BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
cmath.h
Go to the documentation of this file.
1 /*
2  * Tensor_CMath.h
3  *
4  * Created on: Oct 30, 2018
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCAT_TENSOR_CMATH_H_
9 #define BLACKCAT_TENSOR_CMATH_H_
10 
11 #include <functional>
12 #include <cmath>
13 
14 namespace bc {
15 namespace tensors {
16 template<class> class Expression_Base;
17 template<class> class Tensor_Base;
18 }
19 
20 namespace oper {
21 namespace cmath_functions {
22 
23 #define BLACKCAT_FUNCTOR_DEF(funcName, instance_name, math_function, ...) \
24  \
25  struct funcName { \
26  template<class value_type> BCINLINE \
27  value_type operator () (const value_type& x) const { \
28  return math_function; \
29  } \
30  template<class value_type> BCINLINE \
31  static auto apply(const value_type& x) { \
32  return math_function; \
33  } \
34  template<class Xpr> \
35  auto operator() (const bc::tensors::Expression_Base<Xpr>& tensor) { \
36  return tensor.un_expr(funcName()); \
37  } \
38  template<class Xpr> \
39  auto operator() (const bc::tensors::Tensor_Base<Xpr>& tensor) { \
40  return tensor.un_expr(funcName()); \
41  } \
42  __VA_ARGS__ \
43 } instance_name; \
44 
45 #define DERIVATIVE_DEF(...)\
46 BLACKCAT_FUNCTOR_DEF(Derivative, dx, __VA_ARGS__)
47 
48 #define DERIVATIVE_CACHED_DEF(...)\
49 BLACKCAT_FUNCTOR_DEF(Cached_Derivative, cached_dx, __VA_ARGS__)
50 
51 #define BLACKCAT_MATH_DEF(funcName, instanceName, ...) \
52 BLACKCAT_FUNCTOR_DEF(funcName, instanceName, std::instanceName(x), __VA_ARGS__)
53 
54 //UTILITY 'just returns x'
56 
57 //COMMON
60 BLACKCAT_MATH_DEF( Sqrt , sqrt, DERIVATIVE_DEF((std::pow(x, -1/2)/2)))
61 
62 //Trig
65 BLACKCAT_MATH_DEF( Tan , tan, DERIVATIVE_DEF(std::pow(1/std::cos(x), 2)))
66 BLACKCAT_FUNCTOR_DEF( Sec, sec, 1/std::cos(x) )
67 
68 //Hyperbolic
72  DERIVATIVE_DEF(1 - std::pow(std::tanh(x), 2))
73  DERIVATIVE_CACHED_DEF(1 - std::pow(x,2)))
74 //Arc
75 BLACKCAT_MATH_DEF( Asin , asin, DERIVATIVE_DEF(1/std::sqrt(1-std::pow(x,2))))
76 BLACKCAT_MATH_DEF( Acos , acos, DERIVATIVE_DEF(-1/std::sqrt(1-std::pow(x,2))))
79 
80 //Arc Hyperbolic
84 
115 
116 struct Pow {
117 
118  template<class ValueType, class Exp> BCINLINE
119  ValueType operator () (const ValueType& x, Exp exp) const {
120  return std::pow(x, exp);
121  }
122 
123  template<class ValueType, class Exp> BCINLINE
124  static auto apply(const ValueType& x, Exp exp) {
125  return std::pow(x, exp);
126  }
127 
128  template<class Xpr, class Exp>
129  auto operator() (const bc::tensors::Expression_Base<Xpr>& tensor, Exp exp)
130  {
131  struct FunctorPow
132  {
133  typename Xpr::value_type exp;
134  auto operator() (const typename Xpr::value_type value) const {
135  return std::pow(value, exp);
136  }
137  };
138 
139  return tensor.un_expr(FunctorPow {exp});
140  }
141 
142 } pow;
143 
153 
155 BLACKCAT_FUNCTOR_DEF(Pow3, pow3, (std::pow(x, 3)), DERIVATIVE_DEF(3));
156 
157 BLACKCAT_FUNCTOR_DEF(Logistic, logistic, (1 / (1 + std::exp(-x))),
158  DERIVATIVE_DEF(Logistic::apply(x) * (1 - Logistic::apply(x)))
159  DERIVATIVE_CACHED_DEF(x * (1 - x)));
160 
161 BLACKCAT_FUNCTOR_DEF(Relu, relu, bc::traits::max(0, x),
162  DERIVATIVE_DEF(x > 0 ? 1 : 0)
164 
166 
167 BLACKCAT_FUNCTOR_DEF(SoftPlus, softplus, std::log(1 + std::exp(x)),
168  DERIVATIVE_DEF(Logistic::apply(x)))
169 
171  x * std::tanh(SoftPlus::apply(x)),
172  DERIVATIVE_DEF(
173  std::exp(x)
174  * (4*(x+1)
175  + 4*(std::exp(2*x))
176  + std::exp(3*x)
177  + std::exp(x)*(4*x+6))
178  / std::pow((2*std::exp(x) + std::exp(2*x) + 2),2)))
179 
180 #undef BLACKCAT_FUNCTOR_DEF
181 #undef BLACKCAT_MATH_DEF
182 #undef DERIVATIVE_DEF
183 #undef DERIVATIVE_CACHED_DEF
184 
185 } //end of ns cmath_functions
186 } //end of ns oper
187 
188 using namespace bc::oper::cmath_functions;
189 
190 } //end of ns BC
191 
192 
193 
194 #endif /* TENSOR_CMATH_H_ */
#define BLACKCAT_FUNCTOR_DEF(funcName, instance_name, math_function,...)
Definition: cmath.h:23
struct bc::oper::cmath_functions::Mish mish
Definition: cmath.h:148
Definition: cmath.h:151
struct bc::oper::cmath_functions::Asin asin
Definition: cmath.h:155
struct bc::oper::cmath_functions::Rint rint
struct bc::oper::cmath_functions::Remquo remquo
struct bc::oper::cmath_functions::Nan nan
#define BCINLINE
Definition: common.h:96
struct bc::oper::cmath_functions::Nexttoward nexttoward
Definition: cmath.h:109
Definition: cmath.h:60
struct bc::oper::cmath_functions::Log log
Definition: cmath.h:145
struct bc::oper::cmath_functions::Llround llround
Definition: cmath.h:55
Definition: cmath.h:178
struct bc::oper::cmath_functions::Floor floor
struct bc::oper::cmath_functions::Logb logb
struct bc::oper::cmath_functions::Hypot hypot
Definition: cmath.h:94
auto un_expr(functor f) const
Definition: expression_base.h:104
Definition: cmath.h:108
struct bc::oper::cmath_functions::Relu relu
Definition: cmath.h:93
struct bc::oper::cmath_functions::Asinh asinh
Definition: cmath.h:58
struct bc::oper::cmath_functions::Cbrt cbrt
struct bc::oper::cmath_functions::Nearbyint nearbyint
Definition: cmath.h:110
struct bc::oper::cmath_functions::Log1P log1p
struct bc::oper::cmath_functions::Fma fma
Definition: cmath.h:99
Definition: cmath.h:105
Definition: cmath.h:165
struct bc::oper::cmath_functions::Copysign copysign
Definition: cmath.h:96
Definition: cmath.h:76
struct bc::oper::cmath_functions::Isinf isinf
struct bc::oper::cmath_functions::Fmax fmax
struct bc::oper::cmath_functions::Sec sec
struct bc::oper::cmath_functions::Lrint lrint
Definition: cmath.h:106
Definition: cmath.h:147
Definition: cmath.h:97
struct bc::oper::cmath_functions::Logistic logistic
Definition: cmath.h:104
Definition: cmath.h:101
struct bc::oper::cmath_functions::Isnan isnan
Definition: cmath.h:82
struct bc::oper::cmath_functions::Lround lround
static BCINLINE auto apply(const ValueType &x, Exp exp)
Definition: cmath.h:124
struct bc::oper::cmath_functions::Fmin fmin
struct bc::oper::cmath_functions::Scalbn scalbn
struct bc::oper::cmath_functions::Acos acos
struct bc::oper::cmath_functions::Abs abs
struct bc::oper::cmath_functions::Llrint llrint
Definition: cmath.h:78
struct bc::oper::cmath_functions::Sin sin
Definition: cmath.h:146
struct bc::oper::cmath_functions::Fdim fdim
struct bc::oper::cmath_functions::Sqrt sqrt
Definition: cmath.h:63
struct bc::oper::cmath_functions::Nextafter nextafter
Definition: cmath.h:87
Definition: cmath.h:90
struct bc::oper::cmath_functions::Fabs fabs
struct bc::oper::cmath_functions::Tanh tanh
Definition: cmath.h:16
Definition: cmath.h:83
struct bc::oper::cmath_functions::Expm1 expm1
Definition: cmath.h:81
struct bc::oper::cmath_functions::Log10 log10
struct bc::oper::cmath_functions::Atanh atanh
Definition: cmath.h:107
Definition: cmath.h:85
Definition: cmath.h:77
Definition: cmath.h:111
struct bc::oper::cmath_functions::Modf modf
struct bc::oper::cmath_functions::Ilogb ilogb
Definition: cmath.h:150
Definition: cmath.h:149
Definition: cmath.h:91
struct bc::oper::cmath_functions::Acosh acosh
struct bc::oper::cmath_functions::Exp2 exp2
#define DERIVATIVE_CACHED_DEF(...)
Definition: cmath.h:48
Definition: cmath.h:152
struct bc::oper::cmath_functions::Fmod fmod
struct bc::oper::cmath_functions::Pow pow
Definition: cmath.h:116
Definition: cmath.h:66
struct bc::oper::cmath_functions::Atan atan
Definition: cmath.h:70
Definition: cmath.h:17
struct bc::oper::cmath_functions::Pass pass
Definition: cmath.h:21
Definition: cmath.h:103
Definition: cmath.h:69
struct bc::oper::cmath_functions::Logical logical
Definition: cmath.h:64
Definition: cmath.h:86
Definition: cmath.h:100
struct bc::oper::cmath_functions::Pow3 pow3
struct bc::oper::cmath_functions::Scalbln scalbln
struct bc::oper::cmath_functions::Ceil ceil
struct bc::oper::cmath_functions::Round round
Definition: cmath.h:154
Definition: cmath.h:73
Definition: cmath.h:98
struct bc::oper::cmath_functions::Ldexp ldexp
Definition: cmath.h:92
struct bc::oper::cmath_functions::SoftPlus softplus
struct bc::oper::cmath_functions::Trunc trunc
Definition: cmath.h:65
struct bc::oper::cmath_functions::Atan2 atan2
struct bc::oper::cmath_functions::Tan tan
Definition: cmath.h:89
struct bc::oper::cmath_functions::Cosh cosh
struct bc::oper::cmath_functions::Cos cos
Definition: cmath.h:102
struct bc::oper::cmath_functions::Exp exp
struct bc::oper::cmath_functions::Sinh sinh
Definition: cmath.h:163
struct bc::oper::cmath_functions::Log2 log2
#define DERIVATIVE_DEF(...)
Definition: cmath.h:45
Definition: cmath.h:75
struct bc::oper::cmath_functions::Frexp frexp
struct bc::oper::cmath_functions::Remainder remainder
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Definition: cmath.h:59
struct bc::oper::cmath_functions::Pow2 pow2
Definition: cmath.h:95
#define BLACKCAT_MATH_DEF(funcName, instanceName,...)
Definition: cmath.h:51