BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
layer_cache.h
Go to the documentation of this file.
1 /*
2  * Layer_Cache.h
3  *
4  * Created on: Aug 31, 2019
5  * Author: joseph
6  */
7 
8 #ifndef BLACKCATTENSORS_NEURALNETWORKS_LAYER_CACHE_H_
9 #define BLACKCATTENSORS_NEURALNETWORKS_LAYER_CACHE_H_
10 
11 #include <vector>
12 #include <type_traits>
13 
14 namespace bc {
15 namespace nn {
16 
30 };
31 
32 template<class K, class V, cache_key_type CacheKeyOverrider=inherit>
34  static constexpr cache_key_type cache_override_type = CacheKeyOverrider;
35 };
36 
37 
46 struct Cache {
47 
48  template<class K, class V, cache_key_type R>
50 
51 private:
52 
53  int m_time_index = 0;
54  bool is_recurrent = false;
55 
56  mutable bc::utility::Any_Map cache;
57 
58  template<class K, class V>
61  }
62 
63  template<class K, class V, cache_key_type R>
64  auto hash(key_type<K, V, R> key) const {
66  }
67 
68 public:
69 
70  void enable_recurrent_caching(bool enable=true) {
71  is_recurrent = enable;
72  }
73 
74  template<class K, class V, cache_key_type R>
75  bool contains(key_type<K,V,R> key) const {
76  return cache.contains(key);
77  }
78 
79  template<class K, class V>
80  auto& load(key_type<K, V, cache_key_type::inherit> key, int t_modifier=0) const
81  {
82  if (is_recurrent)
83  return load(key_type<K,V, always_recurrent>(), t_modifier);
84  else
85  return load(key_type<K,V, always_forward>(), t_modifier);
86  }
87 
88  template<class K, class V, class Factory>
89  auto& load(key_type<K, V, cache_key_type::inherit> key, int t_modifier, Factory factory) const
90  {
91  if (is_recurrent)
92  return load(key_type<K,V, cache_key_type::always_recurrent>(), t_modifier, factory);
93  else
94  return load(key_type<K,V, cache_key_type::always_forward>(), factory);
95  }
96 
97  template<class K, class V, class Factory>
98  auto& load(key_type<K, V, cache_key_type::inherit> key, Factory factory) const
99  {
100  return load(key, 0, factory);
101  }
102 
103  template<class K, class V, class U>
104  auto& store(key_type<K, V, cache_key_type::inherit> key, U&& expression) {
105  if (is_recurrent)
106  return store(key_type<K,V, cache_key_type::always_recurrent>(), expression);
107  else
108  return store(key_type<K,V, cache_key_type::always_forward>(), expression);
109  }
110 
112  template<class K, class V>
113  auto& load(key_type<K, V, cache_key_type::always_recurrent> key, int t_modifier=0) const {
114  std::vector<V>& history = cache[hash(key)];
115  unsigned index = history.size()- 1 - m_time_index + t_modifier;
116 
117  BC_ASSERT((int)index < (int)history.size(),
118  "Load recurrent_variable index out of bounds"
119  "\nHistory size: " + std::to_string(history.size()) +
120  "\nIndex:" + std::to_string(index));
121 
122  return history[index];
123  }
124 
125  template<class K, class V>
126  auto& load(key_type<K, V, cache_key_type::always_forward> key, int t_modifier=0) const {
127  BC_ASSERT(t_modifier==0, "Nonrecurrent keys cannot have a time_offset");
128  return cache[hash(key)];
129  }
130 
131  template<class K, class V, class DefaultFactory>
133  int t_modifier,
134  DefaultFactory function) const
135  {
136  std::vector<V>& history = cache[hash(key)];
137 
138  unsigned index = history.size()- 1 - m_time_index + t_modifier;
139  if (index >= history.size()) {
140  history.push_back(function());
141  return history.back();
142  }
143 
144  BC_ASSERT((int)index < (int)history.size(),
145  "Load recurrent_variable index out of bounds"
146  "\nHistory size: " + std::to_string(history.size()) +
147  "\nIndex:" + std::to_string(index));
148 
149  return history[index];
150  }
151 
152  template<class K, class V, class DefaultFactory>
153  auto& load(key_type<K, V, cache_key_type::always_recurrent> key, DefaultFactory function) const {
154  return load(key, 0, function);
155  }
156 
157 
158  template<class K, class V, class DefaultFactory>
159  auto& load(key_type<K, V, cache_key_type::always_forward> key, DefaultFactory function) const {
160  auto hkey = hash(key);
161 
162  if (cache.contains(hkey)) {
163  return cache[hkey];
164  } else {
165  return cache[hkey] = function();
166  }
167  }
168 
169  template<class K, class V, class U>
171  cache[hash(key)].push_back(std::forward<U>(expression));
172  return cache[hash(key)].back();
173  }
174 
175  template<class K, class V, class U>
177  if (cache.contains(hash(key))) {
178  return cache[hash(key)] = std::forward<U>(expression);
179  } else {
180  return cache[hash(key)] = V(std::forward<U>(expression));
181  }
182  }
183 
184  int get_time_index() const { return m_time_index; }
185  void increment_time_index() { m_time_index++; }
186  void decrement_time_index() { m_time_index--; }
187  void zero_time_index() { set_time_index(0); }
188  void set_time_index(int idx) { m_time_index = idx; }
189 
190  template<class K, class V>
192 
193  template<class K, class V>
195  {
196  if (is_recurrent) {
198  clear_bp_storage(k);
199  }
200  }
201 
202  template<class K, class V>
204  auto& storage = cache[hash(key)];
205 
206  if (storage.size() > 1) {
207  auto last = std::move(storage.back());
208  storage.clear();
209  storage.push_back(std::move(last));
210  }
211  }
212 };
213 
214 
215 }
216 }
217 
218 
219 #endif /* LAYER_CACHE_H_ */
bool contains(Any_Key< K, V > key) const
Definition: any_map.h:73
void clear_bp_storage(key_type< K, V, cache_key_type::always_recurrent > key)
Definition: layer_cache.h:203
Definition: layer_cache.h:29
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier, Factory factory) const
Definition: layer_cache.h:89
void decrement_time_index()
Definition: layer_cache.h:186
Definition: any_map.h:36
auto & store(key_type< K, V, cache_key_type::always_forward > key, U &&expression)
Definition: layer_cache.h:176
int size() const
Definition: any_map.h:122
A Dictionary designed to store any type using the &#39;store&#39; and &#39;load&#39; functions.
Definition: layer_cache.h:46
cache_key_type
A type designed to act as a key to the Cache object.
Definition: layer_cache.h:26
Definition: layer_cache.h:33
void clear_bp_storage(key_type< K, V, cache_key_type::inherit > key)
Definition: layer_cache.h:194
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, int t_modifier, DefaultFactory function) const
Definition: layer_cache.h:132
std::string to_string(int precision=8, bool pretty=true, bool sparse=false) const
Definition: tensor_utility.h:34
void set_time_index(int idx)
Definition: layer_cache.h:188
void zero_time_index()
Definition: layer_cache.h:187
auto & load(key_type< K, V, cache_key_type::always_forward > key, int t_modifier=0) const
Definition: layer_cache.h:126
static constexpr cache_key_type cache_override_type
Definition: layer_cache.h:34
Definition: layer_cache.h:27
void increment_time_index()
Definition: layer_cache.h:185
void clear_bp_storage(key_type< K, V, cache_key_type::always_forward > key)
Definition: layer_cache.h:191
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, int t_modifier=0) const
Loads the current value at the current time_index.
Definition: layer_cache.h:113
auto & store(key_type< K, V, cache_key_type::inherit > key, U &&expression)
Definition: layer_cache.h:104
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier=0) const
Definition: layer_cache.h:80
Definition: layer_cache.h:28
auto & load(key_type< K, V, cache_key_type::always_forward > key, DefaultFactory function) const
Definition: layer_cache.h:159
Any_Map stores a buck of std::shared_ptr<void>.
Definition: any_map.h:60
#define BC_ASSERT(condition, message)
Definition: common.h:185
auto & store(key_type< K, V, cache_key_type::always_recurrent > key, U &&expression)
Definition: layer_cache.h:170
void enable_recurrent_caching(bool enable=true)
Definition: layer_cache.h:70
int get_time_index() const
Definition: layer_cache.h:184
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, DefaultFactory function) const
Definition: layer_cache.h:153
auto & load(key_type< K, V, cache_key_type::inherit > key, Factory factory) const
Definition: layer_cache.h:98
bool contains(key_type< K, V, R > key) const
Definition: layer_cache.h:75
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22