1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | |
5 | #include <functional> |
6 | #include <unordered_map> |
7 | #include <vector> |
8 | |
9 | // Hoisting common index subexpressions |
10 | // |
11 | // Class CommonIndexMap is updated during the lowering as new indices |
12 | // are inserted. An index is uniquely identified with CommonIndexKey, |
13 | // which consists of the concrete ID of the indexed/predicated domain, |
14 | // the for-loops used in the index, and the index vals of the use |
15 | // for-loops. |
16 | // |
17 | // Once all indices are inserted to CommonIndexMap, allocations of the |
18 | // the hoisted indices are inserted by allocateCommonIndices. Note |
19 | // that this assumes that the CUDA code generator does not inline a |
20 | // scalar Val with allocation (PR #1434). |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | namespace fuser { |
25 | namespace cuda { |
26 | |
27 | //! Class to represent unique indexed domains for index |
28 | //! hoisting. Uniquenesss is determined with the indexed domain |
29 | //! itself, the for-loops and their index values. |
30 | class TORCH_CUDA_CU_API CommonIndexKey { |
31 | friend struct CommonIndexKeyHash; |
32 | |
33 | public: |
34 | //! \param consumer_indexed_id Indexed consumer domain |
35 | //! \param consumer_td TensorDomain of consumer_indexed_id |
36 | //! \param ref_td Reference domain at the time of indexing |
37 | //! \param ref_index_map Index map of the reference domain |
38 | //! \param loops Loop structure where this id is indexed |
39 | CommonIndexKey( |
40 | IterDomain* consumer_indexed_id, |
41 | TensorDomain* consumer_td, |
42 | TensorDomain* ref_td, |
43 | const std::unordered_map<IterDomain*, Val*>& ref_index_map, |
44 | const std::vector<kir::ForLoop*>& loops); |
45 | |
46 | //! \param consumer_indexed_id Indexed consumer domain |
47 | //! \param consumer_td TensorDomain of consumer_indexed_id |
48 | //! \param loop_domains Resolved vector of iterdomain corresponding to loops |
49 | //! \param loop_index_map Index mapping generated from the loop nest. |
50 | //! \param loops Loop structure where this id is indexed |
51 | //! Duplicate of above, but without a reference domain. TODO: Remove other |
52 | //! implementation. |
53 | CommonIndexKey( |
54 | IterDomain* consumer_indexed_id, |
55 | TensorDomain* consumer_td, |
56 | const std::vector<IterDomain*>& loop_domains, |
57 | const std::unordered_map<IterDomain*, Val*>& loop_index_map, |
58 | const std::vector<kir::ForLoop*>& loops); |
59 | |
60 | const IterDomain* concreteIndexedId() const { |
61 | return concrete_indexed_id_; |
62 | } |
63 | |
64 | const std::vector<kir::ForLoop*>& usedLoops() const { |
65 | return used_loops_; |
66 | } |
67 | |
68 | const std::vector<Val*>& loopIndexVals() const { |
69 | return loop_index_vals_; |
70 | } |
71 | |
72 | bool operator==(const CommonIndexKey& other) const; |
73 | |
74 | std::string toString() const; |
75 | |
76 | private: |
77 | //! Concrete domain of indexed domain |
78 | IterDomain* concrete_indexed_id_ = nullptr; |
79 | //! Loops used for the index |
80 | std::vector<kir::ForLoop*> used_loops_; |
81 | //! Loop index vals for the used loops |
82 | std::vector<Val*> loop_index_vals_; |
83 | }; |
84 | |
85 | struct CommonIndexKeyHash { |
86 | std::size_t operator()(const CommonIndexKey& key) const { |
87 | auto h = std::hash<const IterDomain*>{}(key.concrete_indexed_id_); |
88 | // NOTE: do not use other fields as the pointers can be different |
89 | // even when two keys can share the same index |
90 | return h; |
91 | } |
92 | }; |
93 | |
94 | //! Map to hold hoisted common indices |
95 | class TORCH_CUDA_CU_API CommonIndexMap { |
96 | public: |
97 | //! Register an indexd consumer domain to hoist |
98 | //! |
99 | //! Returns a corresponding hoisted index and a flag indicating if a |
100 | //! new index is inserted. |
101 | //! |
102 | //! Consumer domains are used even for producer indexing since |
103 | //! producer domains in producer indexing are temporary replay |
104 | //! domains. |
105 | std::pair<Val*, bool> insert( |
106 | IterDomain* indexed_consumer_id, |
107 | TensorDomain* consumer_td, |
108 | TensorDomain* ref_td, |
109 | const std::unordered_map<IterDomain*, Val*>& ref_index_map, |
110 | const std::vector<kir::ForLoop*>& loops, |
111 | Val* index); |
112 | |
113 | //! Duplicate of above, but without a reference domain. TODO: Remove other |
114 | //! implementation. |
115 | std::pair<Val*, bool> insert( |
116 | IterDomain* indexed_consumer_id, |
117 | TensorDomain* consumer_td, |
118 | const std::vector<IterDomain*>& loop_domains, |
119 | const std::unordered_map<IterDomain*, Val*>& loop_index_map, |
120 | const std::vector<kir::ForLoop*>& loops, |
121 | Val* index); |
122 | |
123 | const auto& commonIndexMap() const { |
124 | return common_index_map_; |
125 | } |
126 | |
127 | const auto& useCounts() const { |
128 | return use_counts_; |
129 | } |
130 | |
131 | private: |
132 | //! Utility method to insert a key into common index |
133 | //! map. Returns a pair of an IR node and a boolean value. |
134 | //! The IR node will be the previously inserted index if |
135 | //! the key found a match, or will be the original index |
136 | //! if this is new key and the key will be stored. |
137 | //! The boolean value will be true if the key is stored, |
138 | //! i.e. first time it is inserted. |
139 | std::pair<Val*, bool> tryInsertNewIndex(CommonIndexKey key, Val* index); |
140 | |
141 | private: |
142 | //! Map to hold hoisted common indices |
143 | std::unordered_map<CommonIndexKey, Val*, CommonIndexKeyHash> |
144 | common_index_map_; |
145 | std::unordered_map<CommonIndexKey, int, CommonIndexKeyHash> use_counts_; |
146 | }; |
147 | |
148 | //! Insert allocations of hoisted indices. Must be called after |
149 | //! collecting all common indices. |
150 | std::vector<Expr*> allocateCommonIndices(const std::vector<Expr*>& exprs); |
151 | |
152 | } // namespace cuda |
153 | } // namespace fuser |
154 | } // namespace jit |
155 | } // namespace torch |
156 | |