BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
expression_binary.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_EXPRESSION_BINARY_H_
10 #define BC_EXPRESSION_TEMPLATES_EXPRESSION_BINARY_H_
11 
12 #include <type_traits>
14 
15 namespace bc {
16 namespace tensors {
17 namespace exprs {
18 
19 template<class Operation, class Lv, class Rv>
20 struct Bin_Op:
21  Expression_Base<Bin_Op<Operation, Lv, Rv>>,
22  Operation
23 {
24  using system_tag = typename Lv::system_tag;
25  using value_type = std::decay_t<decltype(
26  std::declval<Operation>().operator()(
27  std::declval<typename Lv::value_type>(),
28  std::declval<typename Rv::value_type>()))>;
29 
30  static constexpr int tensor_dim =
31  bc::traits::max(Lv::tensor_dim, Rv::tensor_dim);
32 
33 private:
34 
35  static constexpr bool is_broadcast_expression =
36  Lv::tensor_dim != Rv::tensor_dim &&
37  Lv::tensor_dim != 0 &&
38  Rv::tensor_dim != 0;
39 
40  static constexpr int max_dim = bc::traits::max(
41  Lv::tensor_iterator_dim,
42  Rv::tensor_iterator_dim,
43  Lv::tensor_dim,
44  Rv::tensor_dim);
45 
46  static constexpr int max_iterator = bc::traits::max(
47  Lv::tensor_iterator_dim,
48  Rv::tensor_iterator_dim);
49 
50  static constexpr bool continuous_mem_layout =
51  Lv::tensor_iterator_dim <= 1 &&
52  Rv::tensor_iterator_dim <= 1;
53 
54 public:
55 
56  static constexpr int tensor_iterator_dim =
57  is_broadcast_expression || !continuous_mem_layout ?
58  max_dim :
59  max_iterator;
60 
61  Lv left;
62  Rv right;
63 
64  Operation get_operation() const {
65  return static_cast<const Operation&>(*this);
66  }
67 
68  template<class... Args> BCHOT
69  Bin_Op(Lv lv, Rv rv, const Args&... args):
70  Operation(args...),
71  left(lv),
72  right(rv) {}
73 
74  BCINLINE
75  auto operator [](bc::size_t index) const {
76  return Operation::operator()(left[index], right[index]);
77  }
78 
79  BCINLINE
80  auto operator [](bc::size_t index) {
81  return Operation::operator()(left[index], right[index]);
82  }
83 
84  template<
85  class... Integers,
86  class=std::enable_if_t<
87  (sizeof...(Integers)>=tensor_iterator_dim)>>
88  BCINLINE
89  auto operator ()(Integers... ints) const {
90  return Operation::operator()(left(ints...), right(ints...));
91  }
92 
93  template<
94  class... Integers,
95  class=std::enable_if_t<(
96  sizeof...(Integers)>=tensor_iterator_dim)>>
97  BCINLINE
98  auto operator ()(Integers... ints) {
99  return Operation::operator()(left(ints...), right(ints...));
100  }
101 
102 private:
103 
104  BCINLINE
105  const auto& shape() const {
106  constexpr int max_dim = Lv::tensor_dim >= Rv::tensor_dim;
107  return traits::get<max_dim>(right, left);
108  }
109 
110 public:
111 
112  BCINLINE bc::size_t size() const { return shape().size(); }
113  BCINLINE bc::size_t rows() const { return shape().rows(); }
114  BCINLINE bc::size_t cols() const { return shape().cols(); }
115  BCINLINE bc::size_t dim(int i) const { return shape().dim(i); }
116  BCINLINE auto inner_shape() const { return shape().inner_shape(); }
117 };
118 
119 
120 //forward declare
121 template<class T>
123 
124 namespace detail {
125 
126 using bc::oper::Add;
127 using bc::oper::Sub;
130 using bc::oper::Negation;
132 
133 template<class Op, class Lv, class Rv, class... Args> BCHOT
134 auto mk_bin_op(Lv left, Rv right, Args&&... args) {
135  return Bin_Op<Op,Lv, Rv>(left, right, std::forward<Args>(args)...);
136 }
137 
138 template<class Op, class Lv, class Rv, class=void>
140 {
141  template<class... Args>
142  static auto make(Lv lv, Rv rv, Args&&... args) {
143  return Bin_Op<Op, Lv, Rv>(lv, rv, std::forward<Args>(args)...);
144  }
145 };
146 
148 template<class Lv, class Rv>
150 {
151  template<class... Args>
152  static auto make(Lv lv, Un_Op<Negation, Rv> rv, Args&&... args) {
153  return mk_bin_op<Sub>(lv, rv.array, std::forward<Args>(args)...);
154  }
155 };
156 
158 template<class Lv, class Rv>
160 {
161  template<class... Args>
162  static auto make(Lv lv, Un_Op<Negation, Rv> rv, Args&&... args) {
163  return mk_bin_op<Add>(lv, rv.array, std::forward<Args>(args)...);
164  }
165 };
166 
168 template<class Lv, class Rv>
170 {
171  template<class... Args>
172  static auto make(Un_Op<Negation, Lv> lv, Rv rv, Args&&... args) {
173  return mk_bin_op<Sub>(rv, lv.array, std::forward<Args>(args)...);
174  }
175 };
176 
178 template<class Lv, class Rv>
180 {
181  template<class... Args>
182  static auto make(
185  Args&&... args)
186  {
187  return mk_bin_op<Sub>(rv.array, lv.array, std::forward<Args>(args)...);
188  }
189 };
190 
192 template<class Lv, class Rv>
194 {
195  template<class... Args>
196  static auto make(
197  Un_Op<Negation, Lv> lv, Un_Op<Negation, Rv> rv, Args&&... args)
198  {
199  return detail::mk_bin_op<Sub>(
200  lv, rv.array, std::forward<Args>(args)...);
201  }
202 };
203 
205 template<class Lv, class Rv>
207 {
208  template<class... Args>
209  static auto make(Lv lv, Un_Op<Negation, Rv> rv, Args&&... args) {
210  return mk_bin_op<Sub_Assign>(lv, rv.array, std::forward<Args>(args)...);
211  }
212 };
213 
215 template<class Lv, class Rv>
217 {
218  template<class... Args>
219  static auto make(Lv lv, Un_Op<Negation, Rv> rv, Args&&... args) {
220  return mk_bin_op<Add_Assign>(lv, rv.array, std::forward<Args>(args)...);
221  }
222 };
223 
225 template<class Lv, class Rv>
227  bc::oper::Scalar_Mul, Lv, Rv, std::enable_if_t<
228  expression_traits<Lv>::is_blas_expression::value !=
229  expression_traits<Rv>::is_blas_expression::value>>
230 {
231  static auto make(Lv lv, Rv rv, Scalar_Mul op=Scalar_Mul())
232  {
233  constexpr bool left_scalar = Lv::tensor_dim==0;
234  auto scalar_expr = bc::traits::constexpr_ternary<left_scalar>(
235  [=]() { return std::make_pair(lv, rv); },
236  [=]() { return std::make_pair(rv, lv); }
237  );
238 
239  auto scalar = scalar_expr.first;
240  auto expr = scalar_expr.second;
241 
242  using expr_lv_t = std::decay_t<decltype(expr.left)>;
243  using expr_rv_t = std::decay_t<decltype(expr.right)>;
244  using expr_op_t = std::decay_t<decltype(expr.get_operation())>;
245 
246  constexpr bool expr_left_is_scalar_multiplied =
248 
249  constexpr bool expr_right_is_scalar_multiplied =
251 
252  //TODO add support for when both left and right are scalar multiplied
253  static_assert(
254  !(expr_left_is_scalar_multiplied
255  && expr_right_is_scalar_multiplied),
256  "Cannot apply scalar_multiplication to a blas_expression where"
257  "both the left and right expressions of the blas expression"
258  "are already scalar multiplied");
259 
260  return bc::traits::constexpr_ternary<!expr_left_is_scalar_multiplied>(
261  [=]() {
262  auto newexpr = mk_bin_op<Scalar_Mul>(scalar, expr.left);
263  return mk_bin_op<expr_op_t>(newexpr, expr.right);
264  },
265  [=]() {
266  auto newexpr = mk_bin_op<Scalar_Mul>(scalar, expr.right);
267  return mk_bin_op<expr_op_t>(newexpr, expr.left);
268  });
269  }
270 };
271 
272 }
273 
274 template<class Op, class Lv, class Rv> BCHOT
275 auto make_bin_expr(Lv left, Rv right, Op oper) {
276  return detail::bin_expr_factory<Op, Lv, Rv>::make(left, right, oper);
277 }
278 
279 template<class Op, class Lv, class Rv, class... Args> BCHOT
280 auto make_bin_expr(Lv left, Rv right, Args&&... args)
281 {
283  left, right, std::forward<Args>(args)...);
284 }
285 
286 
287 } //ns BC
288 } //ns exprs
289 } //ns tensors
290 
292 
293 #endif /* EXPRESSION_BINARY_POINTWISE_SAME_H_ */
#define BCINLINE
Definition: common.h:96
BCHOT auto mk_bin_op(Lv left, Rv right, Args &&... args)
Definition: expression_binary.h:134
static auto make(Lv lv, Rv rv, Args &&... args)
Definition: expression_binary.h:142
BCINLINE bc::size_t size() const
Definition: expression_binary.h:112
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:209
BCHOT auto make_bin_expr(Lv left, Rv right, Op oper)
Definition: expression_binary.h:275
Definition: unary.h:18
BCINLINE auto operator[](bc::size_t index) const
Definition: expression_binary.h:75
static auto make(Un_Op< Negation, Lv > lv, Rv rv, Args &&... args)
Definition: expression_binary.h:172
Rv right
Definition: expression_binary.h:62
Definition: binary.h:100
Operation get_operation() const
Definition: expression_binary.h:64
static auto make(Un_Op< Negation, Lv > lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:196
BCINLINE auto inner_shape() const
Definition: expression_binary.h:116
Definition: blas_expression_template_traits.h:126
BCHOT Bin_Op(Lv lv, Rv rv, const Args &... args)
Definition: expression_binary.h:69
Lv left
Definition: expression_binary.h:61
BCINLINE bc::size_t cols() const
Definition: expression_binary.h:114
Definition: binary.h:96
ArrayType array
Definition: expression_unary.h:32
static auto make(Un_Op< Negation, Lv > lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:182
Definition: common.h:18
int size_t
Definition: common.h:283
BCINLINE auto operator()(Integers... ints) const
Definition: expression_binary.h:89
Definition: binary.h:60
Definition: expression_binary.h:139
Definition: binary.h:109
std::decay_t< decltype(std::declval< Operation >().operator()(std::declval< typename Lv::value_type >(), std::declval< typename Rv::value_type >()))> value_type
Definition: expression_binary.h:28
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:219
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:152
typename Lv::system_tag system_tag
Definition: expression_binary.h:24
#define BCHOT
Definition: common.h:97
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
static constexpr int tensor_dim
Definition: expression_binary.h:30
BCINLINE bc::size_t rows() const
Definition: expression_binary.h:113
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:162
Definition: expression_template_traits.h:19
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Definition: binary.h:54