8 #ifndef BLACKCATTENSORS_NEURALNETWORKS_LAYER_CACHE_H_ 9 #define BLACKCATTENSORS_NEURALNETWORKS_LAYER_CACHE_H_ 12 #include <type_traits> 32 template<
class K,
class V, cache_key_type CacheKeyOverr
ider=inherit>
48 template<
class K,
class V, cache_key_type R>
54 bool is_recurrent =
false;
58 template<
class K,
class V>
63 template<
class K,
class V, cache_key_type R>
71 is_recurrent = enable;
74 template<
class K,
class V, cache_key_type R>
79 template<
class K,
class V>
88 template<
class K,
class V,
class Factory>
97 template<
class K,
class V,
class Factory>
100 return load(key, 0, factory);
103 template<
class K,
class V,
class U>
112 template<
class K,
class V>
114 std::vector<V>& history = cache[hash(key)];
115 unsigned index = history.
size()- 1 - m_time_index + t_modifier;
117 BC_ASSERT((
int)index < (
int)history.size(),
118 "Load recurrent_variable index out of bounds" 122 return history[index];
125 template<
class K,
class V>
127 BC_ASSERT(t_modifier==0,
"Nonrecurrent keys cannot have a time_offset");
128 return cache[hash(key)];
131 template<
class K,
class V,
class DefaultFactory>
134 DefaultFactory
function)
const 136 std::vector<V>& history = cache[hash(key)];
138 unsigned index = history.
size()- 1 - m_time_index + t_modifier;
139 if (index >= history.size()) {
140 history.push_back(
function());
141 return history.back();
144 BC_ASSERT((
int)index < (
int)history.size(),
145 "Load recurrent_variable index out of bounds" 149 return history[index];
152 template<
class K,
class V,
class DefaultFactory>
154 return load(key, 0,
function);
158 template<
class K,
class V,
class DefaultFactory>
160 auto hkey = hash(key);
165 return cache[hkey] =
function();
169 template<
class K,
class V,
class U>
171 cache[hash(key)].push_back(std::forward<U>(expression));
172 return cache[hash(key)].back();
175 template<
class K,
class V,
class U>
178 return cache[hash(key)] = std::forward<U>(expression);
180 return cache[hash(key)] = V(std::forward<U>(expression));
190 template<
class K,
class V>
193 template<
class K,
class V>
202 template<
class K,
class V>
204 auto& storage = cache[hash(key)];
206 if (storage.size() > 1) {
207 auto last = std::move(storage.back());
209 storage.push_back(std::move(last));
bool contains(Any_Key< K, V > key) const
Definition: any_map.h:73
void clear_bp_storage(key_type< K, V, cache_key_type::always_recurrent > key)
Definition: layer_cache.h:203
Definition: layer_cache.h:29
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier, Factory factory) const
Definition: layer_cache.h:89
void decrement_time_index()
Definition: layer_cache.h:186
auto & store(key_type< K, V, cache_key_type::always_forward > key, U &&expression)
Definition: layer_cache.h:176
int size() const
Definition: any_map.h:122
A Dictionary designed to store any type using the 'store' and 'load' functions.
Definition: layer_cache.h:46
cache_key_type
A type designed to act as a key to the Cache object.
Definition: layer_cache.h:26
Definition: layer_cache.h:33
void clear_bp_storage(key_type< K, V, cache_key_type::inherit > key)
Definition: layer_cache.h:194
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, int t_modifier, DefaultFactory function) const
Definition: layer_cache.h:132
std::string to_string(int precision=8, bool pretty=true, bool sparse=false) const
Definition: tensor_utility.h:34
void set_time_index(int idx)
Definition: layer_cache.h:188
void zero_time_index()
Definition: layer_cache.h:187
auto & load(key_type< K, V, cache_key_type::always_forward > key, int t_modifier=0) const
Definition: layer_cache.h:126
static constexpr cache_key_type cache_override_type
Definition: layer_cache.h:34
Definition: layer_cache.h:27
void increment_time_index()
Definition: layer_cache.h:185
void clear_bp_storage(key_type< K, V, cache_key_type::always_forward > key)
Definition: layer_cache.h:191
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, int t_modifier=0) const
Loads the current value at the current time_index.
Definition: layer_cache.h:113
auto & store(key_type< K, V, cache_key_type::inherit > key, U &&expression)
Definition: layer_cache.h:104
auto & load(key_type< K, V, cache_key_type::inherit > key, int t_modifier=0) const
Definition: layer_cache.h:80
Definition: layer_cache.h:28
auto & load(key_type< K, V, cache_key_type::always_forward > key, DefaultFactory function) const
Definition: layer_cache.h:159
Any_Map stores a buck of std::shared_ptr<void>.
Definition: any_map.h:60
#define BC_ASSERT(condition, message)
Definition: common.h:185
auto & store(key_type< K, V, cache_key_type::always_recurrent > key, U &&expression)
Definition: layer_cache.h:170
void enable_recurrent_caching(bool enable=true)
Definition: layer_cache.h:70
int get_time_index() const
Definition: layer_cache.h:184
auto & load(key_type< K, V, cache_key_type::always_recurrent > key, DefaultFactory function) const
Definition: layer_cache.h:153
auto & load(key_type< K, V, cache_key_type::inherit > key, Factory factory) const
Definition: layer_cache.h:98
bool contains(key_type< K, V, R > key) const
Definition: layer_cache.h:75
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22