BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
tree_evaluator_optimizer.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: Joseph Jaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef PTE_ARRAY_H_
10 #define PTE_ARRAY_H_
11 
12 #include "tree_output_data.h"
13 #include "array.h"
14 #include "expression_binary.h"
15 #include "expression_unary.h"
16 
17 namespace bc {
18 namespace tensors {
19 namespace exprs {
20 
26 
27 template<class T, class voider=void>
28 struct optimizer;
29 
30 template<class T>
32 {
48  static constexpr bool entirely_blas_expr = false;
49  static constexpr bool partial_blas_expr = false;
50  static constexpr bool requires_greedy_eval = false;
51 
52  template<class OutputData, class Stream>
53  static auto linear_eval(T branch, OutputData, Stream) { return branch; }
54 
55  template<class OutputData, class Stream>
56  static auto injection(T branch, OutputData, Stream) { return branch; }
57 
58  template<class Stream>
59  static auto temporary_injection(T branch, Stream) { return branch; }
60 
61  template<class Stream>
62  static void deallocate_temporaries(T, Stream) { return; }
63 };
64 
65 template<class Op, class Array>
67 {
68  template<class Stream>
69  static auto temporary_injection(Un_Op<Op, Array> branch, Stream stream) {
70  auto expr = optimizer<Array>::temporary_injection(branch.array, stream);
71  return make_un_expr(expr, branch.get_operation());
72  }
73 
74  template<class Stream>
75  static void deallocate_temporaries(Un_Op<Op, Array> branch, Stream stream) {
77  }
78 };
79 
80 template<class op, class lv, class rv>
82 {
83  template<class Stream>
84  static auto temporary_injection(Bin_Op<op, lv, rv> branch, Stream stream)
85  {
86  auto left = optimizer<lv>::temporary_injection(branch.left, stream);
87  auto right = optimizer<rv>::temporary_injection(branch.right, stream);
88  return make_bin_expr<op>(left, right, branch.get_operation());
89  }
90 
91  template<class Stream>
92  static void deallocate_temporaries(Bin_Op<op, lv, rv> branch, Stream stream)
93  {
96  }
97 };
98 
99 // -------------------------------- Array -------------------------------- //
100 template<class T>
101 struct optimizer<T, std::enable_if_t<
102  expression_traits<T>::is_array::value &&
103  !expression_traits<T>::is_temporary::value>>:
104  optimizer_default<T> {};
105 
106 // -------------------------------- Temp -------------------------------- //
107 
108 template<class Array>
109 struct optimizer<
110  Array,
111  std::enable_if_t<expression_traits<Array>::is_temporary::value>>:
112  optimizer_default<Array>
113 {
114  template<class Stream>
115  static void deallocate_temporaries(Array tmp, Stream stream)
116  {
117  using value_type = typename Array::value_type;
118  tmp.deallocate(stream.template get_allocator_rebound<value_type>());
119  }
120 };
121 
122 // -------------------------------- Blas -------------------------------- //
123 template<class Xpr>
125 {
126  static constexpr bool entirely_blas_expr = true;
127  static constexpr bool partial_blas_expr = true;
128  static constexpr bool requires_greedy_eval = true;
129 
130 private:
131 
132  template<class OutputData, class Stream>
133  static auto evaluate_impl(Xpr branch, OutputData tensor, Stream stream,
134  std::true_type valid_injection)
135  {
136  branch.eval(tensor, stream);
137  return tensor.data();
138  }
139 
140  template<class OutputData, class Stream>
141  static auto evaluate_impl(Xpr branch, OutputData tensor, Stream stream,
142  std::false_type valid_injection)
143  {
144  return branch;
145  }
146 
147 public:
148 
149  template<class OutputData, class Stream>
150  static auto linear_eval(Xpr branch, OutputData tensor, Stream stream) {
151  return evaluate_impl(branch, tensor, stream,
152  truth_type<Xpr::tensor_dim == OutputData::tensor_dim>());
153  }
154 
155  template<class OutputData, class Stream>
156  static auto injection(Xpr branch, OutputData tensor, Stream stream) {
157  return evaluate_impl(branch, tensor, stream,
158  truth_type<Xpr::tensor_dim == OutputData::tensor_dim>());
159  }
160 
161  template<class Stream>
162  static auto temporary_injection(Xpr branch, Stream stream)
163  {
164  using value_type = typename Xpr::value_type;
165 
166  auto temporary = make_kernel_array(
167  branch.get_shape(),
168  stream.template get_allocator_rebound<value_type>(),
169  temporary_tag());
170 
171  branch.eval(make_output_data<1, 0>(temporary), stream);
172  return temporary;
173  }
174 };
175 
176 
177 template<class op, class lv, class rv>
178 struct optimizer<
179  Bin_Op<op, lv, rv>,
180  std::enable_if_t<
181  expression_traits<Bin_Op<op, lv, rv>>
182  ::requires_greedy_evaluation::value>>:
183  binary_optimizer_default<op, lv, rv>,
184  optimizer_greedy_evaluations<Bin_Op<op, lv, rv>>
185 {
187 };
188 
189 template<class op, class value>
190 struct optimizer<
191  Un_Op<op, value>,
192  std::enable_if_t<
193  expression_traits<Un_Op<op, value>>
194  ::requires_greedy_evaluation::value>>:
195  unary_optimizer_default<op, value>,
196  optimizer_greedy_evaluations<Un_Op<op, value>>
197 {
199 };
200 
201 
202 // ------------------------------ Linear ------------------------------//
203 
204 template<class op, class lv, class rv>
205 struct optimizer<
206  Bin_Op<op, lv, rv>,
207  std::enable_if_t<oper::operation_traits<op>::is_linear_operation>>:
208  binary_optimizer_default<op, lv, rv>
209 {
210  static constexpr bool entirely_blas_expr =
213  lv::tensor_dim == rv::tensor_dim;
214 
215  static constexpr bool partial_blas_expr =
218 
219  static constexpr bool requires_greedy_eval =
222 
223  template<class OutputData, class Stream>
224  static
226  Bin_Op<op, lv, rv>& branch, OutputData tensor, Stream stream)
227  {
228  auto rv_eval = [&](auto update_beta=std::true_type()) {
229  using update_beta_t = std::decay_t<decltype(update_beta)>;
231  branch.right,
232  update_alpha_beta_modifiers<op, update_beta_t::value>(tensor),
233  stream);
234  };
235 
236  auto left = optimizer<lv>::linear_eval(branch.left, tensor, stream);
237 
238  return
239  constexpr_if<entirely_blas_expr>(
240  [&](){
241  rv_eval(std::true_type());
242  return tensor.data();
243  },
244  constexpr_else_if<optimizer<lv>::entirely_blas_expr>(
245  [&]() {
246  auto right = rv_eval(std::true_type());
247  return constexpr_ternary<std::is_same<op, oper::Sub>::value>(
248  [&]() {
249  return make_un_expr<oper::Negation>(right);
250  },
251  [&]() {
252  return right;
253  }
254  );
255  },
256  constexpr_else_if<optimizer<rv>::entirely_blas_expr>(
257  [&]() {
259  return left;
260  },
262  [&]() {
263  using left_evaluated = truth_type<
264  (optimizer<lv>::partial_blas_expr || OutputData::BETA)>;
265  auto right = rv_eval(left_evaluated());
266  return make_bin_expr<op>(left, right);
267  }
268  ))));
269  }
270 
271  template<class OutputData, class Stream> static
272  auto injection(Bin_Op<op, lv, rv> branch, OutputData tensor, Stream stream)
273  {
274  auto lv_eval = [&]() {
275  return optimizer<lv>::linear_eval(branch.left, tensor, stream);
276  };
277 
278  auto rv_eval = [&](auto update_beta=std::true_type()) {
279  using update_beta_t = std::decay_t<decltype(update_beta)>;
281  branch.right,
282  update_alpha_beta_modifiers<op, update_beta_t::value>(tensor),
283  stream);
284  };
285 
286  auto basic_eval = [&]()
287  {
288  using left_evaluated = truth_type<
289  optimizer<lv>::partial_blas_expr || OutputData::BETA != 0>;
290  return make_bin_expr<op>(lv_eval(), rv_eval(left_evaluated()));
291  };
292 
293  return constexpr_if<entirely_blas_expr>(
294  [&](){
295  lv_eval();
296  rv_eval(std::true_type());
297  return tensor.data();
298  },
302  basic_eval,
303  constexpr_else_if<optimizer<lv>::requires_greedy_eval>(
304  [&]() {
305  auto left = optimizer<lv>::injection(branch.left, tensor, stream);
306  return make_bin_expr<op>(left, branch.right);
307  },
308  constexpr_else_if<optimizer<rv>::requires_greedy_eval>(
309  [&]() {
310  auto right = optimizer<rv>::injection(branch.right, tensor, stream);
311  return make_bin_expr<op>(branch.left, right);
312  },
314  basic_eval
315  )))));
316  }
317 };
318 
319 // ------------------------------ Non-linear ------------------------------//
320 
321 template<class op, class lv, class rv>
322 struct optimizer<
323  Bin_Op<op, lv, rv>,
324  std::enable_if_t<
325  oper::operation_traits<op>::is_nonlinear_operation &&
326  !expression_traits<Bin_Op<op, lv, rv>>
327  ::requires_greedy_evaluation::value>>:
328  binary_optimizer_default<op, lv, rv>
329 {
330  static constexpr bool entirely_blas_expr = false;
331  static constexpr bool partial_blas_expr = false;
332  static constexpr bool requires_greedy_eval =
335 
336  template<class OutputData, class Stream> static
337  auto linear_eval(Bin_Op<op, lv, rv> branch, OutputData tensor, Stream) {
338  return branch;
339  }
340 
341  template<class OutputData, class Stream> static
342  auto injection(Bin_Op<op, lv, rv> branch, OutputData tensor, Stream stream)
343  {
344  return constexpr_ternary<
347  [&]() {
348  auto left = optimizer<lv>::injection(branch.left, tensor, stream);
349  auto right = branch.right;
350  return make_bin_expr<op>(left, right, branch.get_operation());
351  }, [&]() {
352  auto left = branch.left;
353  auto right = optimizer<rv>::injection(branch.right, tensor, stream);
354  return make_bin_expr<op>(left, right, branch.get_operation());
355  });
356  }
357 };
358 
359 // ------------------------------ Un_Op ------------------------------//
360 template<class Op, class Array>
361 struct optimizer<
362  Un_Op<Op, Array>,
363  std::enable_if_t<!expression_traits<Un_Op<Op, Array>>
364  ::requires_greedy_evaluation::value>>:
365  unary_optimizer_default<Op, Array>
366 {
367  static constexpr bool entirely_blas_expr = false;
368  static constexpr bool partial_blas_expr = false;
370 
371  template<class OutputData, class Stream> static
372  auto linear_eval(Un_Op<Op, Array> branch, OutputData tensor, Stream) {
373  return branch;
374  }
375 
376  template<class OutputData, class Stream> static
377  auto injection(Un_Op<Op, Array> branch, OutputData tensor, Stream stream)
378  {
379  auto array = optimizer<Array>::injection(branch.array, tensor, stream);
380  return make_un_expr(array, branch.get_operation());
381  }
382 };
383 
384 } //ns exprs
385 } //ns tensors
386 } //ns BC
387 
388 
389 #endif
Definition: tree_evaluator_optimizer.h:124
auto make_kernel_array(Shape< N > shape, Allocator allocator, Tags...)
Definition: array_kernel_array.h:134
static void deallocate_temporaries(Bin_Op< op, lv, rv > branch, Stream stream)
Definition: tree_evaluator_optimizer.h:92
auto constexpr_ternary(f1 true_path, f2 false_path)
C++ 11/14 version of constexpr if.
Definition: constexpr_if.h:36
static auto temporary_injection(Bin_Op< op, lv, rv > branch, Stream stream)
Definition: tree_evaluator_optimizer.h:84
static constexpr bool requires_greedy_eval
Definition: tree_evaluator_optimizer.h:50
BCHOT auto make_un_expr(Expression expression, Operation operation=Operation())
Definition: expression_unary.h:73
BCINLINE const Operation & get_operation() const
Definition: expression_unary.h:35
static auto linear_eval(Un_Op< Op, Array > branch, OutputData tensor, Stream)
Definition: tree_evaluator_optimizer.h:372
auto constexpr_else(Function function)
Definition: constexpr_if.h:61
static auto injection(Bin_Op< op, lv, rv > branch, OutputData tensor, Stream stream)
Definition: tree_evaluator_optimizer.h:272
Rv right
Definition: expression_binary.h:62
Operation get_operation() const
Definition: expression_binary.h:64
Definition: expression_template_traits.h:48
auto constexpr_else_if(Function function)
Definition: constexpr_if.h:51
void deallocate()
Definition: array.h:161
static auto linear_eval(Bin_Op< op, lv, rv > &branch, OutputData tensor, Stream stream)
Definition: tree_evaluator_optimizer.h:225
static constexpr bool partial_blas_expr
Definition: tree_evaluator_optimizer.h:49
Lv left
Definition: expression_binary.h:61
auto constexpr_if(Function function)
Definition: constexpr_if.h:41
Definition: tree_evaluator_optimizer.h:31
static auto injection(T branch, OutputData, Stream)
Definition: tree_evaluator_optimizer.h:56
Definition: array.h:25
ArrayType array
Definition: expression_unary.h:32
Definition: common.h:18
Definition: tree_evaluator_optimizer.h:66
static void deallocate_temporaries(Array tmp, Stream stream)
Definition: tree_evaluator_optimizer.h:115
static auto temporary_injection(Un_Op< Op, Array > branch, Stream stream)
Definition: tree_evaluator_optimizer.h:69
static constexpr bool entirely_blas_expr
entirely_blas_expr if we may replace this branch entirely with a temporary/cache expression is +/- op...
Definition: tree_evaluator_optimizer.h:48
static auto linear_eval(Xpr branch, OutputData tensor, Stream stream)
Definition: tree_evaluator_optimizer.h:150
Definition: tree_evaluator_optimizer.h:28
static auto linear_eval(T branch, OutputData, Stream)
Definition: tree_evaluator_optimizer.h:53
conditional_t< Bool, true_type, false_type > truth_type
Definition: type_traits.h:49
static void deallocate_temporaries(Un_Op< Op, Array > branch, Stream stream)
Definition: tree_evaluator_optimizer.h:75
Scalar value_type
Definition: array.h:34
static auto temporary_injection(T branch, Stream)
Definition: tree_evaluator_optimizer.h:59
static auto injection(Un_Op< Op, Array > branch, OutputData tensor, Stream stream)
Definition: tree_evaluator_optimizer.h:377
Definition: tree_evaluator_optimizer.h:81
Definition: device.h:27
static auto injection(Xpr branch, OutputData tensor, Stream stream)
Definition: tree_evaluator_optimizer.h:156
static void deallocate_temporaries(T, Stream)
Definition: tree_evaluator_optimizer.h:62
static auto temporary_injection(Xpr branch, Stream stream)
Definition: tree_evaluator_optimizer.h:162
Definition: expression_template_traits.h:19
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22