BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
function_gemm.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_GEMM_H_
10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_GEMM_H_
11 
13 #include "tree_evaluator.h"
15 
16 namespace bc {
17 namespace tensors {
18 namespace exprs {
19 
20 
21 template<class lv, class rv, class SystemTag>
22 struct Bin_Op<oper::gemm<SystemTag>, lv, rv>:
23  Expression_Base<Bin_Op<oper::gemm<SystemTag>, lv, rv>>,
24  oper::gemm<SystemTag> {
25 
26  static_assert(std::is_same<
27  typename lv::value_type,
28  typename rv::value_type>::value,
29  "GEMM arguments must have the same value_type");
30 
31  static_assert(lv::tensor_dim==2 && rv::tensor_dim==2,
32  "Error: GEMM Expression initialized with non matrix tensor");
33 
34  using value_type = typename lv::value_type;
35  using system_tag = SystemTag;
36 
37  static constexpr int tensor_dim = rv::tensor_dim;
38  static constexpr int tensor_iterator_dim = 1;
39 
40  lv left;
41  rv right;
42 
43  BCHOT Bin_Op(lv left, rv right,
45  left(left),
46  right(right)
47  {
48  BC_ASSERT(left.cols() == right.rows(),
49  "gemm requires left.cols() == right.rows()");
50  }
51 
53  return oper::gemm<SystemTag>();
54  }
55 
56  BCINLINE bc::size_t size() const { return left.rows() * right.cols(); }
57  BCINLINE bc::size_t dim(int i) const {
58  return i == 0 ? left.rows() : i == 1 ? right.cols() : 1;
59  }
60 
61  BCINLINE bc::size_t rows() const { return dim(0); }
62  BCINLINE bc::size_t cols() const { return dim(1); }
63 
64  template<class Core, int Alpha, int Beta, class Stream>
65  void eval(Output_Data<Core, Alpha, Beta> output, Stream stream) const
66  {
67  auto& out = output.data();
68 
69  static_assert(Core::tensor_dim == 2,
70  "Gemm out must be a matrix");
71  BC_ASSERT(out.rows() == left.rows(),
72  "Output dim (rows) mismatch for GEMM");
73  BC_ASSERT(out.cols() == right.cols(),
74  "Output dim (cols) mismatch for GEMM");
75 
76  using self_t = Bin_Op<oper::gemm<SystemTag>, lv, rv>;
77  using traits = blas_expression_traits<self_t>;
78 
79  auto contents = traits::template parse_expression<Alpha, Beta>(stream, *this);
80  auto A = contents.left;
81  auto B = contents.right;
82  auto alpha = contents.alpha;
83  auto beta = contents.beta;
84  auto transA = contents.lv_is_transposed;
85  auto transB = contents.rv_is_transposed;
86 
88  stream, transA, transB, out.rows(), out.cols(), left.cols(),
89  alpha.data(), A.data(), A.leading_dim(1),
90  B.data(), B.leading_dim(1),
91  beta.data(), out.data(), out.leading_dim(1));
92 
93  traits::template post_parse_expression_evaluation(stream, contents);
94  }
95 };
96 
97 
98 } //ns BC
99 } //ns exprs
100 } //ns tensors
101 
102 
103 #endif /* EXPRESSION_BINARY_DOTPRODUCT_CU_ */
Definition: tree_output_data.h:18
#define BCINLINE
Definition: common.h:96
Definition: device.h:17
__host__ __device__ bc::size_t cols() const
Definition: function_gemm.h:62
__host__ __device__ bc::size_t dim(int i) const
Definition: function_gemm.h:57
Bin_Op(lv left, rv right, oper::gemm< system_tag > op=oper::gemm< system_tag >())
Definition: function_gemm.h:43
Definition: blas_expression_template_traits.h:126
Definition: blas.h:22
Definition: common.h:18
int size_t
Definition: common.h:283
static oper::gemm< SystemTag > get_operation()
Definition: function_gemm.h:52
void eval(Output_Data< Core, Alpha, Beta > output, Stream stream) const
Definition: function_gemm.h:65
Definition: expression_template_base.h:77
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
#define BC_ASSERT(condition, message)
Definition: common.h:185
#define BCHOT
Definition: common.h:97
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
typename lv::value_type value_type
Definition: function_gemm.h:34
static constexpr int tensor_dim
Definition: expression_binary.h:30
const Tensor & data() const
Definition: tree_output_data.h:26
Definition: device.h:27
__host__ __device__ bc::size_t rows() const
Definition: function_gemm.h:61
__host__ __device__ bc::size_t size() const
Definition: function_gemm.h:56
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
SystemTag system_tag
Definition: function_gemm.h:35