BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
function_dot.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_DOT_H_
10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_DOT_H_
11 
13 #include "tree_evaluator.h"
15 
16 namespace bc {
17 namespace tensors {
18 namespace exprs {
19 
20 
21 template<class lv, class rv, class SystemTag>
22 struct Bin_Op<oper::dot<SystemTag>, lv, rv>:
23  Expression_Base<Bin_Op<oper::dot<SystemTag>, lv, rv>>,
24  Shape<0>,
25  oper::dot<SystemTag> {
26 
27  static_assert(std::is_same<
28  typename lv::value_type,
29  typename rv::value_type>::value,
30  "ValueType must be the same");
31 
32  static_assert(
33  lv::tensor_dim == 1 &&
34  (rv::tensor_dim == 1 || rv::tensor_dim ==0),
35  "DOT DIMENSION MISMATCH, INTERNAL BUG, REPORT PLEASE");
36 
37  using value_type = typename lv::value_type;
38  using system_tag = SystemTag;
39 
40  static constexpr int tensor_dim = 0;
41  static constexpr int tensor_iterator_dim = 0;
42 
43  lv left;
44  rv right;
45 
47 
49  left(left),
50  right(right) {}
51 
53  return oper::dot<system_tag>();
54  }
55 
56  template<class Core, int Alpha, int Beta, class Stream>
57  void eval(Output_Data<Core, Alpha, Beta> output, Stream stream) const {
58  static_assert(Core::tensor_dim == 0,"Output must be a scalar");
59 
61 
62  auto X = greedy_evaluate(left, stream);
63  auto Y = greedy_evaluate(right, stream);
64  auto& out = output.data();
65 
66  //call outer product
68  stream,
69  X.rows(), out.data(),
70  X.data(), X.leading_dim(0),
71  Y.data(), Y.leading_dim(0));
72 
73  constexpr int beta_value = Beta == 0 ? 1 : Beta;
76 
77  if (lv_scalar || rv_scalar) {
78  auto alpha_lv = blas_expression_traits<lv>::get_scalar(left);
79  auto alpha_rv = blas_expression_traits<rv>::get_scalar(right);
80  blas_tools::scalar_multiply(stream, out.data(), beta_value, alpha_lv, alpha_rv);
81  } else if (beta_value != 1) {
82  blas_tools::scalar_multiply(stream, out.data(), out.data(), beta_value);
83  }
84  }
85 };
86 
87 
88 } //ns BC
89 } //ns exprs
90 } //ns tensors
91 
92 
93 
94 #endif /* FUNCTION_DOT_H_ */
void eval(Output_Data< Core, Alpha, Beta > output, Stream stream) const
Definition: function_dot.h:57
Definition: tree_output_data.h:18
Definition: shape.h:17
Definition: device.h:17
size_t value_type
Definition: shape.h:120
Definition: blas_expression_template_traits.h:126
Bin_Op(lv left, rv right, oper::dot< system_tag > op=oper::dot< system_tag >())
Definition: function_dot.h:48
static auto get_scalar(const T &expression) -> decltype(detail::remove_scalar_mul< T >::get_scalar(expression))
Definition: blas_expression_template_traits.h:160
SystemTag system_tag
Definition: function_dot.h:38
Definition: common.h:18
Definition: expression_template_base.h:77
static oper::dot< system_tag > get_operation()
Definition: function_dot.h:52
Definition: blas.h:37
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
static constexpr int tensor_dim
Definition: expression_binary.h:30
const Tensor & data() const
Definition: tree_output_data.h:26
Definition: device.h:27
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22