1#pragma once
2
3#include <c10/macros/Export.h>
4#include <type.h>
5
6#include <array>
7#include <bitset>
8#include <map>
9#include <unordered_map>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16constexpr int getParallelTypeBitMapOffset(ParallelType pt) {
17 switch (pt) {
18 case ParallelType::BIDx:
19 return 0;
20 case ParallelType::BIDy:
21 return 1;
22 case ParallelType::BIDz:
23 return 2;
24 case ParallelType::TIDx:
25 return 3;
26 case ParallelType::TIDy:
27 return 4;
28 case ParallelType::TIDz:
29 return 5;
30 default:
31 return -1;
32 }
33}
34
35//! Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
36class ParallelTypeBitmap {
37 public:
38 static constexpr int kNumParallelTypes = 6;
39
40 //! Iterator for ParallelTypeBitmap. Picks only set types.
41 class Iterator {
42 public:
43 static Iterator begin(const ParallelTypeBitmap& map);
44
45 static Iterator end(const ParallelTypeBitmap& map);
46
47 bool operator==(const Iterator& other) const;
48
49 bool operator!=(const Iterator& other) const;
50
51 Iterator& operator++();
52
53 Iterator operator++(int);
54
55 ParallelType operator*() const;
56
57 private:
58 Iterator(const ParallelTypeBitmap& map, int offset);
59
60 void skipToSetType();
61
62 private:
63 const ParallelTypeBitmap& map_;
64 int offset_ = 0;
65
66 static constexpr int kOffsetEnd = kNumParallelTypes;
67 };
68
69 ParallelTypeBitmap() = default;
70
71 explicit ParallelTypeBitmap(ParallelType pt) {
72 set(pt);
73 }
74
75 //! Return true if pt is included
76 bool get(ParallelType pt) const {
77 auto offset = getParallelTypeBitMapOffset(pt);
78 TORCH_INTERNAL_ASSERT(
79 offset != -1, "Could not recognize parallel type: ", pt);
80 return bitset_[offset];
81 }
82
83 //! Set the flag of pt
84 bool set(ParallelType pt, bool new_val = true) {
85 auto offset = getParallelTypeBitMapOffset(pt);
86 TORCH_INTERNAL_ASSERT(
87 offset != -1, "Could not recognize parallel type: ", pt);
88 bool old_val = bitset_[offset];
89 bitset_[offset] = new_val;
90 return old_val;
91 }
92
93 //! Clear the flag of pt
94 bool clear(ParallelType pt) {
95 return set(pt, false);
96 }
97
98 //! Assign logical AND with other
99 ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other) {
100 bitset_ &= other.bitset_;
101 return *this;
102 }
103
104 //! Assign logical OR with other
105 ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other) {
106 bitset_ |= other.bitset_;
107 return *this;
108 }
109
110 //! Assign logical NOR with other
111 ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other) {
112 bitset_ ^= other.bitset_;
113 return *this;
114 }
115
116 //! Return logical compliment
117 ParallelTypeBitmap operator~() const {
118 return ParallelTypeBitmap(~bitset_);
119 }
120
121 //! Return true if none of the mapppings is true
122 bool none() const {
123 return bitset_.none();
124 }
125
126 //! Return true if any of the mapppings is true
127 bool any() const {
128 return bitset_.any();
129 }
130
131 //! Return true if all of the mapppings is true
132 bool all() const {
133 return bitset_.all();
134 }
135
136 //! Return true if the parallel type corresponding to a position
137 //! defined in offset_to_pt_ is true
138 bool operator[](size_t pos) const {
139 TORCH_INTERNAL_ASSERT(
140 pos < kNumParallelTypes, "Invalid index to ParallelTypeBitset: ", pos);
141 return bitset_[pos];
142 }
143
144 //! Return true if TIDx/y/z is included
145 bool hasTID() const {
146 return (bitset_ & kTIDBits).any();
147 }
148
149 //! Return true if BIDx/y/z is included
150 bool hasBID() const {
151 return (bitset_ & kBIDBits).any();
152 }
153
154 //! Set all of the TID flags
155 void setAllTID() {
156 *this |= ParallelTypeBitmap(kTIDBits);
157 }
158
159 //! Set all of the BID flags
160 void setAllBID() {
161 *this |= ParallelTypeBitmap(kBIDBits);
162 }
163
164 //! Clear all of the TID flags
165 void clearAllTID() {
166 auto tid_bits = ParallelTypeBitmap(kTIDBits);
167 auto not_tid_bits = ~tid_bits;
168 *this &= not_tid_bits;
169 }
170
171 //! Clear all of the BID flags
172 void clearAllBID() {
173 auto bid_bits = ParallelTypeBitmap(kBIDBits);
174 auto not_bid_bits = ~bid_bits;
175 *this &= not_bid_bits;
176 }
177
178 //! Get an iterator to traverse set types
179 Iterator begin() const {
180 return Iterator::begin(*this);
181 }
182
183 //! Get an end iterator to traverse set types
184 Iterator end() const {
185 return Iterator::end(*this);
186 }
187
188 bool operator==(const ParallelTypeBitmap& other) const {
189 return bitset_ == other.bitset_;
190 }
191
192 std::string toString() const;
193
194 private:
195 explicit constexpr ParallelTypeBitmap(
196 const std::bitset<kNumParallelTypes>& bs)
197 : bitset_(bs) {}
198
199 private:
200 std::bitset<kNumParallelTypes> bitset_;
201
202 static constexpr std::bitset<ParallelTypeBitmap::kNumParallelTypes> kTIDBits{
203 (1u << getParallelTypeBitMapOffset(ParallelType::TIDx)) |
204 (1u << getParallelTypeBitMapOffset(ParallelType::TIDy)) |
205 (1u << getParallelTypeBitMapOffset(ParallelType::TIDz))};
206
207 static constexpr std::bitset<ParallelTypeBitmap::kNumParallelTypes> kBIDBits{
208 (1u << getParallelTypeBitMapOffset(ParallelType::BIDx)) |
209 (1u << getParallelTypeBitMapOffset(ParallelType::BIDy)) |
210 (1u << getParallelTypeBitMapOffset(ParallelType::BIDz))};
211};
212
213inline ParallelTypeBitmap operator&(
214 const ParallelTypeBitmap& lhs,
215 const ParallelTypeBitmap& rhs) {
216 auto x = lhs;
217 x &= rhs;
218 return x;
219}
220
221inline ParallelTypeBitmap operator|(
222 const ParallelTypeBitmap& lhs,
223 const ParallelTypeBitmap& rhs) {
224 auto x = lhs;
225 x |= rhs;
226 return x;
227}
228
229inline ParallelTypeBitmap operator^(
230 const ParallelTypeBitmap& lhs,
231 const ParallelTypeBitmap& rhs) {
232 auto x = lhs;
233 x ^= rhs;
234 return x;
235}
236
237inline bool ParallelTypeBitmap::Iterator::operator==(
238 const ParallelTypeBitmap::Iterator& other) const {
239 return offset_ == other.offset_ && map_ == other.map_;
240}
241
242inline bool ParallelTypeBitmap::Iterator::operator!=(
243 const ParallelTypeBitmap::Iterator& other) const {
244 return !(*this == other);
245}
246
247inline ParallelTypeBitmap::Iterator& ParallelTypeBitmap::Iterator::
248operator++() {
249 ++offset_;
250 skipToSetType();
251 return *this;
252}
253
254inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::operator++(
255 int) {
256 const auto before_increment = *this;
257 ++offset_;
258 skipToSetType();
259 return before_increment;
260}
261
262inline ParallelType ParallelTypeBitmap::Iterator::operator*() const {
263 return kParallelTypeThreads[offset_];
264}
265
266inline ParallelTypeBitmap::Iterator::Iterator(
267 const ParallelTypeBitmap& map,
268 int offset)
269 : map_(map), offset_(offset) {
270 skipToSetType();
271}
272
273inline void ParallelTypeBitmap::Iterator::skipToSetType() {
274 while (offset_ < kOffsetEnd && !map_[offset_]) {
275 ++offset_;
276 }
277}
278
279inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::begin(
280 const ParallelTypeBitmap& map) {
281 return Iterator(map, 0);
282}
283
284inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::end(
285 const ParallelTypeBitmap& map) {
286 return Iterator(map, kOffsetEnd);
287}
288
289//! Map from ParallelType to template type T
290template <typename T>
291class ParallelTypeMap {
292 public:
293 ParallelTypeMap() = default;
294
295 ParallelTypeMap(const T& init) {
296 std::fill(map_.begin(), map_.end(), init);
297 }
298
299 T& operator[](ParallelType pt) {
300 return map_[getParallelTypeBitMapOffset(pt)];
301 }
302
303 const T& operator[](ParallelType pt) const {
304 return map_[getParallelTypeBitMapOffset(pt)];
305 }
306
307 T& at(ParallelType pt) {
308 return map_.at(getParallelTypeBitMapOffset(pt));
309 }
310
311 const T& at(ParallelType pt) const {
312 return map_.at(getParallelTypeBitMapOffset(pt));
313 }
314
315 auto begin() {
316 return map_.begin();
317 }
318
319 auto begin() const {
320 return map_.begin();
321 }
322
323 auto end() {
324 return map_.begin();
325 }
326
327 auto end() const {
328 return map_.begin();
329 }
330
331 private:
332 std::array<T, ParallelTypeBitmap::kNumParallelTypes> map_;
333};
334
335} // namespace cuda
336} // namespace fuser
337} // namespace jit
338} // namespace torch
339