1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <kernel_ir.h> |
5 | |
6 | #include <deque> |
7 | #include <unordered_map> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | //! Maps TID/BID to its dimension. It is by default blockDim/gridDim, |
15 | //! but if use of a ParallelType is mapped to a unique constant |
16 | //! extent, the constant value is used instead since presumably it's |
17 | //! more efficient. |
18 | class TORCH_CUDA_CU_API ParallelDimensionMap { |
19 | public: |
20 | void build(Fusion* fusion); |
21 | |
22 | //! Returns the dimension of a ParallelType. nullptr is returned if |
23 | //! a ParallelType is unused. |
24 | Val* get(ParallelType pt) const; |
25 | |
26 | //! True if the dimension of a ParallelType is known to be exact |
27 | bool isExact(ParallelType pt) const; |
28 | |
29 | std::string toString() const; |
30 | |
31 | //! Symbolically analyze if two extent vals are equal |
32 | static bool equalDim(Val* dim1, Val* dim2); |
33 | |
34 | private: |
35 | //! Register the extent of an IterDomain if its constant |
36 | void registerConstantExtent(IterDomain* id); |
37 | |
38 | void handleParallelDomain(IterDomain* id); |
39 | |
40 | void populateDimensionMapWithSingleCASet( |
41 | ParallelType pt, |
42 | const std::unordered_set<IterDomain*>& dom_set); |
43 | |
44 | void populateDimensionMapWithMultipleCASet( |
45 | ParallelType pt, |
46 | const std::unordered_set<IterDomain*>& dom_set); |
47 | |
48 | //! TIDx may need to be marked as non-exact as it may be padded to a |
49 | //! multiple of the warp size. |
50 | void adjustMappingsForWarpPadding(); |
51 | |
52 | static IterDomain* getCAMappedConcreteDomain(IterDomain* id); |
53 | |
54 | private: |
55 | //! Maps from parallel types to dimensions, which are constant if |
56 | //! a unique value is found. |
57 | std::unordered_map<ParallelType, Val*, TypeHash> dim_map_; |
58 | //! Set of parallel types whose dimensions are identified to be |
59 | //! exactly the same as extents of mapped domains. |
60 | std::unordered_set<ParallelType, TypeHash> exact_types_; |
61 | |
62 | // Below are temporary maps to build the ParallelType-to-dimension |
63 | // map. Only used during build(). |
64 | |
65 | //! Map from a parallel type to a set of concrete domains where the |
66 | //! parallel type is used. |
67 | std::unordered_map<ParallelType, std::unordered_set<IterDomain*>, TypeHash> |
68 | concrete_dom_map_; |
69 | //! Keep track of constant extents found for a CA domain set |
70 | //! represented by the concrete domain. |
71 | std::unordered_map<IterDomain*, std::unordered_set<int64_t>> |
72 | constant_extent_map_; |
73 | }; |
74 | |
75 | } // namespace cuda |
76 | } // namespace fuser |
77 | } // namespace jit |
78 | } // namespace torch |
79 | |