1 #define BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor) \ 4 class=std::enable_if_t< \ 5 std::is_convertible<ScalarType, value_type>::value>> \ 6 auto op (const ScalarType& param) const \ 8 return bi_expr(oper::op_functor(), \ 9 exprs::make_scalar_constant<system_tag>((value_type)param)); \ 14 class=std::enable_if_t< \ 15 std::is_convertible<ScalarType, value_type>::value>> \ 16 friend auto op (const ScalarType& param, const Expression_Base& tensor) \ 18 value_type value = param; \ 19 auto scalar = exprs::make_scalar_constant<system_tag>(value); \ 20 return make_expression(scalar).bi_expr(oper:: op_functor (), tensor); \ 23 #define BC_COEFFICIENTWISE_DEF(op, op_functor) \ 25 auto op (const Expression_Base<Xpr>& param) const \ 27 assert_valid(param); \ 28 return bi_expr(oper::op_functor(), param); \ 31 BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor) \ 51 #undef BC_SCALAR_COEFFICIENTWISE_DEF 52 #undef BC_COEFFICIENTWISE_DEF 55 auto operator *(
const Expression_Base<Xpr>& param)
const {
57 using blas_traits = exprs::blas_expression_traits<expression_type>;
58 using rv_blas_traits = exprs::blas_expression_traits<Xpr>;
60 constexpr
bool lv_trans = blas_traits::is_transposed::value;
61 constexpr
bool rv_trans = rv_blas_traits::is_transposed::value;
63 constexpr
bool scalmul = tensor_dim == 0 || Xpr::tensor_dim == 0;
64 constexpr
bool gemm = tensor_dim == 2 && Xpr::tensor_dim == 2;
65 constexpr
bool gemv = tensor_dim == 2 && Xpr::tensor_dim == 1;
66 constexpr
bool ger = tensor_dim == 1 && Xpr::tensor_dim == 1 &&
67 !lv_trans && rv_trans;
68 constexpr
bool dot = tensor_dim == 1 && Xpr::tensor_dim == 1 &&
69 !lv_trans && !rv_trans;
72 std::conditional_t<scalmul, oper::Scalar_Mul,
73 std::conditional_t<gemm, oper::gemm<system_tag>,
74 std::conditional_t<gemv, oper::gemv<system_tag>,
75 std::conditional_t<ger, oper::ger<system_tag>,
76 std::conditional_t<dot, oper::dot<system_tag>,
void>>>>>;
78 static_assert(!std::is_void<matmul_t>::value,
79 "INVALID USE OF OPERATOR *");
80 return bi_expr(matmul_t(), param);
102 template<
class functor>
107 this->expression_template(), f));
113 class=std::enable_if_t<
114 exprs::expression_traits<Xpr>::is_expression_template::value>>
115 auto bi_expr(Functor func,
const Xpr& rv)
const 119 this->expression_template(),
120 rv.expression_template(),
127 static_assert(std::is_same<system_tag, typename Xpr::system_tag>::value,
128 "Tensor arguments must have compatible (same) system_tags");
130 bool same_dim = tensor_dim == Xpr::tensor_dim;
131 bool same_shape = this->inner_shape() == tensor.inner_shape();
132 bool cwise_op = same_dim && same_shape;
134 bool scalar_op = tensor_dim == 0 || Xpr::tensor_dim == 0;
135 bool valid_broadcast_op = !same_dim && !cwise_op &&
valid_slice(tensor);
136 bool valid_cwise_op = (same_dim && same_shape);
138 if (!scalar_op && !valid_broadcast_op && !valid_cwise_op) {
139 throw std::invalid_argument(
140 "Tensor by Tensor operation error: shape mismatch." 143 "\nthis_dims = " + this->inner_shape().
to_string() +
146 "\nparam_dims = " + tensor.inner_shape().to_string()
155 for (
int i = 0; i < min_dim; ++i)
156 if (tensor.dim(i) != this->
dim(i))
const auto t() const
Definition: expression_operations.h:93
auto make_transpose(expr_t expr)
Definition: function_transpose.h:75
BCHOT auto make_bin_expr(Lv left, Rv right, Op oper)
Definition: expression_binary.h:275
BCHOT auto make_un_expr(Expression expression, Operation operation=Operation())
Definition: expression_unary.h:73
auto max_value(const Expression_Base< Xpr > ¶m) const
Definition: expression_operations.h:47
auto operator-(const Expression_Base< Xpr > ¶m) const
Definition: expression_operations.h:35
bool valid_slice(const Expression_Base< Xpr > &tensor) const
Definition: expression_operations.h:152
BCINLINE auto dim(const Integers &... ints)
Definition: dim.h:336
std::string to_string(int precision=8, bool pretty=true, bool sparse=false) const
Definition: tensor_utility.h:34
void assert_valid(const Expression_Base< Xpr > &tensor) const
Definition: expression_operations.h:125
auto make_expression(ExpressionTemplate expression)
Definition: common.h:42
#define BC_COEFFICIENTWISE_DEF(op, op_functor)
Definition: expression_operations.h:23
auto approx_equal(const Expression_Base< Xpr > ¶m) const
Definition: expression_operations.h:46
int size_t
Definition: common.h:283
auto bi_expr(Functor func, const Xpr &rv) const
Definition: expression_operations.h:115
const auto transpose() const
Definition: expression_operations.h:85
#define BC_SCALAR_COEFFICIENTWISE_DEF(op, op_functor)
Definition: expression_operations.h:1
auto un_expr(functor f) const
Definition: expression_operations.h:103
auto operator*(const ScalarType ¶m) const
Definition: expression_operations.h:49
struct bc::oper::Negation negation
auto min_value(const Expression_Base< Xpr > ¶m) const
Definition: expression_operations.h:48