1#pragma once
2
3#include "taichi/rhi/device.h"
4#include <assert.h>
5#include <forward_list>
6#include <unordered_set>
7#include <mutex>
8#include <type_traits>
9
10namespace taichi::lang {
11
12// Constructs within `rhi_impl` is for implementing RHI
13// No public-facing API should use anything within `rhi_impl` namespace
14namespace rhi_impl {
15
16template <typename... Ts>
17void disabled_function([[maybe_unused]] Ts... C) {
18}
19
20#if defined(SPDLOG_H) && defined(TI_WARN)
21#define RHI_LOG_ERROR(msg) TI_WARN("RHI Error : {}", msg)
22#else
23#define RHI_LOG_ERROR(msg) std::cerr << "RHI Error: " << msg << std::endl;
24#endif
25
26#define RHI_DEBUG
27#define RHI_USE_TI_LOGGING
28
29#ifdef RHI_DEBUG
30#define RHI_DEBUG_SNPRINTF std::snprintf
31#ifdef RHI_USE_TI_LOGGING
32#include "taichi/common/logging.h"
33#define RHI_LOG_DEBUG(msg) TI_TRACE("RHI Debug : {}", msg)
34#else
35#define RHI_LOG_DEBUG(msg) std::cout << "RHI Debug: " << msg << std::endl;
36#endif
37#else
38#define RHI_DEBUG_SNPRINTF taichi::lang::rhi_impl::disabled_function
39#define RHI_LOG_DEBUG(msg)
40#endif
41
42#define RHI_ASSERT(cond) assert(cond);
43#define RHI_THROW_UNLESS(cond, exception) \
44 if (!(cond)) \
45 throw(exception);
46
47template <typename T>
48constexpr auto saturate_uadd(T a, T b) {
49 static_assert(std::is_unsigned<T>::value);
50 const T c = a + b;
51 if (c < a) {
52 return std::numeric_limits<T>::max();
53 }
54 return c;
55}
56
57template <typename T>
58constexpr auto saturate_usub(T x, T y) {
59 static_assert(std::is_unsigned<T>::value);
60 T res = x - y;
61 res &= -(res <= x);
62
63 return res;
64}
65
66// Wrapped return-code & object tuple for simplicity
67// Easier to read then std::pair
68// NOTE: If an internal function can fail, wrap return object with this!
69template <typename T>
70struct RhiReturn {
71 [[nodiscard]] RhiResult result;
72 [[nodiscard]] T object;
73
74 RhiReturn(RhiResult &result, T &object) : result(result), object(object) {
75 }
76
77 RhiReturn(const RhiResult &result, const T &object)
78 : result(result), object(object) {
79 }
80
81 RhiReturn(RhiResult &&result, T &&object)
82 : result(result), object(std::move(object)) {
83 }
84
85 RhiReturn &operator=(const RhiReturn &other) = default;
86};
87
88// Bi-directional map, useful for mapping between RHI enums and backend enums
89template <typename RhiType, typename BackendType>
90struct BidirMap {
91 std::unordered_map<RhiType, BackendType> rhi2backend;
92 std::unordered_map<BackendType, RhiType> backend2rhi;
93
94 BidirMap(std::initializer_list<std::pair<RhiType, BackendType>> init_list) {
95 for (auto &pair : init_list) {
96 rhi2backend.insert(pair);
97 backend2rhi.insert(std::make_pair(pair.second, pair.first));
98 }
99 }
100
101 bool exists(RhiType &v) const {
102 return rhi2backend.find(v) != rhi2backend.cend();
103 }
104
105 BackendType at(RhiType &v) const {
106 return rhi2backend.at(v);
107 }
108
109 bool exists(BackendType &v) const {
110 return backend2rhi.find(v) != backend2rhi.cend();
111 }
112
113 RhiType at(BackendType &v) const {
114 return backend2rhi.at(v);
115 }
116};
117
118// A synchronized list of objects that is pointer stable & reuse objects
119template <class T>
120class SyncedPtrStableObjectList {
121 using storage_block = std::array<uint8_t, sizeof(T)>;
122
123 public:
124 template <typename... Params>
125 T &acquire(Params &&...args) {
126 std::lock_guard<std::mutex> _(lock_);
127
128 void *storage = nullptr;
129 if (free_nodes_.empty()) {
130 storage = objects_.emplace_front().data();
131 } else {
132 storage = free_nodes_.back();
133 free_nodes_.pop_back();
134 }
135 return *new (storage) T(std::forward<Params>(args)...);
136 }
137
138 void release(T *ptr) {
139 std::lock_guard<std::mutex> _(lock_);
140
141 ptr->~T();
142 free_nodes_.push_back(ptr);
143 }
144
145 void clear() {
146 std::lock_guard<std::mutex> _(lock_);
147
148 // Transfer to quick look-up
149 std::unordered_set<void *> free_nodes_set(free_nodes_.begin(),
150 free_nodes_.end());
151 free_nodes_.clear();
152 // Destroy live objects
153 for (auto &storage : objects_) {
154 T *obj = reinterpret_cast<T *>(storage.data());
155 // Call destructor if object is not in the free list (thus live)
156 if (free_nodes_set.find(obj) == free_nodes_set.end()) {
157 obj->~T();
158 }
159 }
160 // Clear the storage
161 objects_.clear();
162 }
163
164 ~SyncedPtrStableObjectList() {
165 clear();
166 }
167
168 private:
169 std::mutex lock_;
170 std::forward_list<storage_block> objects_;
171 std::vector<void *> free_nodes_;
172};
173
174// A helper to combine hash
175template <class T>
176inline void hash_combine(std::size_t &seed, const T &v) {
177 std::hash<T> hasher;
178 seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
179}
180
181} // namespace rhi_impl
182} // namespace taichi::lang
183