1#pragma once
2
3#include <mutex>
4#include <unordered_map>
5#include <thread>
6
7#include "taichi/program/kernel_profiler.h"
8#include "taichi/rhi/cuda/cuda_driver.h"
9
10namespace taichi::lang {
11
12// Note:
13// It would be ideal to create a CUDA context per Taichi program, yet CUDA
14// context creation takes time. Therefore we use a shared context to accelerate
15// cases such as unit testing where many Taichi programs are created/destroyed.
16
17class CUDADriver;
18
19class CUDAContext {
20 private:
21 void *device_;
22 void *context_;
23 int dev_count_;
24 int compute_capability_;
25 std::string mcpu_;
26 std::mutex lock_;
27 KernelProfilerBase *profiler_;
28 CUDADriver &driver_;
29 bool debug_;
30
31 public:
32 CUDAContext();
33
34 std::size_t get_total_memory();
35 std::size_t get_free_memory();
36 std::string get_device_name();
37
38 bool detected() const {
39 return dev_count_ != 0;
40 }
41
42 void launch(void *func,
43 const std::string &task_name,
44 std::vector<void *> arg_pointers,
45 std::vector<int> arg_sizes,
46 unsigned grid_dim,
47 unsigned block_dim,
48 std::size_t dynamic_shared_mem_bytes);
49
50 void set_profiler(KernelProfilerBase *profiler) {
51 profiler_ = profiler;
52 }
53
54 void set_debug(bool debug) {
55 debug_ = debug;
56 }
57
58 std::string get_mcpu() const {
59 return mcpu_;
60 }
61
62 void *get_context() {
63 return context_;
64 }
65
66 void make_current() {
67 driver_.context_set_current(context_);
68 }
69
70 int get_compute_capability() const {
71 return compute_capability_;
72 }
73
74 ~CUDAContext();
75
76 class ContextGuard {
77 private:
78 void *old_ctx_;
79 void *new_ctx_;
80
81 public:
82 explicit ContextGuard(CUDAContext *new_ctx)
83 : old_ctx_(nullptr), new_ctx_(new_ctx->context_) {
84 CUDADriver::get_instance().context_get_current(&old_ctx_);
85 if (old_ctx_ != new_ctx_)
86 new_ctx->make_current();
87 }
88
89 ~ContextGuard() {
90 if (old_ctx_ != new_ctx_) {
91 CUDADriver::get_instance().context_set_current(old_ctx_);
92 }
93 }
94 };
95
96 ContextGuard get_guard() {
97 return ContextGuard(this);
98 }
99
100 std::unique_lock<std::mutex> get_lock_guard() {
101 return std::unique_lock<std::mutex>(lock_);
102 }
103
104 static CUDAContext &get_instance();
105};
106
107} // namespace taichi::lang
108