BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
function_ger.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_GER_H_
10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_GER_H_
11 
13 #include "tree_evaluator.h"
14 #include "array_scalar_constant.h"
16 
17 namespace bc {
18 namespace tensors {
19 namespace exprs {
20 
21 template<class lv, class rv, class SystemTag>
22 struct Bin_Op<oper::ger<SystemTag>, lv, rv>:
23  Expression_Base<Bin_Op<oper::ger<SystemTag>, lv, rv>>,
24  oper::ger<SystemTag> {
25 
26  static_assert(
27  std::is_same<
28  typename lv::value_type,
29  typename rv::value_type>::value,
30  "GER arguments must have the same value_type");
31 
32  static_assert(lv::tensor_dim == 1 &&
33  rv::tensor_dim == 1 &&
35  "GER DIMENSION MISMATCH, INTERNAL BUG, REPORT PLEASE");
36 
37  using value_type = typename lv::value_type;
38  using system_tag = SystemTag;
39 
40  static constexpr int tensor_dim = 2;
41  static constexpr int tensor_iterator_dim = 1;
42 
43  lv left;
44  rv right;
45 
47  left(left),
48  right(right) {}
49 
51  return oper::ger<system_tag>();
52  }
53 
54  BCINLINE bc::size_t size() const { return left.size() * right.size(); }
55  BCINLINE bc::size_t dim(int i) const { return i == 0 ? left.rows() : i == 1 ? right.cols() : 1; }
56  BCINLINE bc::size_t rows() const { return dim(0); }
57  BCINLINE bc::size_t cols() const { return dim(1); }
58 
59 
60  template<class core, int Alpha, int Beta, class Stream>
61  void eval(Output_Data<core, Alpha, Beta> output, Stream stream) const {
62  static_assert(core::tensor_dim==2, "Ger out must be a matrix");
63 
64  using self_t = Bin_Op<oper::ger<system_tag>, lv, rv>;
65  using traits = blas_expression_traits<self_t>;
66 
67  auto& out = output.data();
68 
69  //if we need to negate or zero the output
70  //If Beta != 1 consider using gemm (to enable zeroing/modifying the output)
71  if (Beta != 1) {
72  auto expr = make_bin_expr<oper::Assign>(out, make_scalar_constant<value_type>(Beta));
73  evaluate(expr, stream);
74  }
75 
78 
79  auto contents = traits::template parse_expression<Alpha, Beta>(stream, *this);
80  auto A = contents.left;
81  auto B = contents.right;
82  auto alpha = contents.alpha;
83  bc::blas::BLAS<system_tag>::ger(stream, left.rows(), right.cols(),
84  alpha.data(), A.data(), A.leading_dim(0),
85  B.data(), B.leading_dim(0),
86  out.data(), out.leading_dim(1));
87  traits::post_parse_expression_evaluation(stream, contents);
88  } else {
89  auto alpha = make_constexpr_scalar<bc::host_tag, (Alpha == 0 ? 1 : Alpha), value_type>();
90  auto A = greedy_evaluate(blas_expression_traits<lv>::remove_blas_modifiers(left), stream);
91  auto B = greedy_evaluate(blas_expression_traits<rv>::remove_blas_modifiers(right), stream);
92  stream.set_blas_pointer_mode_host();
93  bc::blas::BLAS<system_tag>::ger(stream, left.rows(), right.cols(),
94  alpha.data(), A.data(), A.leading_dim(0),
95  B.data(), B.leading_dim(0),
96  out.data(), out.leading_dim(1));
97  }
98  }
99 };
100 
101 
102 } //ns BC
103 } //ns exprs
104 } //ns tensors
105 
106 
107 #endif /* EXPRESSION_BINARY_DOTPRODUCT_CU_ */
Bin_Op(lv left, rv right, oper::ger< system_tag > op=oper::ger< system_tag >())
Definition: function_ger.h:46
SystemTag system_tag
Definition: function_ger.h:38
Definition: tree_output_data.h:18
#define BCINLINE
Definition: common.h:96
Definition: device.h:17
typename lv::value_type value_type
Definition: function_ger.h:37
void eval(Output_Data< core, Alpha, Beta > output, Stream stream) const
Definition: function_ger.h:61
__host__ __device__ bc::size_t rows() const
Definition: function_ger.h:56
Definition: blas_expression_template_traits.h:126
__host__ __device__ bc::size_t dim(int i) const
Definition: function_ger.h:55
Definition: blas.h:32
static oper::ger< system_tag > get_operation()
Definition: function_ger.h:50
Definition: common.h:18
int size_t
Definition: common.h:283
__host__ __device__ bc::size_t cols() const
Definition: function_ger.h:57
__host__ __device__ bc::size_t size() const
Definition: function_ger.h:54
Definition: expression_template_base.h:77
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
static constexpr int tensor_dim
Definition: expression_binary.h:30
const Tensor & data() const
Definition: tree_output_data.h:26
Definition: device.h:27
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22