9 #ifndef BC_EXPRESSION_TEMPLATES_EXPRESSION_BINARY_H_ 10 #define BC_EXPRESSION_TEMPLATES_EXPRESSION_BINARY_H_ 12 #include <type_traits> 19 template<
class Operation,
class Lv,
class Rv>
21 Expression_Base<Bin_Op<Operation, Lv, Rv>>,
26 std::declval<Operation>().
operator()(
27 std::declval<typename Lv::value_type>(),
28 std::declval<typename Rv::value_type>()))>;
31 bc::traits::max(Lv::tensor_dim, Rv::tensor_dim);
35 static constexpr
bool is_broadcast_expression =
36 Lv::tensor_dim != Rv::tensor_dim &&
37 Lv::tensor_dim != 0 &&
40 static constexpr
int max_dim = bc::traits::max(
41 Lv::tensor_iterator_dim,
42 Rv::tensor_iterator_dim,
46 static constexpr
int max_iterator = bc::traits::max(
47 Lv::tensor_iterator_dim,
48 Rv::tensor_iterator_dim);
50 static constexpr
bool continuous_mem_layout =
51 Lv::tensor_iterator_dim <= 1 &&
52 Rv::tensor_iterator_dim <= 1;
57 is_broadcast_expression || !continuous_mem_layout ?
65 return static_cast<const Operation&
>(*this);
68 template<
class... Args>
BCHOT 69 Bin_Op(Lv lv, Rv rv,
const Args&... args):
76 return Operation::operator()(left[index], right[index]);
81 return Operation::operator()(left[index], right[index]);
86 class=std::enable_if_t<
87 (
sizeof...(Integers)>=tensor_iterator_dim)>>
90 return Operation::operator()(
left(ints...), right(ints...));
95 class=std::enable_if_t<(
96 sizeof...(Integers)>=tensor_iterator_dim)>>
99 return Operation::operator()(
left(ints...), right(ints...));
105 const auto& shape()
const {
106 constexpr
int max_dim = Lv::tensor_dim >= Rv::tensor_dim;
107 return traits::get<max_dim>(
right,
left);
133 template<
class Op,
class Lv,
class Rv,
class... Args>
BCHOT 138 template<
class Op,
class Lv,
class Rv,
class=
void>
141 template<
class... Args>
142 static auto make(Lv lv, Rv rv, Args&&... args) {
148 template<
class Lv,
class Rv>
151 template<
class... Args>
153 return mk_bin_op<Sub>(lv, rv.
array, std::forward<Args>(args)...);
158 template<
class Lv,
class Rv>
161 template<
class... Args>
163 return mk_bin_op<Add>(lv, rv.
array, std::forward<Args>(args)...);
168 template<
class Lv,
class Rv>
171 template<
class... Args>
173 return mk_bin_op<Sub>(rv, lv.
array, std::forward<Args>(args)...);
178 template<
class Lv,
class Rv>
181 template<
class... Args>
187 return mk_bin_op<Sub>(rv.
array, lv.
array, std::forward<Args>(args)...);
192 template<
class Lv,
class Rv>
195 template<
class... Args>
199 return detail::mk_bin_op<Sub>(
200 lv, rv.
array, std::forward<Args>(args)...);
205 template<
class Lv,
class Rv>
208 template<
class... Args>
210 return mk_bin_op<Sub_Assign>(lv, rv.
array, std::forward<Args>(args)...);
215 template<
class Lv,
class Rv>
218 template<
class... Args>
220 return mk_bin_op<Add_Assign>(lv, rv.
array, std::forward<Args>(args)...);
225 template<
class Lv,
class Rv>
228 expression_traits<Lv>::is_blas_expression::value !=
229 expression_traits<Rv>::is_blas_expression::value>>
233 constexpr
bool left_scalar = Lv::tensor_dim==0;
234 auto scalar_expr = bc::traits::constexpr_ternary<left_scalar>(
235 [=]() {
return std::make_pair(lv, rv); },
236 [=]() {
return std::make_pair(rv, lv); }
239 auto scalar = scalar_expr.first;
240 auto expr = scalar_expr.second;
242 using expr_lv_t = std::decay_t<decltype(expr.left)>;
243 using expr_rv_t = std::decay_t<decltype(expr.right)>;
244 using expr_op_t = std::decay_t<decltype(expr.get_operation())>;
246 constexpr
bool expr_left_is_scalar_multiplied =
249 constexpr
bool expr_right_is_scalar_multiplied =
254 !(expr_left_is_scalar_multiplied
255 && expr_right_is_scalar_multiplied),
256 "Cannot apply scalar_multiplication to a blas_expression where" 257 "both the left and right expressions of the blas expression" 258 "are already scalar multiplied");
260 return bc::traits::constexpr_ternary<!expr_left_is_scalar_multiplied>(
262 auto newexpr = mk_bin_op<Scalar_Mul>(scalar, expr.left);
263 return mk_bin_op<expr_op_t>(newexpr, expr.right);
266 auto newexpr = mk_bin_op<Scalar_Mul>(scalar, expr.right);
267 return mk_bin_op<expr_op_t>(newexpr, expr.left);
274 template<
class Op,
class Lv,
class Rv>
BCHOT 279 template<
class Op,
class Lv,
class Rv,
class... Args>
BCHOT 283 left, right, std::forward<Args>(args)...);
#define BCINLINE
Definition: common.h:96
BCHOT auto mk_bin_op(Lv left, Rv right, Args &&... args)
Definition: expression_binary.h:134
static auto make(Lv lv, Rv rv, Args &&... args)
Definition: expression_binary.h:142
BCINLINE bc::size_t size() const
Definition: expression_binary.h:112
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:209
BCHOT auto make_bin_expr(Lv left, Rv right, Op oper)
Definition: expression_binary.h:275
BCINLINE auto operator[](bc::size_t index) const
Definition: expression_binary.h:75
static auto make(Un_Op< Negation, Lv > lv, Rv rv, Args &&... args)
Definition: expression_binary.h:172
Rv right
Definition: expression_binary.h:62
static auto make(Lv lv, Rv rv, Scalar_Mul op=Scalar_Mul())
Definition: expression_binary.h:231
Operation get_operation() const
Definition: expression_binary.h:64
static auto make(Un_Op< Negation, Lv > lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:196
BCINLINE auto inner_shape() const
Definition: expression_binary.h:116
Definition: blas_expression_template_traits.h:126
BCHOT Bin_Op(Lv lv, Rv rv, const Args &... args)
Definition: expression_binary.h:69
Lv left
Definition: expression_binary.h:61
BCINLINE bc::size_t cols() const
Definition: expression_binary.h:114
ArrayType array
Definition: expression_unary.h:32
static auto make(Un_Op< Negation, Lv > lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:182
int size_t
Definition: common.h:283
BCINLINE auto operator()(Integers... ints) const
Definition: expression_binary.h:89
Definition: expression_binary.h:139
std::decay_t< decltype(std::declval< Operation >().operator()(std::declval< typename Lv::value_type >(), std::declval< typename Rv::value_type >()))> value_type
Definition: expression_binary.h:28
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:219
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:152
typename Lv::system_tag system_tag
Definition: expression_binary.h:24
#define BCHOT
Definition: common.h:97
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
static constexpr int tensor_dim
Definition: expression_binary.h:30
BCINLINE bc::size_t rows() const
Definition: expression_binary.h:113
static auto make(Lv lv, Un_Op< Negation, Rv > rv, Args &&... args)
Definition: expression_binary.h:162
Definition: expression_template_traits.h:19
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22