1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <type.h> |
5 | |
6 | #include <array> |
7 | #include <bitset> |
8 | #include <map> |
9 | #include <unordered_map> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | constexpr int getParallelTypeBitMapOffset(ParallelType pt) { |
17 | switch (pt) { |
18 | case ParallelType::BIDx: |
19 | return 0; |
20 | case ParallelType::BIDy: |
21 | return 1; |
22 | case ParallelType::BIDz: |
23 | return 2; |
24 | case ParallelType::TIDx: |
25 | return 3; |
26 | case ParallelType::TIDy: |
27 | return 4; |
28 | case ParallelType::TIDz: |
29 | return 5; |
30 | default: |
31 | return -1; |
32 | } |
33 | } |
34 | |
35 | //! Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz. |
36 | class ParallelTypeBitmap { |
37 | public: |
38 | static constexpr int kNumParallelTypes = 6; |
39 | |
40 | //! Iterator for ParallelTypeBitmap. Picks only set types. |
41 | class Iterator { |
42 | public: |
43 | static Iterator begin(const ParallelTypeBitmap& map); |
44 | |
45 | static Iterator end(const ParallelTypeBitmap& map); |
46 | |
47 | bool operator==(const Iterator& other) const; |
48 | |
49 | bool operator!=(const Iterator& other) const; |
50 | |
51 | Iterator& operator++(); |
52 | |
53 | Iterator operator++(int); |
54 | |
55 | ParallelType operator*() const; |
56 | |
57 | private: |
58 | Iterator(const ParallelTypeBitmap& map, int offset); |
59 | |
60 | void skipToSetType(); |
61 | |
62 | private: |
63 | const ParallelTypeBitmap& map_; |
64 | int offset_ = 0; |
65 | |
66 | static constexpr int kOffsetEnd = kNumParallelTypes; |
67 | }; |
68 | |
69 | ParallelTypeBitmap() = default; |
70 | |
71 | explicit ParallelTypeBitmap(ParallelType pt) { |
72 | set(pt); |
73 | } |
74 | |
75 | //! Return true if pt is included |
76 | bool get(ParallelType pt) const { |
77 | auto offset = getParallelTypeBitMapOffset(pt); |
78 | TORCH_INTERNAL_ASSERT( |
79 | offset != -1, "Could not recognize parallel type: " , pt); |
80 | return bitset_[offset]; |
81 | } |
82 | |
83 | //! Set the flag of pt |
84 | bool set(ParallelType pt, bool new_val = true) { |
85 | auto offset = getParallelTypeBitMapOffset(pt); |
86 | TORCH_INTERNAL_ASSERT( |
87 | offset != -1, "Could not recognize parallel type: " , pt); |
88 | bool old_val = bitset_[offset]; |
89 | bitset_[offset] = new_val; |
90 | return old_val; |
91 | } |
92 | |
93 | //! Clear the flag of pt |
94 | bool clear(ParallelType pt) { |
95 | return set(pt, false); |
96 | } |
97 | |
98 | //! Assign logical AND with other |
99 | ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other) { |
100 | bitset_ &= other.bitset_; |
101 | return *this; |
102 | } |
103 | |
104 | //! Assign logical OR with other |
105 | ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other) { |
106 | bitset_ |= other.bitset_; |
107 | return *this; |
108 | } |
109 | |
110 | //! Assign logical NOR with other |
111 | ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other) { |
112 | bitset_ ^= other.bitset_; |
113 | return *this; |
114 | } |
115 | |
116 | //! Return logical compliment |
117 | ParallelTypeBitmap operator~() const { |
118 | return ParallelTypeBitmap(~bitset_); |
119 | } |
120 | |
121 | //! Return true if none of the mapppings is true |
122 | bool none() const { |
123 | return bitset_.none(); |
124 | } |
125 | |
126 | //! Return true if any of the mapppings is true |
127 | bool any() const { |
128 | return bitset_.any(); |
129 | } |
130 | |
131 | //! Return true if all of the mapppings is true |
132 | bool all() const { |
133 | return bitset_.all(); |
134 | } |
135 | |
136 | //! Return true if the parallel type corresponding to a position |
137 | //! defined in offset_to_pt_ is true |
138 | bool operator[](size_t pos) const { |
139 | TORCH_INTERNAL_ASSERT( |
140 | pos < kNumParallelTypes, "Invalid index to ParallelTypeBitset: " , pos); |
141 | return bitset_[pos]; |
142 | } |
143 | |
144 | //! Return true if TIDx/y/z is included |
145 | bool hasTID() const { |
146 | return (bitset_ & kTIDBits).any(); |
147 | } |
148 | |
149 | //! Return true if BIDx/y/z is included |
150 | bool hasBID() const { |
151 | return (bitset_ & kBIDBits).any(); |
152 | } |
153 | |
154 | //! Set all of the TID flags |
155 | void setAllTID() { |
156 | *this |= ParallelTypeBitmap(kTIDBits); |
157 | } |
158 | |
159 | //! Set all of the BID flags |
160 | void setAllBID() { |
161 | *this |= ParallelTypeBitmap(kBIDBits); |
162 | } |
163 | |
164 | //! Clear all of the TID flags |
165 | void clearAllTID() { |
166 | auto tid_bits = ParallelTypeBitmap(kTIDBits); |
167 | auto not_tid_bits = ~tid_bits; |
168 | *this &= not_tid_bits; |
169 | } |
170 | |
171 | //! Clear all of the BID flags |
172 | void clearAllBID() { |
173 | auto bid_bits = ParallelTypeBitmap(kBIDBits); |
174 | auto not_bid_bits = ~bid_bits; |
175 | *this &= not_bid_bits; |
176 | } |
177 | |
178 | //! Get an iterator to traverse set types |
179 | Iterator begin() const { |
180 | return Iterator::begin(*this); |
181 | } |
182 | |
183 | //! Get an end iterator to traverse set types |
184 | Iterator end() const { |
185 | return Iterator::end(*this); |
186 | } |
187 | |
188 | bool operator==(const ParallelTypeBitmap& other) const { |
189 | return bitset_ == other.bitset_; |
190 | } |
191 | |
192 | std::string toString() const; |
193 | |
194 | private: |
195 | explicit constexpr ParallelTypeBitmap( |
196 | const std::bitset<kNumParallelTypes>& bs) |
197 | : bitset_(bs) {} |
198 | |
199 | private: |
200 | std::bitset<kNumParallelTypes> bitset_; |
201 | |
202 | static constexpr std::bitset<ParallelTypeBitmap::kNumParallelTypes> kTIDBits{ |
203 | (1u << getParallelTypeBitMapOffset(ParallelType::TIDx)) | |
204 | (1u << getParallelTypeBitMapOffset(ParallelType::TIDy)) | |
205 | (1u << getParallelTypeBitMapOffset(ParallelType::TIDz))}; |
206 | |
207 | static constexpr std::bitset<ParallelTypeBitmap::kNumParallelTypes> kBIDBits{ |
208 | (1u << getParallelTypeBitMapOffset(ParallelType::BIDx)) | |
209 | (1u << getParallelTypeBitMapOffset(ParallelType::BIDy)) | |
210 | (1u << getParallelTypeBitMapOffset(ParallelType::BIDz))}; |
211 | }; |
212 | |
213 | inline ParallelTypeBitmap operator&( |
214 | const ParallelTypeBitmap& lhs, |
215 | const ParallelTypeBitmap& rhs) { |
216 | auto x = lhs; |
217 | x &= rhs; |
218 | return x; |
219 | } |
220 | |
221 | inline ParallelTypeBitmap operator|( |
222 | const ParallelTypeBitmap& lhs, |
223 | const ParallelTypeBitmap& rhs) { |
224 | auto x = lhs; |
225 | x |= rhs; |
226 | return x; |
227 | } |
228 | |
229 | inline ParallelTypeBitmap operator^( |
230 | const ParallelTypeBitmap& lhs, |
231 | const ParallelTypeBitmap& rhs) { |
232 | auto x = lhs; |
233 | x ^= rhs; |
234 | return x; |
235 | } |
236 | |
237 | inline bool ParallelTypeBitmap::Iterator::operator==( |
238 | const ParallelTypeBitmap::Iterator& other) const { |
239 | return offset_ == other.offset_ && map_ == other.map_; |
240 | } |
241 | |
242 | inline bool ParallelTypeBitmap::Iterator::operator!=( |
243 | const ParallelTypeBitmap::Iterator& other) const { |
244 | return !(*this == other); |
245 | } |
246 | |
247 | inline ParallelTypeBitmap::Iterator& ParallelTypeBitmap::Iterator:: |
248 | operator++() { |
249 | ++offset_; |
250 | skipToSetType(); |
251 | return *this; |
252 | } |
253 | |
254 | inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::operator++( |
255 | int) { |
256 | const auto before_increment = *this; |
257 | ++offset_; |
258 | skipToSetType(); |
259 | return before_increment; |
260 | } |
261 | |
262 | inline ParallelType ParallelTypeBitmap::Iterator::operator*() const { |
263 | return kParallelTypeThreads[offset_]; |
264 | } |
265 | |
266 | inline ParallelTypeBitmap::Iterator::Iterator( |
267 | const ParallelTypeBitmap& map, |
268 | int offset) |
269 | : map_(map), offset_(offset) { |
270 | skipToSetType(); |
271 | } |
272 | |
273 | inline void ParallelTypeBitmap::Iterator::skipToSetType() { |
274 | while (offset_ < kOffsetEnd && !map_[offset_]) { |
275 | ++offset_; |
276 | } |
277 | } |
278 | |
279 | inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::begin( |
280 | const ParallelTypeBitmap& map) { |
281 | return Iterator(map, 0); |
282 | } |
283 | |
284 | inline ParallelTypeBitmap::Iterator ParallelTypeBitmap::Iterator::end( |
285 | const ParallelTypeBitmap& map) { |
286 | return Iterator(map, kOffsetEnd); |
287 | } |
288 | |
289 | //! Map from ParallelType to template type T |
290 | template <typename T> |
291 | class ParallelTypeMap { |
292 | public: |
293 | ParallelTypeMap() = default; |
294 | |
295 | ParallelTypeMap(const T& init) { |
296 | std::fill(map_.begin(), map_.end(), init); |
297 | } |
298 | |
299 | T& operator[](ParallelType pt) { |
300 | return map_[getParallelTypeBitMapOffset(pt)]; |
301 | } |
302 | |
303 | const T& operator[](ParallelType pt) const { |
304 | return map_[getParallelTypeBitMapOffset(pt)]; |
305 | } |
306 | |
307 | T& at(ParallelType pt) { |
308 | return map_.at(getParallelTypeBitMapOffset(pt)); |
309 | } |
310 | |
311 | const T& at(ParallelType pt) const { |
312 | return map_.at(getParallelTypeBitMapOffset(pt)); |
313 | } |
314 | |
315 | auto begin() { |
316 | return map_.begin(); |
317 | } |
318 | |
319 | auto begin() const { |
320 | return map_.begin(); |
321 | } |
322 | |
323 | auto end() { |
324 | return map_.begin(); |
325 | } |
326 | |
327 | auto end() const { |
328 | return map_.begin(); |
329 | } |
330 | |
331 | private: |
332 | std::array<T, ParallelTypeBitmap::kNumParallelTypes> map_; |
333 | }; |
334 | |
335 | } // namespace cuda |
336 | } // namespace fuser |
337 | } // namespace jit |
338 | } // namespace torch |
339 | |