BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
momentum.h
Go to the documentation of this file.
1 /*
2  * Momentum.h
3  *
4  * Created on: Dec 3, 2019
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCATTENSORS_NEURALNETWORKS_OPTIMIZERS_MOMENTUM_H_
9 #define BLACKCATTENSORS_NEURALNETWORKS_OPTIMIZERS_MOMENTUM_H_
10 
11 #include "optimizer_base.h"
12 
13 namespace bc {
14 namespace nn {
15 
16 struct Momentum {
17 
18  template<class Tensor>
20 
21  using value_type = typename Tensor::value_type;
22 
25 
27 
28  template<class... Args>
29  Optimizer(Args&&... args):
30  momentum(std::forward<Args>(args)...) {
31  momentum.zero();
32  }
33 
34  template<class TensorX, class Gradients>
35  void update(TensorX& tensor, Gradients&& delta)
36  {
37  momentum = alpha * momentum + delta * learning_rate;
38  tensor += momentum;
39  }
40 
42  learning_rate = lr;
43  }
44 
45  void save(Layer_Loader& loader, std::string name) const {
46  loader.save_variable(momentum, name);
47  }
48 
49  void load(Layer_Loader& loader, std::string name) {
50  loader.load_variable(momentum, name);
51  }
52  };
53 
54 } momentum;
55 
56 }
57 }
58 
59 
60 
61 #endif /* MOMENTUM_H_ */
void load(Layer_Loader &loader, std::string name)
Definition: momentum.h:49
self_type & zero()
Definition: tensor_base.h:13
value_type learning_rate
Definition: momentum.h:24
void save(Layer_Loader &loader, std::string name) const
Definition: momentum.h:45
Definition: layer_loader.h:19
void save_variable(const T &tensor, string variable_name)
Definition: layer_loader.h:44
void set_learning_rate(value_type lr)
Definition: momentum.h:41
Optimizer(Args &&... args)
Definition: momentum.h:29
Definition: momentum.h:19
Definition: optimizer_base.h:16
void update(TensorX &tensor, Gradients &&delta)
Definition: momentum.h:35
typename Tensor::value_type value_type
Definition: momentum.h:21
value_type alpha
Definition: momentum.h:23
void load_variable(T &tensor, string variable_name)
Definition: layer_loader.h:50
Tensor momentum
Definition: momentum.h:26
Definition: momentum.h:16
Definition: common.h:19
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22