1#pragma once
2
3#include <disjoint_set.h>
4#include <ir_all_nodes.h>
5#include <kernel_ir.h>
6#include <lower_trivial_reductions.h>
7
8#include <deque>
9#include <unordered_map>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16// There's three modes of these iter domain mappings all uniquely important in
17// the lowering process.
18//
19// For EXACT/PERMISSIVE mode consider:
20//
21// consumer[i0, b1] = producer[i0]
22// consumer->merge(0) (consumer will now be [i0 * b1])
23// When producer is replayed as consumer (the direction we use for mapping)
24// with BestEffortReplay forward_bcast_mismatch = True the producer to
25// consumer map will have both a mapping of consumer(i0) to producer(i0) as
26// well as consumer(i0*b1) to producer(i0). This latter mapping is important
27// for loop nest mappings as the consumer will generate a loop based on i0*b1
28// and the producer may be computeAt inside this loop nest. However, for
29// indexing we do not want these two maps as producer may be indexed as i0*i1
30// depending on the loop nest structure and how it was built. Therefore we
31// really need to carry (at least) two sets of maps around for lowering.
32//
33// LOOP mode is important if we have something like:
34// consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt
35// = 1) which can easily happen when using shared memory. We want to make sure
36// that the iteration domain used for loop construction (concreteId) has the
37// proper parallelization strategy. In parallel mode we do typical iteration
38// domain mapping, however we remove from it any iteration domains outside the
39// computeAt of producer when mapping. This guarentees we won't map
40// IterDomains that could have different parallelization strategies. We also
41// propagate the parallel strategy in parallel mode so all mapped IDs that
42// must have the same parallel type, do.
43//
44// IdMappingMode::LOOP
45// Only maps leaf axes to left of compute at
46// Forward broadcast axes in replay
47// IdMappingMode::PERMISSIVE
48// Forward broadcast axes in replay
49// Map all iteration domains
50// Always contain root mappings (otherwise they could have been forwarded in
51// broadcast)
52// IdMappingMode::EXACT
53// Don't map any broadcast axes to non-broadcast axes
54// Do not forward through any broadcast IDs
55class TORCH_CUDA_CU_API IterDomainGraph {
56 public:
57 IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false);
58
59 const DisjointSets<IterDomain*>& permissiveNodes() const {
60 return permissive_nodes_;
61 }
62 const DisjointSets<IterDomain*>& exactNodes() const {
63 return exact_nodes_;
64 }
65 const DisjointSets<IterDomain*>& loopNodes() const {
66 return loop_nodes_;
67 }
68
69 // Consumers and producers is not symmetric like the other sets
70 const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
71 consumers() const {
72 return consumers_;
73 }
74 const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
75 producers() const {
76 return producers_;
77 }
78
79 const DisjointSets<IterDomain*>& siblings() const {
80 return sibling_sets_;
81 }
82
83 const VectorOfUniqueEntries<IterDomain*>& allIds() const {
84 return all_ids_;
85 }
86
87 const std::unordered_set<IterDomain*>& viewRfactorIds() const {
88 return view_rfactor_ids_;
89 }
90
91 // Returns if first and second are expressions through which the provided
92 // id_map have matching inputs (if forward), or outputs (if not forward).
93 // Returning true means the expressions are "the same", in terms they modify
94 // matching original extents, by the same amount.
95 static bool exprsMap(
96 Expr* first,
97 Expr* second,
98 bool forward,
99 const DisjointSets<IterDomain*>& id_map);
100
101 bool hasSelfMapping() const {
102 return self_mapping_info_.has_value();
103 }
104
105 private:
106 void build(Fusion* fusion);
107
108 void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id);
109
110 // Checks if exprsMap then if forward will map outputs else inputs in exact
111 // and permissive map.
112 void mapThroughExpr(Expr* first, Expr* second, bool forward);
113
114 DisjointSets<IterDomain*> permissive_nodes_;
115 DisjointSets<IterDomain*> exact_nodes_;
116 DisjointSets<IterDomain*> loop_nodes_;
117
118 // Consumers and producers is not symmetric like the other sets
119 std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
120 consumers_;
121 std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
122 producers_;
123
124 DisjointSets<IterDomain*> sibling_sets_;
125
126 VectorOfUniqueEntries<IterDomain*> all_ids_;
127
128 std::unordered_set<IterDomain*> view_rfactor_ids_;
129
130 c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
131 self_mapping_info_ = c10::nullopt;
132};
133
134class TrivialReductionInfo;
135
136using DoubleBufferIndices = std::unordered_map<DoubleBufferLoopStage, Int*>;
137
138class TORCH_CUDA_CU_API ComputeAtMap {
139 public:
140 ComputeAtMap() = delete;
141 ComputeAtMap(const ComputeAtMap&) = delete;
142 ComputeAtMap& operator=(const ComputeAtMap&) = delete;
143 ComputeAtMap(ComputeAtMap&&) = default;
144 ComputeAtMap& operator=(ComputeAtMap&&) = default;
145 ComputeAtMap(Fusion* fusion);
146
147 //! Run through disjoint sets in the LOOP map, make sure there's only one
148 //! non-serial parallel type in each disjoint set, set the parallel type of
149 //! all IterDomains in the disjoint set to that PType.
150 void validateAndPropagatePType();
151
152 //! Run through disjoint sets in the LOOP map and allocate the index
153 //! variable for the associated for loop that will be generated
154 //! for each disjoint sets in the loop map. This pre-allocation makes
155 //! 2 key assumptions about computeAt map that would very likely be
156 //! long term invariant:
157 //! 1. All kir::forloop created in the lowering pass should belong
158 //! to one of the disjoint sets in loop map.
159 //! 2. The lowering pass will *never* create a loop nest with 2
160 //! different nesting levels mapped together, i.e. the case below
161 //! never occurs:
162 //! for i in IterDomain1
163 //! for j in IterDomain2
164 //! ...
165 //! With loop_map.areMapped(IterDomain1, IterDomain2) == true.
166 //! Under this condition, we can pre-allocate all required index
167 //! variable integers before creating any kir::forloop, and this
168 //! would help optimizing the generated integer math for indexing.
169 void allocateIndexVariables();
170
171 //! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode
172 bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const;
173
174 //! Returns an iter domain that is the maximum expanded size of all iter
175 //! domains the one provided maps to. Useful for opening loops to the correct
176 //! iteration size. Not guarenteed to return the same ID every call, but is
177 //! guarenteed to return iter domains in the same disjoint set.
178 IterDomain* getConcreteMappedID(IterDomain* id, IdMappingMode mode) const;
179
180 // Prints mapping information, forwards to an internal IterDomainGraph
181 std::string toString() const;
182
183 // Returns if the provided ID is a view like rfactor id
184 bool isViewRfactor(IterDomain* ref_id) const;
185
186 // Returns all rfactor domains in rfactor_concrete_count_reset_domains_ that
187 // are in the disjoint set of the provided IterDomain. This will be every view
188 // like rfactor ID the provided ID "depends" on in the map.
189 std::vector<IterDomain*> getViewRfactorDomainsOfIdGroup(
190 IterDomain* ref_id,
191 IdMappingMode mode) const;
192
193 const IterDomainGraph& idGraph() const {
194 return id_graph_;
195 }
196
197 //! Get the ID sets for a provided IdMappingMode
198 const DisjointSets<IterDomain*>& getIdSets(IdMappingMode mode) const;
199
200 // Returns if the ID actually has a disjoint set meaning it has been processed
201 // in the creation of the compute at map.
202 bool idExistsInMap(IterDomain* id) const;
203
204 //! Returns the pre-allocated index variable integer used in
205 //! the kir::ForLoop corresponding to the given IterDomain.
206 //! this interface is only valid if the ID has a loop mapping,
207 //! ca_map will throw exceptions if given iterdomain doesn't
208 //! have a loop map entry.
209 Val* getIndexVariable(
210 IterDomain* id,
211 DoubleBufferLoopStage double_buffer_loop_stage =
212 DoubleBufferLoopStage::NotApplicable) const;
213
214 private:
215 // Build id_graph_
216 void build(Fusion* fusion);
217
218 // Build concrete_id_cache_
219 // Build a single entry in concrete_cache_id_
220 IterDomain* computeConcreteId(IterDomain* id, IdMappingMode mode);
221 void buildConcreteIds();
222
223 // Produce the disjoint set containing provided id with mapping mode.
224 const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& disjointSetOf(
225 IterDomain* id,
226 IdMappingMode mode) const;
227
228 // Should be built once and never modified again.
229 IterDomainGraph id_graph_;
230 TrivialReductionInfo trivial_reduction_info_;
231
232 // Prevent needing to recompute concrete_id's in compute at map.
233 // VectorOfUniqueEntries is unique across mapping modes, so don't need to use
234 // mapping mode directly in this cache. const
235 // VectorOfUniqueEntries<IterDomain*>& is what's returned by
236 // ComputeAtMap::disjointSetOf which can be used directly.
237 std::unordered_map<
238 std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
239 IterDomain*>
240 concrete_id_cache_;
241
242 //! Allocated Loop index variable through the CA map.
243 //! only valid for disjoint sets on the loop ca map.
244 std::unordered_map<const VectorOfUniqueEntries<IterDomain*>*, Val*>
245 loop_index_variable_map_;
246
247 //! Allocated loop indices for double buffer loop.
248 //! only valid for disjoint sets on the loop ca map
249 //! that have double buffer-ed iterdomains.
250 using DoubleBufferIndicesPtr = std::unique_ptr<DoubleBufferIndices>;
251 std::unordered_map<
252 const VectorOfUniqueEntries<IterDomain*>*,
253 DoubleBufferIndicesPtr>
254 double_buffered_loop_index_variable_map_;
255
256 // Shortcut to access the fusion this computeAt map was
257 // built from.
258 Fusion* fusion_;
259};
260
261} // namespace cuda
262} // namespace fuser
263} // namespace jit
264} // namespace torch
265