8 #ifndef BLACKCAT_TENSORS_NEURALNETWORKS_LSTM_H_ 9 #define BLACKCAT_TENSORS_NEURALNETWORKS_LSTM_H_ 11 #include "../layer_cache.h" 19 template<
class SystemTag,
21 class Optimizer=Stochastic_Gradient_Descent,
23 class WriteGateNonlinearity=
bc::Tanh,
26 class CellStateNonLinearity=
bc::Tanh>
33 ForgetGateNonlinearity,
34 WriteGateNonlinearity,
35 InputGateNonlinearity,
36 OutputGateNonlinearity,
37 CellStateNonLinearity>,
55 ForgetGateNonlinearity,
56 WriteGateNonlinearity,
57 InputGateNonlinearity,
58 OutputGateNonlinearity,
59 CellStateNonLinearity>,
82 using mat_opt_t =
typename Optimizer::template Optimizer<mat>;
83 using vec_opt_t =
typename Optimizer::template Optimizer<vec>;
85 CellStateNonLinearity c_g;
86 ForgetGateNonlinearity f_g;
87 WriteGateNonlinearity z_g;
88 InputGateNonlinearity i_g;
89 OutputGateNonlinearity o_g;
92 mat wf_gradients, wz_gradients, wi_gradients, wo_gradients;
95 mat rf_gradients, rz_gradients, ri_gradients, ro_gradients;
98 vec bf_gradients, bz_gradients, bi_gradients, bo_gradients;
100 mat_opt_t wf_opt, wz_opt, wi_opt, wo_opt;
101 mat_opt_t rf_opt, rz_opt, ri_opt, ro_opt;
102 vec_opt_t bf_opt, bz_opt, bi_opt, bo_opt;
104 mat dc, df, dz, di, do_, dy;
128 wf_gradients(outputs, inputs),
129 wz_gradients(outputs, inputs),
130 wi_gradients(outputs, inputs),
131 wo_gradients(outputs, inputs),
133 rf(outputs, outputs),
134 rz(outputs, outputs),
135 ri(outputs, outputs),
136 ro(outputs, outputs),
138 rf_gradients(outputs, outputs),
139 rz_gradients(outputs, outputs),
140 ri_gradients(outputs, outputs),
141 ro_gradients(outputs, outputs),
148 bf_gradients(outputs),
149 bz_gradients(outputs),
150 bi_gradients(outputs),
151 bo_gradients(outputs),
153 wf_opt(outputs, inputs),
154 wz_opt(outputs, inputs),
155 wi_opt(outputs, inputs),
156 wo_opt(outputs, inputs),
158 rf_opt(outputs, outputs),
159 rz_opt(outputs, outputs),
160 ri_opt(outputs, outputs),
161 ro_opt(outputs, outputs),
190 template<
class X,
class Y>
197 mat& c = cache.
load(
cell_key(), default_tensor_factory());
206 template<
class X,
class Y>
209 mat f = f_g(wf * x + rf * y + bf);
210 mat z = z_g(wz * x + rz * y + bz);
211 mat i = i_g(wi * x + ri * y + bi);
212 mat o = o_g(wo * x + ro * y + bo);
213 mat& c = cache.
load(
cell_key(), default_tensor_factory());
222 template<
class X,
class Y>
225 vec f = f_g(wf * x + rf * y + bf);
226 vec z = z_g(wz * x + rz * y + bz);
227 vec i = i_g(wi * x + ri * y + bi);
228 vec o = o_g(wo * x + ro * y + bo);
235 template<
class X,
class Y,
class Delta>
237 const Delta& delta_outputs,
class Cache& cache)
243 rz_gradients -= dz * y.
t();
244 rf_gradients -= df * y.
t();
245 ri_gradients -= di * y.
t();
246 ro_gradients -= do_ * y.
t();
253 auto& cm1 = cache.
load(
cell_key(), -1, default_tensor_factory());
254 auto& c = cache.
load(
cell_key(), default_tensor_factory());
262 do_ = dy % c_g(c) % o_g.cached_dx(o);
265 auto& fp1 = cache.
load(
forget_key(), 1, default_tensor_factory());
266 dc = dy % o % c_g.dx(c) + dc % fp1;
268 dc = dy % o % c_g.dx(c);
271 df = dc % cm1 % f_g.cached_dx(f);
272 di = dc % z % i_g.cached_dx(i);
273 dz = dc % i % z_g.cached_dx(z);
275 wz_gradients -= dz * x.
t();
276 wf_gradients -= df * x.
t();
277 wi_gradients -= di * x.
t();
278 wo_gradients -= do_ * x.
t();
293 wz_opt.update(wz, wz_gradients);
294 wf_opt.update(wf, wf_gradients);
295 wi_opt.update(wi, wi_gradients);
296 wo_opt.update(wo, wo_gradients);
298 rz_opt.update(rz, rz_gradients);
299 rf_opt.update(rf, rf_gradients);
300 ri_opt.update(ri, ri_gradients);
301 ro_opt.update(ro, ro_gradients);
303 bz_opt.update(bz, bz_gradients);
304 bf_opt.update(bf, bf_gradients);
305 bi_opt.update(bi, bi_gradients);
306 bo_opt.update(bo, bo_gradients);
318 wz_opt, wf_opt, wi_opt, wo_opt,
319 rz_opt, rf_opt, ri_opt, ro_opt);
322 bf_opt, bz_opt, bi_opt, bo_opt);
324 for (
auto& optimizer : optimizers)
325 optimizer.set_learning_rate(batched_lr);
327 for (
auto& optimizer : bias_optimizers)
328 optimizer.set_learning_rate(batched_lr);
334 for (
auto& tensor:
enumerate(dc, df, dz, di, do_, dy)) {
341 for (
auto& delta :
enumerate(dc, df, di, dz, do_, dy)) {
349 wf_gradients, wz_gradients,
350 wi_gradients, wo_gradients,
351 rf_gradients, rz_gradients,
352 ri_gradients, ro_gradients)) {
357 bf_gradients, bz_gradients,
358 bi_gradients, bo_gradients)) {
391 wf_opt.save(loader,
"wf_opt");
392 wz_opt.save(loader,
"wz_opt");
393 wi_opt.save(loader,
"wi_opt");
394 wo_opt.save(loader,
"wo_opt");
396 rf_opt.save(loader,
"rf_opt");
397 rz_opt.save(loader,
"rz_opt");
398 ri_opt.save(loader,
"ri_opt");
399 ro_opt.save(loader,
"ro_opt");
401 bf_opt.save(loader,
"bf_opt");
402 bz_opt.save(loader,
"bz_opt");
403 bi_opt.save(loader,
"bi_opt");
404 bo_opt.save(loader,
"bo_opt");
414 auto& c = cache.
load(
cell_key(), default_tensor_factory());
423 auto& pc = cache.
load(
425 default_predict_tensor_factory());
449 wf_opt.load(loader,
"wf_opt");
450 wz_opt.load(loader,
"wz_opt");
451 wi_opt.load(loader,
"wi_opt");
452 wo_opt.load(loader,
"wo_opt");
454 rf_opt.load(loader,
"rf_opt");
455 rz_opt.load(loader,
"rz_opt");
456 ri_opt.load(loader,
"ri_opt");
457 ro_opt.load(loader,
"ro_opt");
459 bf_opt.load(loader,
"bf_opt");
460 bz_opt.load(loader,
"bz_opt");
461 bi_opt.load(loader,
"bi_opt");
462 bo_opt.load(loader,
"bo_opt");
472 auto& c = cache.
load(
cell_key(), default_tensor_factory());
481 auto& pc = cache.
load(
483 default_predict_tensor_factory());
491 auto& c = cache.
load(
cell_key(), default_tensor_factory());
497 auto default_tensor_factory()
const 504 auto default_predict_tensor_factory()
const 513 template<
class SystemTag,
class Optimizer=nn_default_optimizer_type>
514 auto lstm(SystemTag
system_tag,
int inputs,
int outputs, Optimizer=Optimizer()) {
517 typename SystemTag::default_floating_point_type,
518 Optimizer>(inputs, outputs);
521 template<
class Optimizer=nn_default_optimizer_type>
522 auto lstm(
int inputs,
int outputs, Optimizer=Optimizer()) {
525 typename BLACKCAT_DEFAULT_SYSTEM_T::default_floating_point_type,
526 Optimizer>(inputs, outputs);
void randomize(value_type lb=0, value_type ub=1)
Definition: tensor_base.h:36
ValueType value_type
Definition: lstm.h:44
auto forward_propagation(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:191
self_type & zero()
Definition: tensor_iteralgos.h:12
virtual void set_learning_rate_hook(value_type lr) override final
Definition: lstm.h:312
Definition: constexpr_int.h:14
auto lstm(SystemTag system_tag, int inputs, int outputs, Optimizer=Optimizer())
Definition: lstm.h:514
SystemTag system_tag
Definition: lstm.h:43
Definition: layer_base.h:86
#define BLACKCAT_DEFAULT_SYSTEM_T
Definition: common.h:49
A Dictionary designed to store any type using the 'store' and 'load' functions.
Definition: layer_cache.h:46
Definition: layer_loader.h:19
void zero_deltas()
Definition: lstm.h:339
void save_variable(const T &tensor, string variable_name)
Definition: layer_loader.h:44
Definition: layer_cache.h:33
virtual void set_batch_size_hook(int bs) override final
Definition: lstm.h:332
std::true_type greedy_evaluate_delta
Definition: lstm.h:65
std::true_type is_recurrent
Definition: lstm.h:69
void copy_training_data_to_single_predict(Cache &cache, int batch_index)
Definition: lstm.h:488
std::true_type defines_single_predict
Definition: lstm.h:75
std::true_type forward_requires_outputs
Definition: lstm.h:66
void clear_bp_storage(Cache &m_cache)
Definition: lstm.h:363
bool file_exists(int dim, string filename)
Definition: layer_loader.h:100
void randomize_weights()
Definition: lstm.h:172
bc::nn::Layer_Base< LSTM< SystemTag, ValueType, Optimizer, ForgetGateNonlinearity, WriteGateNonlinearity, InputGateNonlinearity, OutputGateNonlinearity, CellStateNonLinearity >, Tensor_Descriptor< ValueType, SystemTag, Integer< 1 > > >::output_size bc::size_t output_size() const
Definition: layer_base.h:148
int size_t
Definition: common.h:283
Optimizer optimizer_type
Definition: lstm.h:63
void clear_bp_storage(key_type< K, V, cache_key_type::always_forward > key)
Definition: layer_cache.h:191
auto single_predict(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:223
const auto t() const
Definition: expression_base.h:94
std::true_type defines_predict
Definition: lstm.h:72
auto & store(key_type< K, V, cache_key_type::inherit > key, U &&expression)
Definition: layer_cache.h:104
virtual void load(Layer_Loader &loader) override
Definition: lstm.h:431
void load_variable(T &tensor, string variable_name)
Definition: layer_loader.h:50
auto back_propagation(const X &x, const Y &y, const Delta &delta_outputs, class Cache &cache)
Definition: lstm.h:236
bc::nn::Layer_Base< LSTM< SystemTag, ValueType, Optimizer, ForgetGateNonlinearity, WriteGateNonlinearity, InputGateNonlinearity, OutputGateNonlinearity, CellStateNonLinearity >, Tensor_Descriptor< ValueType, SystemTag, Integer< 1 > > >::batch_size bc::size_t batch_size() const
Definition: layer_base.h:149
std::true_type requires_extra_cache
Definition: lstm.h:68
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier=0) const
Definition: layer_cache.h:80
Definition: layer_cache.h:28
ReferenceList< T > enumerate(T &t, Ts &... ts)
Definition: reference_iterator.h:56
bc::nn::Layer_Base< LSTM< SystemTag, ValueType, Optimizer, ForgetGateNonlinearity, WriteGateNonlinearity, InputGateNonlinearity, OutputGateNonlinearity, CellStateNonLinearity >, Tensor_Descriptor< ValueType, SystemTag, Integer< 1 > > >::get_batched_learning_rate auto get_batched_learning_rate() const
Definition: layer_base.h:171
std::true_type backward_requires_outputs
Definition: lstm.h:67
void update_weights()
Definition: lstm.h:291
void set_learning_rate(value_type learning_rate)
Definition: layer_base.h:162
auto predict(const X &x, const Y &y, Cache &cache)
Definition: lstm.h:207
virtual void save_from_cache(Layer_Loader &loader, const Cache &cache) const override
Definition: lstm.h:408
int get_time_index() const
Definition: layer_cache.h:184
void zero_gradients()
Definition: lstm.h:346
bool contains(key_type< K, V, R > key) const
Definition: layer_cache.h:75
virtual void load_to_cache(Layer_Loader &loader, const Cache &cache) override
Definition: lstm.h:466
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Definition: recycle_allocator.h:57
LSTM(int inputs, bc::size_t outputs)
Definition: lstm.h:121
virtual void save(Layer_Loader &loader) const
Definition: lstm.h:373