BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
function_gemv.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_GEMV_H_
10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_GEMV_H_
11 
13 #include "tree_evaluator.h"
15 
16 
17 namespace bc {
18 namespace tensors {
19 namespace exprs {
20 
21 
22 template<class lv, class rv, class SystemTag>
23 struct Bin_Op<oper::gemv<SystemTag>, lv, rv>:
24  Expression_Base<Bin_Op<oper::gemv<SystemTag>, lv, rv>>,
25  oper::gemv<SystemTag> {
26 
27  static_assert(std::is_same<
28  typename lv::value_type,
29  typename rv::value_type>::value,
30  "GEMV arguments must have the same value_type");
31 
32  static_assert(lv::tensor_dim == 2 && rv::tensor_dim == 1,
33  "Lv must be a Matrix and Rv must be a Vector");
34 
35  using value_type = typename lv::value_type;
36  using system_tag = SystemTag;
37 
38  static constexpr int tensor_dim = 1;
39  static constexpr int tensor_iterator_dim = 1;
40 
41  lv left;
42  rv right;
43 
45  left(left),
46  right(right)
47  {
48  BC_ASSERT(left.cols() == right.rows(),
49  "gemv requires left.cols() == right.rows()");
50  }
51 
53  return oper::gemv<system_tag>();
54  }
55 
56  BCINLINE bc::size_t size() const { return left.rows(); }
57  BCINLINE bc::size_t dim(int i) const { return i == 0 ? left.rows() : 1; }
58  BCINLINE bc::size_t rows() const { return dim(0); }
59  BCINLINE bc::size_t cols() const { return dim(1); }
60 
61  template<class core, int Alpha, int Beta, class Stream>
62  void eval(Output_Data<core, Alpha, Beta> output, Stream stream) const {
63  static_assert(core::tensor_dim==1, "Gemv out must be a vector");
64 
65  using self_t = Bin_Op<oper::gemv<system_tag>, lv, rv>;
66  using traits = blas_expression_traits<self_t>;
67 
68  //evaluate the left and right branches (computes only if necessary)
69  auto contents = traits::template parse_expression<Alpha, Beta>(stream, *this);
70  auto A = contents.left;
71  auto X = contents.right;
72  auto alpha = contents.alpha;
73  auto beta = contents.beta;
74  bool transA = contents.lv_is_transposed;
75 
76  auto& out = output.data();
77 
78  //gemv uses the [m,n] to refer to dim ignoring op(A)
79  //http://www.netlib.org/lapack/explore-html/d6/d30/group__single__blas__level2_gafc92361b74c6d41c7e5afa0aa5d13ec9.html#gafc92361b74c6d41c7e5afa0aa5d13ec9
81  stream, transA, A.rows(), A.cols(),
82  alpha.data(), A.data(), A.leading_dim(1),
83  X.data(), X.leading_dim(0)/*inc_X*/,
84  beta.data(),
85  out.data()/*Y*/, out.leading_dim(0)/*incy*/);
86 
87  traits::post_parse_expression_evaluation(stream, contents);
88  }
89 };
90 
91 
92 } //ns BC
93 } //ns exprs
94 } //ns tensors
95 
96 
97 #endif /* EXPRESSION_BINARY_DOTPRODUCT_CU_ */
SystemTag system_tag
Definition: function_gemv.h:36
Definition: tree_output_data.h:18
#define BCINLINE
Definition: common.h:96
__host__ __device__ bc::size_t size() const
Definition: function_gemv.h:56
Definition: device.h:17
Definition: blas.h:27
static oper::gemv< system_tag > get_operation()
Definition: function_gemv.h:52
typename lv::value_type value_type
Definition: function_gemv.h:35
Definition: blas_expression_template_traits.h:126
Bin_Op(lv left, rv right, oper::gemv< system_tag > op=oper::gemv< system_tag >())
Definition: function_gemv.h:44
Definition: common.h:18
int size_t
Definition: common.h:283
void eval(Output_Data< core, Alpha, Beta > output, Stream stream) const
Definition: function_gemv.h:62
Definition: expression_template_base.h:77
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
#define BC_ASSERT(condition, message)
Definition: common.h:185
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
static constexpr int tensor_dim
Definition: expression_binary.h:30
const Tensor & data() const
Definition: tree_output_data.h:26
Definition: device.h:27
__host__ __device__ bc::size_t cols() const
Definition: function_gemv.h:59
__host__ __device__ bc::size_t dim(int i) const
Definition: function_gemv.h:57
__host__ __device__ bc::size_t rows() const
Definition: function_gemv.h:58
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22