BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
adam.h
Go to the documentation of this file.
1 /*
2  * Adam.h
3  *
4  * Created on: Dec 11, 2019
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCAT_TENSORS_NEURALNETWORKS_OPTIMIZERS_ADAM_H_
9 #define BLACKCAT_TENSORS_NEURALNETWORKS_OPTIMIZERS_ADAM_H_
10 
11 #include "optimizer_base.h"
12 
13 namespace bc {
14 namespace nn {
15 
16 struct Adam {
17 
18  template<class Tensor>
19  struct Optimizer;
20 
21 } adam;
22 
23 
24 template<class Tensor>
26 
27  using value_type = typename Tensor::value_type;
28  using system_tag = typename Tensor::system_tag;
29 
30  value_type alpha = bc::nn::default_learning_rate;
31  value_type beta_1 = 0.9;
32  value_type beta_2 = 0.999;
33  value_type epsilon = 1e-8;
34  value_type time_stamp = 0;
35 
38 
39  template<class... Args>
40  Optimizer(Args&&... args):
41  m_t(std::forward<Args>(args)...),
42  v_t(std::forward<Args>(args)...) {
43 
44  m_t.zero();
45  v_t.zero();
46  }
47 
48  template<class TensorX, class Gradients>
49  void update(TensorX& tensor, Gradients&& delta)
50  {
51  time_stamp++;
52  m_t = beta_1 * m_t + (1-beta_1) * delta;
53  v_t = beta_2 * v_t + (1-beta_2) * bc::pow2(delta);
54 
55  auto m_cap = m_t/(1-(bc::pow(beta_1, time_stamp)));
56  auto v_cap = v_t/(1-(bc::pow(beta_2, time_stamp)));
57 
58  tensor += (alpha*m_cap)/(bc::sqrt(v_cap)+epsilon);
59  }
60 
61 
63  alpha = lr;
64  }
65 
66  void save(Layer_Loader& loader, std::string name) const {
67  //TODO add support for loader saving primitives
68  }
69 
70  void load(Layer_Loader& loader, std::string name) {
71  //TODO add support for loader loading primitives
72  }
73 };
74 
75 }
76 }
77 
78 
79 
80 
81 #endif /* ADAM_H_ */
Definition: adam.h:19
self_type & zero()
Definition: tensor_base.h:13
void load(Layer_Loader &loader, std::string name)
Definition: adam.h:70
void save(Layer_Loader &loader, std::string name) const
Definition: adam.h:66
Definition: layer_loader.h:19
void set_learning_rate(value_type lr)
Definition: adam.h:62
typename Tensor::value_type value_type
Definition: adam.h:27
struct bc::oper::cmath_functions::Sqrt sqrt
Optimizer(Args &&... args)
Definition: adam.h:40
typename Tensor::system_tag system_tag
Definition: adam.h:28
Definition: optimizer_base.h:16
Definition: adam.h:16
Tensor m_t
Definition: adam.h:36
struct bc::oper::cmath_functions::Pow pow
struct bc::nn::Adam adam
void update(TensorX &tensor, Gradients &&delta)
Definition: adam.h:49
Definition: common.h:19
Tensor v_t
Definition: adam.h:37
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
struct bc::oper::cmath_functions::Pow2 pow2