BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
layer_descriptor.h
Go to the documentation of this file.
1 #ifndef BLACKCAT_NEURALNETWORKS_POLYMORHIC_LAYERS_LAYER_DESCRIPTOR_H_
2 #define BLACKCAT_NEURALNETWORKS_POLYMORHIC_LAYERS_LAYER_DESCRIPTOR_H_
3 
4 namespace bc {
5 namespace nn {
6 
7 template<class OutputTensorDescriptor>
8 class Layer_Input_Base;
9 
10 template<class OutputTensorDescriptor>
11 struct Layer_Output_Base
12 {
13  using output_value_type = typename OutputTensorDescriptor::value_type;
14  using output_system_tag = typename OutputTensorDescriptor::system_tag;
15  using output_allocator_type = typename OutputTensorDescriptor::allocator_type;
16  using output_tensor_dim = typename OutputTensorDescriptor::tensor_dim;
17  using output_shape_type = typename OutputTensorDescriptor::shape_type;
18 
19  using output_tensor_type = typename OutputTensorDescriptor::tensor_type;
20  using batched_output_tensor_type = typename OutputTensorDescriptor::batched_type;
22 
23 private:
25 
26 protected:
27  next_layer_type* m_next_layer = nullptr;
29 
30 public:
31  void set_next(next_layer_type& next) { m_next_layer = &next; }
33  const next_layer_type& next_layer() const { return *m_next_layer; }
35  virtual ~Layer_Output_Base() {}
36 };
37 
38 template<class InputTensorDescriptor>
39 class Layer_Input_Base
40 {
41  using input_value_type = typename InputTensorDescriptor::value_type;
42  using input_system_tag = typename InputTensorDescriptor::system_tag;
43  using input_allocator_type = typename InputTensorDescriptor::allocator_type;
44  using input_tensor_dim = typename InputTensorDescriptor::tensor_dim;
45  using input_shape_type = bc::Dim<input_tensor_dim::value>;
46  using input_tensor_type = typename InputTensorDescriptor::type;
47  using batched_input_tensor_type = typename InputTensorDescriptor::batched_type;
48  using prev_layer_type = Layer_Output_Base<InputTensorDescriptor>;
49 private:
51 
52 protected:
53  prev_layer_type* m_prev_layer;
54  input_shape_type m_input_shape;
55 
56 public:
57  void set_prev(prev_layer_type& prev) { m_prev_layer = &prev; }
58  input_shape_type input_shape() const { return m_input_shape; }
59  const prev_layer_type& prev_layer() const { return *m_prev_layer; }
60  prev_layer_type& prev_layer() { return *m_prev_layer; }
61  virtual ~Layer_Input_Base() {}
62 };
63 
64 }
65 }
66 
67 #endif
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::tensor_dim output_tensor_dim
Definition: layer_base.h:32
const next_layer_type & next_layer() const
Definition: layer_descriptor.h:33
virtual ~Layer_Input_Base()
Definition: layer_descriptor.h:61
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::type output_tensor_type
Definition: layer_base.h:34
next_layer_type & next_layer()
Definition: layer_descriptor.h:34
prev_layer_type & prev_layer()
Definition: layer_descriptor.h:60
void set_next(next_layer_type &next)
Definition: layer_descriptor.h:31
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::allocator_type output_allocator_type
Definition: layer_base.h:31
input_shape_type input_shape() const
Definition: layer_descriptor.h:58
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::system_tag output_system_tag
Definition: layer_base.h:30
output_shape_type output_shape() const
Definition: layer_descriptor.h:32
virtual ~Layer_Output_Base()
Definition: layer_descriptor.h:35
next_layer_type * m_next_layer
Definition: layer_base.h:42
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::value_type output_value_type
Definition: layer_base.h:29
const prev_layer_type & prev_layer() const
Definition: layer_descriptor.h:59
void set_prev(prev_layer_type &prev)
Definition: layer_descriptor.h:57
bc::Dim< output_tensor_dim::value > m_output_shape
Definition: layer_base.h:43
typename Tensor_Descriptor< ValueType, SystemTag, Integer< 3 > > ::batched_type batched_output_tensor_type
Definition: layer_base.h:35
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22