BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
function_sum.h
Go to the documentation of this file.
1 /* Project: BlackCat_Scalars
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BC_EXPRESSION_TEMPLATES_FUNCTION_SUM_H_
10 #define BC_EXPRESSION_TEMPLATES_FUNCTION_SUM_H_
11 
13 #include "tree_evaluator.h"
14 
16 
17 
18 namespace bc {
19 namespace tensors {
20 namespace exprs {
21 
22 template<class SystemTag>
23 struct Sum {};
24 
25 template<class ArrayType, class SystemTag>
26 struct Un_Op<Sum<SystemTag>, ArrayType>:
27  Expression_Base<Un_Op<Sum<SystemTag>, ArrayType>>,
28  Shape<0>,
29  Sum<SystemTag> {
30 
31  using value_type = typename ArrayType::value_type;
32  using system_tag = SystemTag;
33  using requires_greedy_evaluation = std::true_type;
34 
35  static constexpr int tensor_dim = 0;
36  static constexpr int tensor_iterator_dim = 0;
37 
38  ArrayType array;
39 
41 
42  Un_Op(ArrayType array, Sum<system_tag> op=Sum<system_tag>()):
43  array(array) {}
44 
46  return Sum<SystemTag>();
47  }
48 
49  template<class Scalar, int Alpha, int Beta, class Stream>
50  void eval(Output_Data<Scalar, Alpha, Beta> output, Stream stream) const {
51  static_assert(Scalar::tensor_dim==0, "Output must be a scalar");
52 
53  //TODO handle alpha/beta scalars
55  stream,
56  output.data(),
57  array);
58  }
59 };
60 
61 
62 } //ns BC
63 } //ns exprs
64 } //ns tensors
65 
66 
67 
68 #endif /* FUNCTION_DOT_H_ */
Definition: tree_output_data.h:18
Definition: shape.h:17
Definition: function_sum.h:23
size_t value_type
Definition: shape.h:120
std::true_type requires_greedy_evaluation
Definition: function_sum.h:33
void eval(Output_Data< Scalar, Alpha, Beta > output, Stream stream) const
Definition: function_sum.h:50
auto sum(const Expression_Base< Expression > &tensor)
Definition: tensor_static_functions.h:15
SystemTag system_tag
Definition: function_sum.h:32
static constexpr int tensor_dim
Definition: tensor_base.h:38
Definition: expression_template_base.h:77
const Tensor & data() const
Definition: tree_output_data.h:26
Definition: device.h:27
ArrayType array
Definition: function_sum.h:38
static Sum< SystemTag > get_operation()
Definition: function_sum.h:45
Definition: expression_template_traits.h:19
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Un_Op(ArrayType array, Sum< system_tag > op=Sum< system_tag >())
Definition: function_sum.h:42