1 | #pragma once |
2 | |
3 | #include "taichi/rhi/device.h" |
4 | #include <assert.h> |
5 | #include <forward_list> |
6 | #include <unordered_set> |
7 | #include <mutex> |
8 | #include <type_traits> |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | // Constructs within `rhi_impl` is for implementing RHI |
13 | // No public-facing API should use anything within `rhi_impl` namespace |
14 | namespace rhi_impl { |
15 | |
16 | template <typename... Ts> |
17 | void disabled_function([[maybe_unused]] Ts... C) { |
18 | } |
19 | |
20 | #if defined(SPDLOG_H) && defined(TI_WARN) |
21 | #define RHI_LOG_ERROR(msg) TI_WARN("RHI Error : {}", msg) |
22 | #else |
23 | #define RHI_LOG_ERROR(msg) std::cerr << "RHI Error: " << msg << std::endl; |
24 | #endif |
25 | |
26 | #define RHI_DEBUG |
27 | #define RHI_USE_TI_LOGGING |
28 | |
29 | #ifdef RHI_DEBUG |
30 | #define RHI_DEBUG_SNPRINTF std::snprintf |
31 | #ifdef RHI_USE_TI_LOGGING |
32 | #include "taichi/common/logging.h" |
33 | #define RHI_LOG_DEBUG(msg) TI_TRACE("RHI Debug : {}", msg) |
34 | #else |
35 | #define RHI_LOG_DEBUG(msg) std::cout << "RHI Debug: " << msg << std::endl; |
36 | #endif |
37 | #else |
38 | #define RHI_DEBUG_SNPRINTF taichi::lang::rhi_impl::disabled_function |
39 | #define RHI_LOG_DEBUG(msg) |
40 | #endif |
41 | |
42 | #define RHI_ASSERT(cond) assert(cond); |
43 | #define RHI_THROW_UNLESS(cond, exception) \ |
44 | if (!(cond)) \ |
45 | throw(exception); |
46 | |
47 | template <typename T> |
48 | constexpr auto saturate_uadd(T a, T b) { |
49 | static_assert(std::is_unsigned<T>::value); |
50 | const T c = a + b; |
51 | if (c < a) { |
52 | return std::numeric_limits<T>::max(); |
53 | } |
54 | return c; |
55 | } |
56 | |
57 | template <typename T> |
58 | constexpr auto saturate_usub(T x, T y) { |
59 | static_assert(std::is_unsigned<T>::value); |
60 | T res = x - y; |
61 | res &= -(res <= x); |
62 | |
63 | return res; |
64 | } |
65 | |
66 | // Wrapped return-code & object tuple for simplicity |
67 | // Easier to read then std::pair |
68 | // NOTE: If an internal function can fail, wrap return object with this! |
69 | template <typename T> |
70 | struct RhiReturn { |
71 | [[nodiscard]] RhiResult result; |
72 | [[nodiscard]] T object; |
73 | |
74 | RhiReturn(RhiResult &result, T &object) : result(result), object(object) { |
75 | } |
76 | |
77 | RhiReturn(const RhiResult &result, const T &object) |
78 | : result(result), object(object) { |
79 | } |
80 | |
81 | RhiReturn(RhiResult &&result, T &&object) |
82 | : result(result), object(std::move(object)) { |
83 | } |
84 | |
85 | RhiReturn &operator=(const RhiReturn &other) = default; |
86 | }; |
87 | |
88 | // Bi-directional map, useful for mapping between RHI enums and backend enums |
89 | template <typename RhiType, typename BackendType> |
90 | struct BidirMap { |
91 | std::unordered_map<RhiType, BackendType> rhi2backend; |
92 | std::unordered_map<BackendType, RhiType> backend2rhi; |
93 | |
94 | BidirMap(std::initializer_list<std::pair<RhiType, BackendType>> init_list) { |
95 | for (auto &pair : init_list) { |
96 | rhi2backend.insert(pair); |
97 | backend2rhi.insert(std::make_pair(pair.second, pair.first)); |
98 | } |
99 | } |
100 | |
101 | bool exists(RhiType &v) const { |
102 | return rhi2backend.find(v) != rhi2backend.cend(); |
103 | } |
104 | |
105 | BackendType at(RhiType &v) const { |
106 | return rhi2backend.at(v); |
107 | } |
108 | |
109 | bool exists(BackendType &v) const { |
110 | return backend2rhi.find(v) != backend2rhi.cend(); |
111 | } |
112 | |
113 | RhiType at(BackendType &v) const { |
114 | return backend2rhi.at(v); |
115 | } |
116 | }; |
117 | |
118 | // A synchronized list of objects that is pointer stable & reuse objects |
119 | template <class T> |
120 | class SyncedPtrStableObjectList { |
121 | using storage_block = std::array<uint8_t, sizeof(T)>; |
122 | |
123 | public: |
124 | template <typename... Params> |
125 | T &acquire(Params &&...args) { |
126 | std::lock_guard<std::mutex> _(lock_); |
127 | |
128 | void *storage = nullptr; |
129 | if (free_nodes_.empty()) { |
130 | storage = objects_.emplace_front().data(); |
131 | } else { |
132 | storage = free_nodes_.back(); |
133 | free_nodes_.pop_back(); |
134 | } |
135 | return *new (storage) T(std::forward<Params>(args)...); |
136 | } |
137 | |
138 | void release(T *ptr) { |
139 | std::lock_guard<std::mutex> _(lock_); |
140 | |
141 | ptr->~T(); |
142 | free_nodes_.push_back(ptr); |
143 | } |
144 | |
145 | void clear() { |
146 | std::lock_guard<std::mutex> _(lock_); |
147 | |
148 | // Transfer to quick look-up |
149 | std::unordered_set<void *> free_nodes_set(free_nodes_.begin(), |
150 | free_nodes_.end()); |
151 | free_nodes_.clear(); |
152 | // Destroy live objects |
153 | for (auto &storage : objects_) { |
154 | T *obj = reinterpret_cast<T *>(storage.data()); |
155 | // Call destructor if object is not in the free list (thus live) |
156 | if (free_nodes_set.find(obj) == free_nodes_set.end()) { |
157 | obj->~T(); |
158 | } |
159 | } |
160 | // Clear the storage |
161 | objects_.clear(); |
162 | } |
163 | |
164 | ~SyncedPtrStableObjectList() { |
165 | clear(); |
166 | } |
167 | |
168 | private: |
169 | std::mutex lock_; |
170 | std::forward_list<storage_block> objects_; |
171 | std::vector<void *> free_nodes_; |
172 | }; |
173 | |
174 | // A helper to combine hash |
175 | template <class T> |
176 | inline void hash_combine(std::size_t &seed, const T &v) { |
177 | std::hash<T> hasher; |
178 | seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); |
179 | } |
180 | |
181 | } // namespace rhi_impl |
182 | } // namespace taichi::lang |
183 | |