9 #ifndef BC_CONTEXT_DEVICE_H_ 10 #define BC_CONTEXT_DEVICE_H_ 16 #include <cuda_runtime.h> 17 #include <cublas_v2.h> 26 template<
class SystemTag>
32 struct Device_Stream_Contents {
35 cublasHandle_t m_cublas_handle;
36 cudaStream_t m_stream_handle=
nullptr;
37 cudaEvent_t m_event_handle=
nullptr;
41 Device_Stream_Contents(
bool init_stream=
true) {
44 BC_CUDA_ASSERT((cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE)));
52 ~Device_Stream_Contents() {
61 using contents_handle_t = std::shared_ptr<Device_Stream_Contents>;
63 static contents_handle_t get_default_contents() {
64 thread_local contents_handle_t default_contents =
65 contents_handle_t(
new Device_Stream_Contents(
false));
66 return default_contents;
69 contents_handle_t m_contents = get_default_contents();
77 return m_contents->m_workspace;
80 template<
class RebindType>
82 return typename allocator_type::template rebind<RebindType>::other(get_allocator());
86 cublasSetPointerMode(m_contents->m_cublas_handle, CUBLAS_POINTER_MODE_HOST);
89 cublasSetPointerMode(m_contents->m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
93 return m_contents->m_cublas_handle;
96 operator cudaStream_t()
const {
97 return m_contents->m_stream_handle;
101 m_contents = dev.m_contents;
106 m_contents->m_event_handle,
107 m_contents->m_stream_handle));
112 m_contents->m_stream_handle,
113 stream.m_contents->m_event_handle, 0));
117 stream.record_event();
119 m_contents->m_stream_handle,
120 stream.m_contents->m_event_handle, 0));
125 m_contents->m_stream_handle,
130 return m_contents->m_stream_handle == 0;
135 contents_handle_t(
new Device_Stream_Contents(
true));
139 m_contents = get_default_contents();
146 cudaStreamSynchronize(m_contents->m_stream_handle));
150 template<
class Function>
157 class=std::enable_if_t<
159 decltype(std::declval<function>()())>::value>>
168 m_contents->m_event_handle,
169 m_contents->m_stream_handle));
171 m_contents->m_host_stream.push(
173 cudaEventSynchronize(m_contents->m_event_handle);
181 class=std::enable_if_t<
183 decltype(std::declval<function>()())>::value>,
187 std::promise<decltype(func())> promise;
190 promise.set_value(func());
191 return promise.get_future();
194 auto future = promise.get_future();
196 m_contents->m_event_handle,
197 m_contents->m_stream_handle));
199 m_contents->m_host_stream.push(
201 [
this, func](std::promise<decltype(func())> promise) {
202 cudaEventSynchronize(this->m_contents->m_event_handle);
203 promise.set_value(func());
204 }, std::move(promise)));
210 return m_contents == dev.m_contents;
214 return m_contents != dev.m_contents;
allocator_type & get_allocator()
Definition: device.h:76
auto enqueue_callback(function func)
Definition: device.h:185
Bind< Function, Args &&... > bind(Function function, Args &&... args)
Definition: bind.h:105
void destroy()
Definition: device.h:138
void sync()
Definition: device.h:142
void create()
Definition: device.h:133
void enqueue(Function f)
Definition: device.h:151
void wait_event(cudaEvent_t event)
Definition: device.h:123
void wait_event(Stream &stream)
Definition: device.h:110
bool is_default()
Definition: device.h:129
void record_event()
Definition: device.h:104
cublasHandle_t get_cublas_handle() const
Definition: device.h:92
auto set_blas_pointer_mode_host()
Definition: device.h:85
auto get_allocator_rebound()
Definition: device.h:81
Definition: host_stream.h:56
#define BC_CUDA_ASSERT(...)
Definition: common.h:194
void wait_stream(Stream &stream)
Definition: device.h:116
An unsynced memory pool implemented as a stack.
Definition: stack_allocator.h:138
auto set_blas_pointer_mode_device()
Definition: device.h:88
void enqueue_callback(function func)
Definition: device.h:160
void set_stream(Stream dev)
Definition: device.h:100
auto operator==(const Expression_Base< Xpr > ¶m) const
Definition: expression_operations.h:38
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22