BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
blas_expression_template_traits.h
Go to the documentation of this file.
1 /*
2  * Blas_Expression_Traits.h
3  *
4  * Created on: Oct 10, 2019
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCATTENSORS_TENSORS_EXPRS_BLAS_EXPRESSION_TRAITS_H_
9 #define BLACKCATTENSORS_TENSORS_EXPRS_BLAS_EXPRESSION_TRAITS_H_
10 
12 #include "array.h"
13 #include "array_kernel_array.h"
14 #include "array_scalar_constant.h"
15 #include "tree_evaluator.h"
17 
18 namespace bc {
19 namespace tensors {
20 namespace exprs {
21 namespace detail {
22 
23 template<class T>
25  using type = T;
28 
29  static T rm(T expression) {
30  return expression;
31  }
32 
33  static scalar_type get_scalar(const T& expression) {
34  return nullptr;
35  };
36 };
37 
38 template<class lv, class rv>
39 struct remove_scalar_mul<Bin_Op<oper::Scalar_Mul, lv, rv>> {
40  using type = std::conditional_t<lv::tensor_dim == 0, rv, lv>;
41  using scalar_type = std::conditional_t<lv::tensor_dim == 0, lv ,rv>;
42 
44  return bc::traits::constexpr_ternary<lv::tensor_dim==0>(
45  [&]() { return expression.right; },
46  [&]() { return expression.left; }
47  );
48  }
49 
52  return bc::traits::constexpr_ternary<lv::tensor_dim==0>(
53  [&]() { return expression.left; },
54  [&]() { return expression.right; }
55  );
56  }
57 
58 };
59 
60 template<class T>
62  using type = T;
63  static T rm(T expression) {
64  return expression;
65  }
66 };
67 template<class Array, class SystemTag>
68 struct remove_transpose<Un_Op<oper::transpose<SystemTag>, Array>> {
69  using type = Array;
70 
71  static type rm(
72  Un_Op<oper::transpose<SystemTag>, Array> expression) {
73  return expression.array;
74  }
75 };
76 
77 template<class Array, class SystemTag, class Rv>
79  Bin_Op<
80  oper::Scalar_Mul,
81  Un_Op<oper::transpose<SystemTag>,
82  Array>,
83  Rv>>
84 {
85  using type = Array;
86 
87  static type rm(
88  Bin_Op<
91  Array>,
92  Rv> expression) {
93  return expression.left.array;
94  }
95 };
96 
97 template<class Array, class SystemTag, class Lv>
99  Bin_Op<
100  oper::Scalar_Mul,
101  Lv,
102  Un_Op<oper::transpose<SystemTag>,
103  Array>>>
104 {
105  using type = Array;
106 
107  static type rm(
108  Bin_Op<
110  Lv,
112  Array>> expression) {
113  return expression.right.array;
114  }
115 };
116 
117 } //end of ns detail
118 
119 
120 namespace blas_expression_parser {
121 template<class SystemTag>
122 struct Blas_Expression_Parser;
123 }
124 
125 template<class T>
127  expression_traits<T> {
128 
131 
136  using value_type = typename T::value_type;
137 
139  !std::is_same<remove_scalar_mul_type, T>::value>;
140 
142  !std::is_same<remove_transpose_type, T>::value>;
143 
145  return detail::remove_transpose<T>::rm(expression);
146  }
147 
149  return detail::remove_scalar_mul<T>::rm(expression);
150  }
151 
154  remove_scalar_mul(expression));
155  }
156 
157  //If an expression with a scalar,
158  //returns the scalar,
159  //else returns a nullpointer of the valuetype == to T::value_type
160  static auto get_scalar(const T& expression)
161  -> decltype(detail::remove_scalar_mul<T>::get_scalar(expression)) {
162  return detail::remove_scalar_mul<T>::get_scalar(expression);
163  }
164 
165  template<int Alpha, int Beta, class Stream>
166  static auto parse_expression(Stream stream, T expression) {
167  using system_tag = typename T::system_tag;
169  template parse_expression<Alpha, Beta>(
170  stream, expression.left, expression.right);
171  }
172 
173  template<class Stream, class Contents>
174  static void post_parse_expression_evaluation(Stream stream, Contents contents) {
175  using system_tag = typename T::system_tag;
177  template post_parse_expression_evaluation(stream, contents);
178  }
179 };
180 
181 
182 }
183 }
184 }
185 
187 
188 #endif /* BLAS_EXPRESSION_TRAITS_H_ */
static remove_transpose_type remove_transpose(T expression)
Definition: blas_expression_template_traits.h:144
Definition: blas_expression_template_traits.h:24
static type rm(Bin_Op< oper::Scalar_Mul, Un_Op< oper::transpose< SystemTag >, Array >, Rv > expression)
Definition: blas_expression_template_traits.h:87
static T rm(T expression)
Definition: blas_expression_template_traits.h:29
Definition: blas.h:19
typename T::requires_greedy_evaluation query_requires_greedy_evaluation
Definition: expression_template_traits.h:34
static auto parse_expression(Stream stream, T expression)
Definition: blas_expression_template_traits.h:166
static remove_scalar_mul_type remove_scalar_mul(T expression)
Definition: blas_expression_template_traits.h:148
typename detail::remove_transpose< remove_scalar_mul_type >::type remove_blas_features_type
Definition: blas_expression_template_traits.h:134
Rv right
Definition: expression_binary.h:62
std::conditional_t< lv::tensor_dim==0, lv,rv > scalar_type
Definition: blas_expression_template_traits.h:41
Definition: blas_expression_template_traits.h:126
Definition: blas_expression_template_traits.h:61
Lv left
Definition: expression_binary.h:61
static type rm(Bin_Op< oper::Scalar_Mul, lv, rv > expression)
Definition: blas_expression_template_traits.h:43
bc::traits::conditional_detected_t< query_value_type, T, T > * scalar_type
Definition: blas_expression_template_traits.h:27
Definition: array.h:25
typename detail::remove_transpose< T >::type remove_transpose_type
Definition: blas_expression_template_traits.h:133
static type rm(Un_Op< oper::transpose< SystemTag >, Array > expression)
Definition: blas_expression_template_traits.h:71
Definition: binary.h:96
static auto get_scalar(const T &expression) -> decltype(detail::remove_scalar_mul< T >::get_scalar(expression))
Definition: blas_expression_template_traits.h:160
Definition: common.h:18
T type
Definition: blas_expression_template_traits.h:62
bc::traits::truth_type< !std::is_same< remove_transpose_type, T >::value > is_transposed
Definition: blas_expression_template_traits.h:142
Definition: expression_template_traits.h:76
static type rm(Bin_Op< oper::Scalar_Mul, Lv, Un_Op< oper::transpose< SystemTag >, Array >> expression)
Definition: blas_expression_template_traits.h:107
T type
Definition: blas_expression_template_traits.h:25
static remove_blas_features_type remove_blas_modifiers(T expression)
Definition: blas_expression_template_traits.h:152
static scalar_type get_scalar(Bin_Op< oper::Scalar_Mul, lv, rv > expression)
Definition: blas_expression_template_traits.h:50
static void post_parse_expression_evaluation(Stream stream, Contents contents)
Definition: blas_expression_template_traits.h:174
std::conditional_t< lv::tensor_dim==0, rv, lv > type
Definition: blas_expression_template_traits.h:40
typename conditional_detected< func, TestType, DefaultType >::type conditional_detected_t
Definition: type_traits.h:87
typename detail::remove_scalar_mul< T >::scalar_type scalar_multiplier_type
Definition: blas_expression_template_traits.h:135
typename detail::remove_scalar_mul< T >::type remove_scalar_mul_type
Definition: blas_expression_template_traits.h:132
typename T::value_type query_value_type
Definition: expression_template_traits.h:28
typename T::value_type value_type
Definition: blas_expression_template_traits.h:136
static scalar_type get_scalar(const T &expression)
Definition: blas_expression_template_traits.h:33
conditional_t< Bool, true_type, false_type > truth_type
Definition: type_traits.h:49
const auto transpose() const
Definition: expression_operations.h:85
static T rm(T expression)
Definition: blas_expression_template_traits.h:63
bc::traits::conditional_detected_t< detail::query_requires_greedy_evaluation, T, std::false_type > requires_greedy_evaluation
Definition: blas_expression_template_traits.h:130
Definition: device.h:27
bc::traits::truth_type< !std::is_same< remove_scalar_mul_type, T >::value > is_scalar_multiplied
Definition: blas_expression_template_traits.h:139
Definition: expression_template_traits.h:19
bc::traits::conditional_detected_t< detail::query_system_tag, T, host_tag > system_tag
Definition: expression_template_traits.h:79
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22