1 | #pragma once |
---|---|
2 | |
3 | #include <mutex> |
4 | #include <unordered_map> |
5 | #include <thread> |
6 | |
7 | #include "taichi/program/kernel_profiler.h" |
8 | #include "taichi/rhi/cuda/cuda_driver.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | // Note: |
13 | // It would be ideal to create a CUDA context per Taichi program, yet CUDA |
14 | // context creation takes time. Therefore we use a shared context to accelerate |
15 | // cases such as unit testing where many Taichi programs are created/destroyed. |
16 | |
17 | class CUDADriver; |
18 | |
19 | class CUDAContext { |
20 | private: |
21 | void *device_; |
22 | void *context_; |
23 | int dev_count_; |
24 | int compute_capability_; |
25 | std::string mcpu_; |
26 | std::mutex lock_; |
27 | KernelProfilerBase *profiler_; |
28 | CUDADriver &driver_; |
29 | bool debug_; |
30 | |
31 | public: |
32 | CUDAContext(); |
33 | |
34 | std::size_t get_total_memory(); |
35 | std::size_t get_free_memory(); |
36 | std::string get_device_name(); |
37 | |
38 | bool detected() const { |
39 | return dev_count_ != 0; |
40 | } |
41 | |
42 | void launch(void *func, |
43 | const std::string &task_name, |
44 | std::vector<void *> arg_pointers, |
45 | std::vector<int> arg_sizes, |
46 | unsigned grid_dim, |
47 | unsigned block_dim, |
48 | std::size_t dynamic_shared_mem_bytes); |
49 | |
50 | void set_profiler(KernelProfilerBase *profiler) { |
51 | profiler_ = profiler; |
52 | } |
53 | |
54 | void set_debug(bool debug) { |
55 | debug_ = debug; |
56 | } |
57 | |
58 | std::string get_mcpu() const { |
59 | return mcpu_; |
60 | } |
61 | |
62 | void *get_context() { |
63 | return context_; |
64 | } |
65 | |
66 | void make_current() { |
67 | driver_.context_set_current(context_); |
68 | } |
69 | |
70 | int get_compute_capability() const { |
71 | return compute_capability_; |
72 | } |
73 | |
74 | ~CUDAContext(); |
75 | |
76 | class ContextGuard { |
77 | private: |
78 | void *old_ctx_; |
79 | void *new_ctx_; |
80 | |
81 | public: |
82 | explicit ContextGuard(CUDAContext *new_ctx) |
83 | : old_ctx_(nullptr), new_ctx_(new_ctx->context_) { |
84 | CUDADriver::get_instance().context_get_current(&old_ctx_); |
85 | if (old_ctx_ != new_ctx_) |
86 | new_ctx->make_current(); |
87 | } |
88 | |
89 | ~ContextGuard() { |
90 | if (old_ctx_ != new_ctx_) { |
91 | CUDADriver::get_instance().context_set_current(old_ctx_); |
92 | } |
93 | } |
94 | }; |
95 | |
96 | ContextGuard get_guard() { |
97 | return ContextGuard(this); |
98 | } |
99 | |
100 | std::unique_lock<std::mutex> get_lock_guard() { |
101 | return std::unique_lock<std::mutex>(lock_); |
102 | } |
103 | |
104 | static CUDAContext &get_instance(); |
105 | }; |
106 | |
107 | } // namespace taichi::lang |
108 |