9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_GEMM_H_ 10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_GEMM_H_ 21 template<
class lv,
class rv,
class SystemTag>
22 struct Bin_Op<oper::gemm<SystemTag>, lv, rv>:
26 static_assert(std::is_same<
27 typename lv::value_type,
28 typename rv::value_type>::value,
29 "GEMM arguments must have the same value_type");
31 static_assert(lv::tensor_dim==2 && rv::tensor_dim==2,
32 "Error: GEMM Expression initialized with non matrix tensor");
49 "gemm requires left.cols() == right.rows()");
58 return i == 0 ? left.rows() : i == 1 ? right.cols() : 1;
64 template<
class Core,
int Alpha,
int Beta,
class Stream>
67 auto& out = output.
data();
69 static_assert(Core::tensor_dim == 2,
70 "Gemm out must be a matrix");
72 "Output dim (rows) mismatch for GEMM");
74 "Output dim (cols) mismatch for GEMM");
79 auto contents = traits::template parse_expression<Alpha, Beta>(stream, *
this);
80 auto A = contents.left;
81 auto B = contents.right;
82 auto alpha = contents.alpha;
83 auto beta = contents.beta;
84 auto transA = contents.lv_is_transposed;
85 auto transB = contents.rv_is_transposed;
88 stream, transA, transB, out.rows(), out.cols(), left.cols(),
89 alpha.data(), A.data(), A.leading_dim(1),
90 B.data(), B.leading_dim(1),
91 beta.data(), out.data(), out.leading_dim(1));
93 traits::template post_parse_expression_evaluation(stream, contents);
Definition: tree_output_data.h:18
#define BCINLINE
Definition: common.h:96
__host__ __device__ bc::size_t cols() const
Definition: function_gemm.h:62
__host__ __device__ bc::size_t dim(int i) const
Definition: function_gemm.h:57
Bin_Op(lv left, rv right, oper::gemm< system_tag > op=oper::gemm< system_tag >())
Definition: function_gemm.h:43
Definition: blas_expression_template_traits.h:126
int size_t
Definition: common.h:283
lv left
Definition: function_gemm.h:40
static oper::gemm< SystemTag > get_operation()
Definition: function_gemm.h:52
void eval(Output_Data< Core, Alpha, Beta > output, Stream stream) const
Definition: function_gemm.h:65
Definition: expression_template_base.h:77
rv right
Definition: function_gemm.h:41
BCINLINE bc::size_t dim(int i) const
Definition: expression_binary.h:115
#define BC_ASSERT(condition, message)
Definition: common.h:185
#define BCHOT
Definition: common.h:97
static constexpr int tensor_iterator_dim
Definition: expression_binary.h:56
typename lv::value_type value_type
Definition: function_gemm.h:34
static constexpr int tensor_dim
Definition: expression_binary.h:30
const Tensor & data() const
Definition: tree_output_data.h:26
__host__ __device__ bc::size_t rows() const
Definition: function_gemm.h:61
__host__ __device__ bc::size_t size() const
Definition: function_gemm.h:56
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
SystemTag system_tag
Definition: function_gemm.h:35