1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <parallel_type_bitmap.h> |
5 | |
6 | #include <unordered_map> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | class SyncMap { |
14 | public: |
15 | std::string toString() const; |
16 | |
17 | //! Validates all tensors are consistently parallelized. Basically, |
18 | //! when a producer axis is threaded, either with threadIdx or |
19 | //! blockIdx, there must be a mapped consumer axis with the |
20 | //! same ParallelType with some exceptions. |
21 | //! |
22 | //! This function assumes Loop and Parallel ComputeAtMaps are already |
23 | //! built as they are used to validate consistency. |
24 | //! |
25 | //! Fills needs_raw_sync with output TVs if they need a raw sync if on smem or |
26 | //! gmem. The second entry in this map is the parallel dimensions being |
27 | //! communicated across. |
28 | void build(Fusion* fusion); |
29 | |
30 | ParallelTypeBitmap needsRawSync(TensorView* tv) const { |
31 | auto it = needs_raw_sync_.find(tv); |
32 | if (it != needs_raw_sync_.end()) { |
33 | return it->second; |
34 | } |
35 | return ParallelTypeBitmap(); |
36 | } |
37 | |
38 | private: |
39 | std::unordered_map<TensorView*, ParallelTypeBitmap> needs_raw_sync_; |
40 | }; |
41 | |
42 | } // namespace cuda |
43 | } // namespace fuser |
44 | } // namespace jit |
45 | } // namespace torch |
46 | |