BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
lstm.h
Go to the documentation of this file.
1 /*
2  * LSTM.h
3  *
4  * Created on: Aug 3, 2019
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCAT_TENSORS_NEURALNETWORKS_LSTM_H_
9 #define BLACKCAT_TENSORS_NEURALNETWORKS_LSTM_H_
10 
11 #include "../layer_cache.h"
12 #include "layer_base.h"
13 
14 namespace bc {
15 namespace nn {
16 
18 
19 template<class SystemTag,
20  class ValueType,
21  class Optimizer=Stochastic_Gradient_Descent,
22  class ForgetGateNonlinearity=bc::Logistic,
23  class WriteGateNonlinearity=bc::Tanh,
24  class InputGateNonlinearity=bc::Logistic,
25  class OutputGateNonlinearity=bc::Logistic,
26  class CellStateNonLinearity=bc::Tanh>
27 struct LSTM:
28  public Layer_Base<
29  LSTM<
30  SystemTag,
31  ValueType,
32  Optimizer,
33  ForgetGateNonlinearity,
34  WriteGateNonlinearity,
35  InputGateNonlinearity,
36  OutputGateNonlinearity,
37  CellStateNonLinearity>,
38  Tensor_Descriptor<
39  ValueType,
40  SystemTag,
41  Integer<1>>>
42 {
43  using system_tag = SystemTag;
44  using value_type = ValueType;
46  ValueType,
47  SystemTag,
49 
50  using parent_type = Layer_Base<
51  LSTM<
52  SystemTag,
53  ValueType,
54  Optimizer,
55  ForgetGateNonlinearity,
56  WriteGateNonlinearity,
57  InputGateNonlinearity,
58  OutputGateNonlinearity,
59  CellStateNonLinearity>,
61 
63  using optimizer_type = Optimizer;
64 
65  using greedy_evaluate_delta = std::true_type;
66  using forward_requires_outputs = std::true_type;
67  using backward_requires_outputs = std::true_type;
68  using requires_extra_cache = std::true_type;
69  using is_recurrent = std::true_type;
70 
71 #ifndef _MSC_VER
72  using defines_predict = std::true_type;
73 #endif
74 
75  using defines_single_predict = std::true_type;
76 
77 private:
78 
81 
82  using mat_opt_t = typename Optimizer::template Optimizer<mat>;
83  using vec_opt_t = typename Optimizer::template Optimizer<vec>;
84 
85  CellStateNonLinearity c_g;
86  ForgetGateNonlinearity f_g;
87  WriteGateNonlinearity z_g;
88  InputGateNonlinearity i_g;
89  OutputGateNonlinearity o_g;
90 
91  mat wf, wz, wi, wo;
92  mat wf_gradients, wz_gradients, wi_gradients, wo_gradients;
93 
94  mat rf, rz, ri, ro;
95  mat rf_gradients, rz_gradients, ri_gradients, ro_gradients;
96 
97  vec bf, bz, bi, bo;
98  vec bf_gradients, bz_gradients, bi_gradients, bo_gradients;
99 
100  mat_opt_t wf_opt, wz_opt, wi_opt, wo_opt;
101  mat_opt_t rf_opt, rz_opt, ri_opt, ro_opt;
102  vec_opt_t bf_opt, bz_opt, bi_opt, bo_opt;
103 
104  mat dc, df, dz, di, do_, dy;
105 
106  template<char C>
107  using key_type = bc::nn::cache_key<
109 
110  using cell_key = key_type<'c'>;
111  using forget_key = key_type<'f'>;
112  using input_key = key_type<'i'>;
113  using write_key = key_type<'z'>;
114  using output_key = key_type<'o'>;
115 
117  bc::utility::Name<'p','c'>, vec, cache_key_type::always_recurrent>;
118 
119 public:
120 
121  LSTM(int inputs, bc::size_t outputs):
122  parent_type(__func__, {inputs}, {outputs}),
123  wf(outputs, inputs),
124  wz(outputs, inputs),
125  wi(outputs, inputs),
126  wo(outputs, inputs),
127 
128  wf_gradients(outputs, inputs),
129  wz_gradients(outputs, inputs),
130  wi_gradients(outputs, inputs),
131  wo_gradients(outputs, inputs),
132 
133  rf(outputs, outputs),
134  rz(outputs, outputs),
135  ri(outputs, outputs),
136  ro(outputs, outputs),
137 
138  rf_gradients(outputs, outputs),
139  rz_gradients(outputs, outputs),
140  ri_gradients(outputs, outputs),
141  ro_gradients(outputs, outputs),
142 
143  bf(outputs),
144  bz(outputs),
145  bi(outputs),
146  bo(outputs),
147 
148  bf_gradients(outputs),
149  bz_gradients(outputs),
150  bi_gradients(outputs),
151  bo_gradients(outputs),
152 
153  wf_opt(outputs, inputs),
154  wz_opt(outputs, inputs),
155  wi_opt(outputs, inputs),
156  wo_opt(outputs, inputs),
157 
158  rf_opt(outputs, outputs),
159  rz_opt(outputs, outputs),
160  ri_opt(outputs, outputs),
161  ro_opt(outputs, outputs),
162 
163  bf_opt(outputs),
164  bz_opt(outputs),
165  bi_opt(outputs),
166  bo_opt(outputs)
167  {
169  zero_gradients();
170  }
171 
173  {
174  wf.randomize(-.1, .1);
175  wz.randomize(-.1, .1);
176  wi.randomize(-.1, .1);
177  wo.randomize(-.1, .1);
178 
179  rf.randomize(-.1, .1);
180  rz.randomize(-.1, .1);
181  ri.randomize(-.1, .1);
182  ro.randomize(-.1, .1);
183 
184  bf.randomize(-.1, .1);
185  bz.randomize(-.1, .1);
186  bi.randomize(-.1, .1);
187  bo.randomize(-.1, .1);
188  }
189 
190  template<class X, class Y>
191  auto forward_propagation(const X& x, const Y& y, Cache& cache)
192  {
193  mat& f = cache.store(forget_key(), f_g(wf * x + rf * y + bf));
194  mat& z = cache.store(write_key(), z_g(wz * x + rz * y + bz));
195  mat& i = cache.store(input_key(), i_g(wi * x + ri * y + bi));
196  mat& o = cache.store(output_key(), o_g(wo * x + ro * y + bo));
197  mat& c = cache.load(cell_key(), default_tensor_factory());
198  c = c % f + z % i; //% element-wise multiplication
199 
200  mat& c_ = cache.store(cell_key(), c);
201  return c_g(c_) % o;
202  }
203 
204 #ifndef _MSC_VER
205 
206  template<class X, class Y>
207  auto predict(const X& x, const Y& y, Cache& cache)
208  {
209  mat f = f_g(wf * x + rf * y + bf);
210  mat z = z_g(wz * x + rz * y + bz);
211  mat i = i_g(wi * x + ri * y + bi);
212  mat o = o_g(wo * x + ro * y + bo);
213  mat& c = cache.load(cell_key(), default_tensor_factory());
214  c = c % f + z % i; //% element-wise multiplication
215 
216  mat& c_ = cache.store(cell_key(), c);
217  return c_g(c_) % o;
218  }
219 
220 #endif
221 
222  template<class X, class Y>
223  auto single_predict(const X& x, const Y& y, Cache& cache)
224  {
225  vec f = f_g(wf * x + rf * y + bf);
226  vec z = z_g(wz * x + rz * y + bz);
227  vec i = i_g(wi * x + ri * y + bi);
228  vec o = o_g(wo * x + ro * y + bo);
229  vec& c = cache.load(predict_cell_key(), default_predict_tensor_factory());
230 
231  c = c % f + z % i; //% element-wise multiplication
232  return c_g(c) % o;
233  }
234 
235  template<class X, class Y, class Delta>
236  auto back_propagation(const X& x, const Y& y,
237  const Delta& delta_outputs, class Cache& cache)
238  {
239  //LSTM Backprop reference
240  //Reference: https://arxiv.org/pdf/1503.04069.pdf
241 
242  if (cache.get_time_index() != 0) {
243  rz_gradients -= dz * y.t();
244  rf_gradients -= df * y.t();
245  ri_gradients -= di * y.t();
246  ro_gradients -= do_ * y.t();
247  }
248 
249  auto& z = cache.load(write_key(), default_tensor_factory());
250  auto& i = cache.load(input_key(), default_tensor_factory());
251  auto& f = cache.load(forget_key(), default_tensor_factory());
252  auto& o = cache.load(output_key(), default_tensor_factory());
253  auto& cm1 = cache.load(cell_key(), -1, default_tensor_factory());
254  auto& c = cache.load(cell_key(), default_tensor_factory());
255 
256  dy = delta_outputs +
257  rz.t() * dz +
258  ri.t() * di +
259  rf.t() * df +
260  ro.t() * do_;
261 
262  do_ = dy % c_g(c) % o_g.cached_dx(o);
263 
264  if (cache.get_time_index() != 0) {
265  auto& fp1 = cache.load(forget_key(), 1, default_tensor_factory());
266  dc = dy % o % c_g.dx(c) + dc % fp1;
267  } else {
268  dc = dy % o % c_g.dx(c);
269  }
270 
271  df = dc % cm1 % f_g.cached_dx(f);
272  di = dc % z % i_g.cached_dx(i);
273  dz = dc % i % z_g.cached_dx(z);
274 
275  wz_gradients -= dz * x.t();
276  wf_gradients -= df * x.t();
277  wi_gradients -= di * x.t();
278  wo_gradients -= do_ * x.t();
279 
280  bz_gradients -= dz;
281  bf_gradients -= df;
282  bi_gradients -= di;
283  bo_gradients -= do_;
284 
285  return wz.t() * dz +
286  wi.t() * dz +
287  wf.t() * df +
288  wo.t() * do_;
289  }
290 
292  {
293  wz_opt.update(wz, wz_gradients);
294  wf_opt.update(wf, wf_gradients);
295  wi_opt.update(wi, wi_gradients);
296  wo_opt.update(wo, wo_gradients);
297 
298  rz_opt.update(rz, rz_gradients);
299  rf_opt.update(rf, rf_gradients);
300  ri_opt.update(ri, ri_gradients);
301  ro_opt.update(ro, ro_gradients);
302 
303  bz_opt.update(bz, bz_gradients);
304  bf_opt.update(bf, bf_gradients);
305  bi_opt.update(bi, bi_gradients);
306  bo_opt.update(bo, bo_gradients);
307 
308  zero_gradients();
309  }
310 
311  virtual
312  void set_learning_rate_hook(value_type lr) override final
313  {
315  value_type batched_lr = this->get_batched_learning_rate();
316 
317  auto optimizers = enumerate(
318  wz_opt, wf_opt, wi_opt, wo_opt,
319  rz_opt, rf_opt, ri_opt, ro_opt);
320 
321  auto bias_optimizers = enumerate(
322  bf_opt, bz_opt, bi_opt, bo_opt);
323 
324  for (auto& optimizer : optimizers)
325  optimizer.set_learning_rate(batched_lr);
326 
327  for (auto& optimizer : bias_optimizers)
328  optimizer.set_learning_rate(batched_lr);
329  }
330 
331  virtual
332  void set_batch_size_hook(int bs) override final
333  {
334  for (auto& tensor: enumerate(dc, df, dz, di, do_, dy)) {
335  tensor = std::move(mat(this->output_size(), bs).zero());
336  }
337  }
338 
339  void zero_deltas()
340  {
341  for (auto& delta : enumerate(dc, df, di, dz, do_, dy)) {
342  delta.zero();
343  }
344  }
345 
347  {
348  for (auto& grad : enumerate(
349  wf_gradients, wz_gradients,
350  wi_gradients, wo_gradients,
351  rf_gradients, rz_gradients,
352  ri_gradients, ro_gradients)) {
353  grad.zero();
354  }
355 
356  for (auto& grad : enumerate(
357  bf_gradients, bz_gradients,
358  bi_gradients, bo_gradients)) {
359  grad.zero();
360  }
361  }
362 
363  void clear_bp_storage(Cache& m_cache)
364  {
365  m_cache.clear_bp_storage(cell_key());
366  m_cache.clear_bp_storage(write_key());
367  m_cache.clear_bp_storage(input_key());
368  m_cache.clear_bp_storage(forget_key());
369  m_cache.clear_bp_storage(output_key());
370  }
371 
372  virtual
373  void save(Layer_Loader& loader) const
374  {
375  loader.save_variable(wf, "wf");
376  loader.save_variable(rf, "rf");
377  loader.save_variable(bf, "bf");
378 
379  loader.save_variable(wz, "wz");
380  loader.save_variable(rz, "rz");
381  loader.save_variable(bz, "bz");
382 
383  loader.save_variable(wi, "wi");
384  loader.save_variable(ri, "ri");
385  loader.save_variable(bi, "bi");
386 
387  loader.save_variable(wo, "wo");
388  loader.save_variable(ro, "ro");
389  loader.save_variable(bo, "bo");
390 
391  wf_opt.save(loader, "wf_opt");
392  wz_opt.save(loader, "wz_opt");
393  wi_opt.save(loader, "wi_opt");
394  wo_opt.save(loader, "wo_opt");
395 
396  rf_opt.save(loader, "rf_opt");
397  rz_opt.save(loader, "rz_opt");
398  ri_opt.save(loader, "ri_opt");
399  ro_opt.save(loader, "ro_opt");
400 
401  bf_opt.save(loader, "bf_opt");
402  bz_opt.save(loader, "bz_opt");
403  bi_opt.save(loader, "bi_opt");
404  bo_opt.save(loader, "bo_opt");
405  }
406 
407  virtual
408  void save_from_cache(Layer_Loader& loader, const Cache& cache) const override
409  {
410  auto& z = cache.load(write_key(), default_tensor_factory());
411  auto& i = cache.load(input_key(), default_tensor_factory());
412  auto& f = cache.load(forget_key(), default_tensor_factory());
413  auto& o = cache.load(output_key(), default_tensor_factory());
414  auto& c = cache.load(cell_key(), default_tensor_factory());
415 
416  loader.save_variable(z, "write_gate_values");
417  loader.save_variable(i, "input_gate_values");
418  loader.save_variable(f, "forget_gate_values");
419  loader.save_variable(o, "output_gate_values");
420  loader.save_variable(c, "cellstate");
421 
422  if (cache.contains(predict_cell_key())) {
423  auto& pc = cache.load(
425  default_predict_tensor_factory());
426  loader.save_variable(pc, "predict_cellstate");
427  }
428  }
429 
430  virtual
431  void load(Layer_Loader& loader) override
432  {
433  loader.load_variable(wf, "wf");
434  loader.load_variable(rf, "rf");
435  loader.load_variable(bf, "bf");
436 
437  loader.load_variable(wz, "wz");
438  loader.load_variable(rz, "rz");
439  loader.load_variable(bz, "bz");
440 
441  loader.load_variable(wi, "wi");
442  loader.load_variable(ri, "ri");
443  loader.load_variable(bi, "bi");
444 
445  loader.load_variable(wo, "wo");
446  loader.load_variable(ro, "ro");
447  loader.load_variable(bo, "bo");
448 
449  wf_opt.load(loader, "wf_opt");
450  wz_opt.load(loader, "wz_opt");
451  wi_opt.load(loader, "wi_opt");
452  wo_opt.load(loader, "wo_opt");
453 
454  rf_opt.load(loader, "rf_opt");
455  rz_opt.load(loader, "rz_opt");
456  ri_opt.load(loader, "ri_opt");
457  ro_opt.load(loader, "ro_opt");
458 
459  bf_opt.load(loader, "bf_opt");
460  bz_opt.load(loader, "bz_opt");
461  bi_opt.load(loader, "bi_opt");
462  bo_opt.load(loader, "bo_opt");
463  }
464 
465  virtual
466  void load_to_cache(Layer_Loader& loader, const Cache& cache) override
467  {
468  auto& z = cache.load(write_key(), default_tensor_factory());
469  auto& i = cache.load(input_key(), default_tensor_factory());
470  auto& f = cache.load(forget_key(), default_tensor_factory());
471  auto& o = cache.load(output_key(), default_tensor_factory());
472  auto& c = cache.load(cell_key(), default_tensor_factory());
473 
474  loader.load_variable(z, "write_gate_values");
475  loader.load_variable(i, "input_gate_values");
476  loader.load_variable(f, "forget_gate_values");
477  loader.load_variable(o, "output_gate_values");
478  loader.load_variable(c, "cellstate");
479 
480  if (loader.file_exists(1, "predict_cellstate")) {
481  auto& pc = cache.load(
483  default_predict_tensor_factory());
484  loader.load_variable(pc, "predict_cellstate");
485  }
486  }
487 
488  void copy_training_data_to_single_predict(Cache& cache, int batch_index)
489  {
490  auto& pc = cache.load(predict_cell_key(), default_predict_tensor_factory());
491  auto& c = cache.load(cell_key(), default_tensor_factory());
492  pc = c[batch_index];
493  }
494 
495 private:
496 
497  auto default_tensor_factory() const
498  {
499  return [&]() {
500  return mat(this->output_size(), this->batch_size()).zero();
501  };
502  }
503 
504  auto default_predict_tensor_factory() const
505  {
506  return [&]() {
507  return vec(this->output_size()).zero();
508  };
509  }
510 
511 };
512 
513 template<class SystemTag, class Optimizer=nn_default_optimizer_type>
514 auto lstm(SystemTag system_tag, int inputs, int outputs, Optimizer=Optimizer()) {
515  return LSTM<
516  SystemTag,
517  typename SystemTag::default_floating_point_type,
518  Optimizer>(inputs, outputs);
519 }
520 
521 template<class Optimizer=nn_default_optimizer_type>
522 auto lstm(int inputs, int outputs, Optimizer=Optimizer()) {
523  return LSTM<
525  typename BLACKCAT_DEFAULT_SYSTEM_T::default_floating_point_type,
526  Optimizer>(inputs, outputs);
527 }
528 
529 
530 }
531 }
532 
533 
534 
535 #endif /* LSTM_H_ */
void randomize(value_type lb=0, value_type ub=1)
Definition: tensor_base.h:36
ValueType value_type
Definition: lstm.h:44
Definition: lstm.h:27
auto forward_propagation(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:191
self_type & zero()
Definition: tensor_iteralgos.h:12
virtual void set_learning_rate_hook(value_type lr) override final
Definition: lstm.h:312
Definition: constexpr_int.h:14
auto lstm(SystemTag system_tag, int inputs, int outputs, Optimizer=Optimizer())
Definition: lstm.h:514
SystemTag system_tag
Definition: lstm.h:43
Definition: layer_base.h:86
#define BLACKCAT_DEFAULT_SYSTEM_T
Definition: common.h:49
A Dictionary designed to store any type using the &#39;store&#39; and &#39;load&#39; functions.
Definition: layer_cache.h:46
Definition: layer_loader.h:19
void zero_deltas()
Definition: lstm.h:339
void save_variable(const T &tensor, string variable_name)
Definition: layer_loader.h:44
Definition: layer_cache.h:33
virtual void set_batch_size_hook(int bs) override final
Definition: lstm.h:332
std::true_type greedy_evaluate_delta
Definition: lstm.h:65
std::true_type is_recurrent
Definition: lstm.h:69
void copy_training_data_to_single_predict(Cache &cache, int batch_index)
Definition: lstm.h:488
std::true_type defines_single_predict
Definition: lstm.h:75
std::true_type forward_requires_outputs
Definition: lstm.h:66
void clear_bp_storage(Cache &m_cache)
Definition: lstm.h:363
bool file_exists(int dim, string filename)
Definition: layer_loader.h:100
void randomize_weights()
Definition: lstm.h:172
int size_t
Definition: common.h:283
Optimizer optimizer_type
Definition: lstm.h:63
void clear_bp_storage(key_type< K, V, cache_key_type::always_forward > key)
Definition: layer_cache.h:191
auto single_predict(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:223
const auto t() const
Definition: expression_base.h:94
std::true_type defines_predict
Definition: lstm.h:72
auto & store(key_type< K, V, cache_key_type::inherit > key, U &&expression)
Definition: layer_cache.h:104
virtual void load(Layer_Loader &loader) override
Definition: lstm.h:431
void load_variable(T &tensor, string variable_name)
Definition: layer_loader.h:50
auto back_propagation(const X &x, const Y &y, const Delta &delta_outputs, class Cache &cache)
Definition: lstm.h:236
std::true_type requires_extra_cache
Definition: lstm.h:68
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier=0) const
Definition: layer_cache.h:80
Definition: layer_cache.h:28
ReferenceList< T > enumerate(T &t, Ts &... ts)
Definition: reference_iterator.h:56
std::true_type backward_requires_outputs
Definition: lstm.h:67
void update_weights()
Definition: lstm.h:291
void set_learning_rate(value_type learning_rate)
Definition: layer_base.h:162
auto predict(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:207
virtual void save_from_cache(Layer_Loader &loader, const Cache &cache) const override
Definition: lstm.h:408
Definition: cmath.h:73
int get_time_index() const
Definition: layer_cache.h:184
Definition: any_map.h:22
void zero_gradients()
Definition: lstm.h:346
bool contains(key_type< K, V, R > key) const
Definition: layer_cache.h:75
virtual void load_to_cache(Layer_Loader &loader, const Cache &cache) override
Definition: lstm.h:466
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Definition: recycle_allocator.h:57
LSTM(int inputs, bc::size_t outputs)
Definition: lstm.h:121
virtual void save(Layer_Loader &loader) const
Definition: lstm.h:373