BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
expression_operations.h
Go to the documentation of this file.
1 #define BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor) \
2  template< \
3  class ScalarType, \
4  class=std::enable_if_t< \
5  std::is_convertible<ScalarType, value_type>::value>> \
6  auto op (const ScalarType& param) const \
7  { \
8  return bi_expr(oper::op_functor(), \
9  exprs::make_scalar_constant<system_tag>((value_type)param)); \
10  } \
11  \
12  template< \
13  class ScalarType, \
14  class=std::enable_if_t< \
15  std::is_convertible<ScalarType, value_type>::value>> \
16  friend auto op (const ScalarType& param, const Expression_Base& tensor) \
17  { \
18  value_type value = param; \
19  auto scalar = exprs::make_scalar_constant<system_tag>(value); \
20  return make_expression(scalar).bi_expr(oper:: op_functor (), tensor); \
21  }
22 
23 #define BC_COEFFICIENTWISE_DEF(op, op_functor) \
24  template<class Xpr> \
25  auto op (const Expression_Base<Xpr>& param) const \
26  { \
27  assert_valid(param); \
28  return bi_expr(oper::op_functor(), param); \
29  } \
30  \
31  BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor) \
32 
33 
34  BC_COEFFICIENTWISE_DEF(operator +, Add)
35  BC_COEFFICIENTWISE_DEF(operator -, Sub)
36  BC_COEFFICIENTWISE_DEF(operator %, Mul)
37  BC_COEFFICIENTWISE_DEF(operator /, Div)
38  BC_COEFFICIENTWISE_DEF(operator == , Equal )
39  BC_COEFFICIENTWISE_DEF(operator > , Greater)
40  BC_COEFFICIENTWISE_DEF(operator < , Lesser)
41  BC_COEFFICIENTWISE_DEF(operator >= , Greater_Equal)
42  BC_COEFFICIENTWISE_DEF(operator <= , Lesser_Equal )
43  BC_COEFFICIENTWISE_DEF(operator && , And )
44  BC_COEFFICIENTWISE_DEF(operator || , Or )
45 
49  BC_SCALAR_COEFFICIENTWISE_DEF(operator *, Scalar_Mul)
50 
51 #undef BC_SCALAR_COEFFICIENTWISE_DEF
52 #undef BC_COEFFICIENTWISE_DEF
53 
54  template<class Xpr>
55  auto operator *(const Expression_Base<Xpr>& param) const {
56 
57  using blas_traits = exprs::blas_expression_traits<expression_type>;
58  using rv_blas_traits = exprs::blas_expression_traits<Xpr>;
59 
60  constexpr bool lv_trans = blas_traits::is_transposed::value;
61  constexpr bool rv_trans = rv_blas_traits::is_transposed::value;
62 
63  constexpr bool scalmul = tensor_dim == 0 || Xpr::tensor_dim == 0;
64  constexpr bool gemm = tensor_dim == 2 && Xpr::tensor_dim == 2;
65  constexpr bool gemv = tensor_dim == 2 && Xpr::tensor_dim == 1;
66  constexpr bool ger = tensor_dim == 1 && Xpr::tensor_dim == 1 &&
67  !lv_trans && rv_trans;
68  constexpr bool dot = tensor_dim == 1 && Xpr::tensor_dim == 1 &&
69  !lv_trans && !rv_trans;
70 
71  using matmul_t =
72  std::conditional_t<scalmul, oper::Scalar_Mul,
73  std::conditional_t<gemm, oper::gemm<system_tag>,
74  std::conditional_t<gemv, oper::gemv<system_tag>,
75  std::conditional_t<ger, oper::ger<system_tag>,
76  std::conditional_t<dot, oper::dot<system_tag>, void>>>>>;
77 
78  static_assert(!std::is_void<matmul_t>::value,
79  "INVALID USE OF OPERATOR *");
80  return bi_expr(matmul_t(), param);
81  }
82 
83  // ---- Unary Expressions ---- //
84 
85  const auto transpose() const {
86  return make_expression(make_transpose(this->expression_template()));
87  }
88 
89  auto transpose() {
90  return make_expression(make_transpose(this->expression_template()));
91  }
92 
93  const auto t() const { return this->transpose(); }
94  auto t() { return this->transpose(); }
95 
96  auto operator - () const {
97  return un_expr(oper::negation);
98  }
99 
100  // ---- expression_factory ---- //
101 
102  template<class functor>
103  auto un_expr(functor f) const
104  {
105  return make_expression(
107  this->expression_template(), f));
108  }
109 
110  template<
111  class Functor,
112  class Xpr,
113  class=std::enable_if_t<
114  exprs::expression_traits<Xpr>::is_expression_template::value>>
115  auto bi_expr(Functor func, const Xpr& rv) const
116  {
117  return make_expression(
119  this->expression_template(),
120  rv.expression_template(),
121  func));
122  }
123 
124  template<class Xpr>
125  void assert_valid(const Expression_Base<Xpr>& tensor) const
126  {
127  static_assert(std::is_same<system_tag, typename Xpr::system_tag>::value,
128  "Tensor arguments must have compatible (same) system_tags");
129 
130  bool same_dim = tensor_dim == Xpr::tensor_dim;
131  bool same_shape = this->inner_shape() == tensor.inner_shape();
132  bool cwise_op = same_dim && same_shape;
133 
134  bool scalar_op = tensor_dim == 0 || Xpr::tensor_dim == 0;
135  bool valid_broadcast_op = !same_dim && !cwise_op && valid_slice(tensor);
136  bool valid_cwise_op = (same_dim && same_shape);
137 
138  if (!scalar_op && !valid_broadcast_op && !valid_cwise_op) {
139  throw std::invalid_argument(
140  "Tensor by Tensor operation error: shape mismatch."
141  "\nthis->tensor_dim = " + std::to_string(tensor_dim) +
142  "\nthis->size() = " + std::to_string(this->size()) +
143  "\nthis_dims = " + this->inner_shape().to_string() +
144  "\nparam->tensor_dim = " + std::to_string(Xpr::tensor_dim) +
145  "\nparam.size() = " + std::to_string(tensor.size()) +
146  "\nparam_dims = " + tensor.inner_shape().to_string()
147  );
148  }
149  }
150 
151  template<class Xpr>
152  bool valid_slice(const Expression_Base<Xpr>& tensor) const {
153  constexpr bc::size_t min_dim = traits::min(tensor_dim, Xpr::tensor_dim);
154 
155  for (int i = 0; i < min_dim; ++i)
156  if (tensor.dim(i) != this->dim(i))
157  return false;
158  return true;
159  }
const auto t() const
Definition: expression_operations.h:93
auto make_transpose(expr_t expr)
Definition: function_transpose.h:75
BCHOT auto make_bin_expr(Lv left, Rv right, Op oper)
Definition: expression_binary.h:275
BCHOT auto make_un_expr(Expression expression, Operation operation=Operation())
Definition: expression_unary.h:73
auto max_value(const Expression_Base< Xpr > &param) const
Definition: expression_operations.h:47
auto operator-(const Expression_Base< Xpr > &param) const
Definition: expression_operations.h:35
bool valid_slice(const Expression_Base< Xpr > &tensor) const
Definition: expression_operations.h:152
BCINLINE auto dim(const Integers &... ints)
Definition: dim.h:336
std::string to_string(int precision=8, bool pretty=true, bool sparse=false) const
Definition: tensor_utility.h:34
void assert_valid(const Expression_Base< Xpr > &tensor) const
Definition: expression_operations.h:125
auto make_expression(ExpressionTemplate expression)
Definition: common.h:42
#define BC_COEFFICIENTWISE_DEF(op, op_functor)
Definition: expression_operations.h:23
struct bc::oper::Min min
auto approx_equal(const Expression_Base< Xpr > &param) const
Definition: expression_operations.h:46
int size_t
Definition: common.h:283
auto bi_expr(Functor func, const Xpr &rv) const
Definition: expression_operations.h:115
const auto transpose() const
Definition: expression_operations.h:85
#define BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor)
Definition: expression_operations.h:1
auto un_expr(functor f) const
Definition: expression_operations.h:103
auto operator*(const ScalarType &param) const
Definition: expression_operations.h:49
struct bc::oper::Negation negation
auto min_value(const Expression_Base< Xpr > &param) const
Definition: expression_operations.h:48