BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
device.h
Go to the documentation of this file.
1 /*
2  * Device.h
3  *
4  * Created on: Jan 24, 2019
5  * Author: joseph
6  */
7 
8 #ifdef __CUDACC__
9 #ifndef BC_CONTEXT_DEVICE_H_
10 #define BC_CONTEXT_DEVICE_H_
11 
12 #include "stream_synchronization.h"
13 #include "host_stream.h"
14 
15 #include <cuda.h>
16 #include <cuda_runtime.h>
17 #include <cublas_v2.h>
18 #include <cublas.h>
19 
20 #include <memory>
21 #include <future>
22 
23 namespace bc {
24 namespace streams {
25 
26 template<class SystemTag>
27 class Stream;
28 
29 template<>
31 
32  struct Device_Stream_Contents {
33 
34  HostStream m_host_stream;
35  cublasHandle_t m_cublas_handle;
36  cudaStream_t m_stream_handle=nullptr;
37  cudaEvent_t m_event_handle=nullptr;
38 
40 
41  Device_Stream_Contents(bool init_stream=true) {
42  BC_CUDA_ASSERT(cublasCreate(&m_cublas_handle));
43  BC_CUDA_ASSERT(cudaEventCreate(&m_event_handle));
44  BC_CUDA_ASSERT((cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE)));
45 
46  if (init_stream) {
47  BC_CUDA_ASSERT(cudaStreamCreate(&m_stream_handle));
48  BC_CUDA_ASSERT(cublasSetStream(m_cublas_handle, m_stream_handle));
49  }
50  }
51 
52  ~Device_Stream_Contents() {
53  BC_CUDA_ASSERT(cublasDestroy(m_cublas_handle));
54  BC_CUDA_ASSERT(cudaEventDestroy(m_event_handle));
55 
56  if (m_stream_handle)
57  BC_CUDA_ASSERT(cudaStreamDestroy(m_stream_handle));
58  }
59  };
60 
61  using contents_handle_t = std::shared_ptr<Device_Stream_Contents>;
62 
63  static contents_handle_t get_default_contents() {
64  thread_local contents_handle_t default_contents =
65  contents_handle_t(new Device_Stream_Contents(false));
66  return default_contents;
67  }
68 
69  contents_handle_t m_contents = get_default_contents();
70 
71 public:
72 
75 
77  return m_contents->m_workspace;
78  }
79 
80  template<class RebindType>
82  return typename allocator_type::template rebind<RebindType>::other(get_allocator());
83  }
84 
86  cublasSetPointerMode(m_contents->m_cublas_handle, CUBLAS_POINTER_MODE_HOST);
87  }
89  cublasSetPointerMode(m_contents->m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
90  }
91 
92  cublasHandle_t get_cublas_handle() const {
93  return m_contents->m_cublas_handle;
94  }
95 
96  operator cudaStream_t() const {
97  return m_contents->m_stream_handle;
98  }
99 
100  void set_stream(Stream dev) {
101  m_contents = dev.m_contents;
102  }
103 
104  void record_event() {
105  BC_CUDA_ASSERT(cudaEventRecord(
106  m_contents->m_event_handle,
107  m_contents->m_stream_handle));
108  }
109 
110  void wait_event(Stream& stream) {
111  BC_CUDA_ASSERT(cudaStreamWaitEvent(
112  m_contents->m_stream_handle,
113  stream.m_contents->m_event_handle, 0));
114  }
115 
116  void wait_stream(Stream& stream) {
117  stream.record_event();
118  BC_CUDA_ASSERT(cudaStreamWaitEvent(
119  m_contents->m_stream_handle,
120  stream.m_contents->m_event_handle, 0));
121  }
122 
123  void wait_event(cudaEvent_t event) {
124  BC_CUDA_ASSERT(cudaStreamWaitEvent(
125  m_contents->m_stream_handle,
126  event, 0));
127  }
128 
129  bool is_default() {
130  return m_contents->m_stream_handle == 0;
131  }
132 
133  void create() {
134  m_contents =
135  contents_handle_t(new Device_Stream_Contents(true));
136  }
137 
138  void destroy() {
139  m_contents = get_default_contents();
140  }
141 
142  void sync()
143  {
144  if (!is_default()) {
146  cudaStreamSynchronize(m_contents->m_stream_handle));
147  }
148  }
149 
150  template<class Function>
151  void enqueue(Function f) {
152  f();
153  }
154 
155  template<
156  class function,
157  class=std::enable_if_t<
158  std::is_void<
159  decltype(std::declval<function>()())>::value>>
160  void enqueue_callback(function func)
161  {
162  if (is_default()) {
163  func();
164  return;
165  }
166 
167  BC_CUDA_ASSERT(cudaEventRecord(
168  m_contents->m_event_handle,
169  m_contents->m_stream_handle));
170 
171  m_contents->m_host_stream.push(
172  [&, func]() {
173  cudaEventSynchronize(m_contents->m_event_handle);
174  func();
175  }
176  );
177  }
178 
179  template<
180  class function,
181  class=std::enable_if_t<
182  !std::is_void<
183  decltype(std::declval<function>()())>::value>,
184  int ADL=0>
185  auto enqueue_callback(function func)
186  {
187  std::promise<decltype(func())> promise;
188 
189  if (is_default()){
190  promise.set_value(func());
191  return promise.get_future();
192  }
193 
194  auto future = promise.get_future();
195  BC_CUDA_ASSERT(cudaEventRecord(
196  m_contents->m_event_handle,
197  m_contents->m_stream_handle));
198 
199  m_contents->m_host_stream.push(
201  [this, func](std::promise<decltype(func())> promise) {
202  cudaEventSynchronize(this->m_contents->m_event_handle);
203  promise.set_value(func());
204  }, std::move(promise)));
205 
206  return future;
207  }
208 
209  bool operator == (const Stream& dev) {
210  return m_contents == dev.m_contents;
211  }
212 
213  bool operator != (const Stream& dev) {
214  return m_contents != dev.m_contents;
215  }
216 };
217 
218 
219 }
220 }
221 
222 
223 #endif /* DEVICE_H_ */
224 #endif
allocator_type & get_allocator()
Definition: device.h:76
auto enqueue_callback(function func)
Definition: device.h:185
Bind< Function, Args &&... > bind(Function function, Args &&... args)
Definition: bind.h:105
void destroy()
Definition: device.h:138
void sync()
Definition: device.h:142
void create()
Definition: device.h:133
void enqueue(Function f)
Definition: device.h:151
void wait_event(cudaEvent_t event)
Definition: device.h:123
void wait_event(Stream &stream)
Definition: device.h:110
bool is_default()
Definition: device.h:129
void record_event()
Definition: device.h:104
Definition: common.h:32
cublasHandle_t get_cublas_handle() const
Definition: device.h:92
auto set_blas_pointer_mode_host()
Definition: device.h:85
auto get_allocator_rebound()
Definition: device.h:81
Definition: host_stream.h:56
#define BC_CUDA_ASSERT(...)
Definition: common.h:194
void wait_stream(Stream &stream)
Definition: device.h:116
An unsynced memory pool implemented as a stack.
Definition: stack_allocator.h:138
auto set_blas_pointer_mode_device()
Definition: device.h:88
void enqueue_callback(function func)
Definition: device.h:160
void set_stream(Stream dev)
Definition: device.h:100
auto operator==(const Expression_Base< Xpr > &param) const
Definition: expression_operations.h:38
Definition: device.h:27
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22