1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <disjoint_set.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <ir_iostream.h> |
8 | #include <iter_visitor.h> |
9 | #include <root_domain_map.h> |
10 | #include <unordered_map> |
11 | #include <vector> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | namespace { |
19 | |
20 | // Enable pair<IterDomain*, size_t> in a set, size_t must be unique in set |
21 | struct id_int_lt { |
22 | bool operator()( |
23 | const std::pair<IterDomain*, size_t>& first, |
24 | const std::pair<IterDomain*, size_t>& second) const { |
25 | return first.second < second.second; |
26 | } |
27 | }; |
28 | |
29 | } // namespace |
30 | |
31 | // Uses the history of _target_domain, and replays that history using the |
32 | // provided map. |
33 | // |
34 | // target_domain contains the history we want replayed. |
35 | // |
36 | // id_map maps IterDomains in that history to the IterDomains we want it |
37 | // replayed on. |
38 | // |
39 | // error_on_failure = true will cause the replay to error if we can't replay any |
40 | // operation in target_domain's history due to missing IDs in the id_map. |
41 | // |
42 | // If error_on_failure = false, replay will replay everything it can, and ignore |
43 | // operations it can't. |
44 | class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { |
45 | protected: |
46 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
47 | const std::vector<IterDomain*>& target_domain_; |
48 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
49 | std::unordered_map<IterDomain*, IterDomain*> id_map_; |
50 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
51 | std::unordered_map<IterDomain*, size_t> leaf_ids_; |
52 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
53 | std::vector<IterDomain*> leaf_vec_; |
54 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
55 | size_t counter = 0; |
56 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
57 | bool error_on_failure_ = true; |
58 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
59 | bool ran_replay = false; // Mark if replay has been run |
60 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
61 | bool replay_swizzle_ = false; |
62 | using IterVisitor::handle; |
63 | |
64 | // Transform dispatch |
65 | void handle(Expr* e) override; |
66 | |
67 | // We're going to replay this split operation on the corresponding ID |
68 | void handle(Split* s) override; |
69 | |
70 | // We're going to replay this merge operation on the corresponding IDs |
71 | void handle(Merge* m) override; |
72 | |
73 | // We're going to replay this swizzle operation on the corresponding IDs |
74 | // if replaying swizzle is enabled. |
75 | void handle(Swizzle2D* m) override; |
76 | |
77 | public: |
78 | ReplayTransformations( |
79 | const std::vector<IterDomain*>& _target_domain, |
80 | std::unordered_map<IterDomain*, IterDomain*> _id_map, |
81 | bool _error_on_failure = true, |
82 | |
83 | // Indicates if we want to replay swizzle ops on the replayed |
84 | // tensor. |
85 | // The swizzle op will be replayed if true, |
86 | // The swizzle inputs will be directly forwarded, and therefore skipping |
87 | // the swizzle op if false. |
88 | // Currently this options should always be off but |
89 | // later we may have cases in scheduling large fusions where |
90 | // this functionality could be useful. |
91 | bool replay_swizzle = false); |
92 | |
93 | // Replays outputs that were generated from ids.first on ids.second |
94 | void runReplay(); |
95 | |
96 | // Returns map from provided target domain to their corresponding IDs |
97 | const std::unordered_map<IterDomain*, IterDomain*>& getReplay() { |
98 | if (!ran_replay) |
99 | runReplay(); |
100 | return id_map_; |
101 | } |
102 | |
103 | // Returns leaf_ids_ the size_t marks the order in which they were put into |
104 | // the map, this is part of the structure because it's used to generate the |
105 | // order from 'getLeafIDs' |
106 | const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() { |
107 | if (!ran_replay) |
108 | runReplay(); |
109 | return leaf_ids_; |
110 | } |
111 | |
112 | // Returns all terminating IDs that resulted from the replay. Leaf IDs are run |
113 | // to run deterministic, but otherwise in no specific order. |
114 | const std::vector<IterDomain*>& getLeafIDs() { |
115 | if (!ran_replay) |
116 | runReplay(); |
117 | return leaf_vec_; |
118 | } |
119 | }; |
120 | |
121 | /* |
122 | * Motivation: |
123 | * |
124 | * Consider the following program: |
125 | * |
126 | * T1[I0, R1] = T0[I0, I1] |
127 | * T2[I0] = T1[I0, R1i] |
128 | * |
129 | * T1->split(1, factor) |
130 | * T1->rFactor(2) |
131 | * |
132 | * T4[I0, R1orf, I1irf] = T0[I0, I1] |
133 | * T1[I0, R1i] = T4[I0, R1orf, I1irf] |
134 | * T2[I0] = T1[I0, R1i] |
135 | * |
136 | * There's an issue when we call replayCasP on |
137 | * T4[I0, R1o, I1i] = T0[I0, I1] |
138 | * |
139 | * This would try to replay T4 as T0, and it could include the rfactor domains. |
140 | * For example we compute T0 inline with T4. The way computeAt is setup this |
141 | * would call replayPasC(T0, T4, -1) then repalyCasP(T4, T0, -1) |
142 | * |
143 | * We might assume that the only way we will hit this is if we call |
144 | * T4->computeAt(T0...) so it might be safe to assume that the right |
145 | * transformations would be replayed. However, we want to preserve the rfactor |
146 | * domain, so since it would replay T4 at root, it would produce iterdomains |
147 | * that wouldn't corresopnd to those in rfactor. Also, I don't know if this |
148 | * assumption is correct. |
149 | * |
150 | * Therefore, we will assume it is not correct, and we will validate here that |
151 | * if we replay a domain that it would transform it in a way consistent with |
152 | * any defined RFactor domains, then we will update the replay map so that |
153 | * RFactor roots are mapped to intermediate IterDomains in the target and start |
154 | * replay from there. |
155 | * |
156 | * |
157 | * SHORT DESCRIPTION: |
158 | * |
159 | * This class will validate/do the above. It will also run through |
160 | * transformations in target according to replay_map. If equal transformations |
161 | * already exist in replay_domain history, we will not redo those |
162 | * transformations, but instead update replay_map to reflect forwarding the |
163 | * existing transformations. This later part is the "best effort" replay. Though |
164 | * we include rfactor replay and validation here. |
165 | * |
166 | * Given an Expr in target_domain, check if its inputs are in replay_map. If so, |
167 | * check if the mapped domain in replay_map are recorded to be transformed by an |
168 | * equivelent operation in replay_domain's history. If so, "forward" the |
169 | * operation and update replay_map to the outputs of target_domain's output(s), |
170 | * to the output of the equivlent expr's outputs in relpay_domain's history. |
171 | * |
172 | * replay_map maps root IDs in the history of target_domain to root IDs in the |
173 | * history replay_domain |
174 | */ |
175 | |
176 | class TORCH_CUDA_CU_API BestEffortReplay { |
177 | private: |
178 | std::unordered_map<IterDomain*, IterDomain*> target2replay_id_map_; |
179 | std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map_; |
180 | std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map_; |
181 | std::unordered_map<IterDomain*, size_t> leaf_ids_; |
182 | std::vector<IterDomain*> forwarded_ids_; |
183 | |
184 | // Need to track which id's have been forwarded. Later need to make sure leaf |
185 | // nodes to produce compliment axes are properly tracked. i.e. |
186 | // T[i0, b1, b2, i3] |
187 | // -> T[i0, b1o, b1i, b2o, b2i, i3] |
188 | // -> T[i0*b1i*b2o, b1o, b2i, i3] |
189 | // -> T[i0*b1i*b2o*i3, b1o, b2i] |
190 | // If we forwarded i0 -> i0*b1i*b2o*i3, we need to know that b1o and b2i |
191 | // are leaf nodes even though their split wasn't part of targets replay. |
192 | |
193 | // Counter to make sure best effort replay leaf_ids can be grabbed |
194 | // deterministicly |
195 | size_t counter = 0; |
196 | |
197 | // Determine if current replay will ignore swizzle ops. |
198 | // When not skipping swizzles, swizzle ops will have to be matched |
199 | // same way as split and merge to progress forward on the mapping. |
200 | // |
201 | // When skipping swizzles, mismatched swizzle ops will not stop matching |
202 | // further down the tensor domains but only the swizzle outputs will be on |
203 | // the target to replay map, since we only generate one-to-one maps in |
204 | // BestEffortReplay and the swizzle outputs is just picked as a convention |
205 | // for simpler and uniform mapping behavior. The swizzle op inputs will be |
206 | // added by the disjoint set passes when building the iterdomain graph. |
207 | // |
208 | // Example: |
209 | // Target: |
210 | // I0o, I0i = split I0 |
211 | // Ix0o, Ix0i = swizzle I0o, I0i |
212 | // I02 = merge Ix0o, Ix0i |
213 | // Replay: |
214 | // I1o, I1i = split I1 |
215 | // I12 = merge I1o, I1i |
216 | // |
217 | // BestEffortReplay **no** skip swizzle gives: |
218 | // { |
219 | // I0->I1, |
220 | // I0o->I1o, |
221 | // I0i->I1i, |
222 | // } |
223 | // |
224 | // BestEffortReplay skip swizzle gives: |
225 | // { |
226 | // I0->I1, |
227 | // Ix0o->I1o, |
228 | // Ix0i->I1i, |
229 | // I02->I12 |
230 | // } |
231 | // |
232 | bool skip_swizzle_ = true; |
233 | |
234 | bool inReplayForwardMap(IterDomain* id) const { |
235 | return replay_forward_id_map_.find(id) != replay_forward_id_map_.end(); |
236 | } |
237 | |
238 | bool inTargetForwardMap(IterDomain* id) const { |
239 | return target_forward_id_map_.find(id) != target_forward_id_map_.end(); |
240 | } |
241 | |
242 | IterDomain* getReplayForwardedId(IterDomain* id) const { |
243 | auto forwarded_id_it = replay_forward_id_map_.find(id); |
244 | if (forwarded_id_it == replay_forward_id_map_.end()) { |
245 | return id; |
246 | } else { |
247 | return getReplayForwardedId(forwarded_id_it->second); |
248 | } |
249 | } |
250 | |
251 | IterDomain* getTargetForwardedId(IterDomain* id) const { |
252 | auto forwarded_id_it = target_forward_id_map_.find(id); |
253 | if (forwarded_id_it == target_forward_id_map_.end()) { |
254 | return id; |
255 | } else { |
256 | return getTargetForwardedId(forwarded_id_it->second); |
257 | } |
258 | } |
259 | |
260 | //! Adds complimenting IDs of forwarded IDs to the leaf map |
261 | void addComplimentLeafIDs( |
262 | const std::unordered_map<IterDomain*, IterDomain*>& forwarding_map, |
263 | const std::unordered_map<IterDomain*, std::vector<IterDomain*>>& |
264 | compliment_map); |
265 | |
266 | // Skip swizzle step to make sure both target and |
267 | // replay swizzles are skipped while the mapping |
268 | // makes progress. This makes sure that, for example |
269 | // different tensors can still be inlined despite |
270 | // different local swizzle patterns. |
271 | void skipSwizzles( |
272 | const std::unordered_map<IterDomain*, Expr*>& target_id2expr, |
273 | const std::unordered_map<IterDomain*, Expr*>& replay_id2expr); |
274 | |
275 | public: |
276 | BestEffortReplay( |
277 | const std::vector<IterDomain*>& replay_domain, |
278 | const std::vector<IterDomain*>& target_domain, |
279 | std::unordered_map<IterDomain*, IterDomain*> target2replay_map, |
280 | std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map = {}, |
281 | std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map = {}, |
282 | bool skip_swizzle = true); |
283 | |
284 | // Return iter domain map from target_domain IDs to their "replayed" |
285 | // replay_domain IDs. If not in map, was not replayed. |
286 | const std::unordered_map<IterDomain*, IterDomain*>& getReplay() const { |
287 | return target2replay_id_map_; |
288 | } |
289 | |
290 | // ids in replay that did not have matching transforms in target_domain |
291 | const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() { |
292 | return leaf_ids_; |
293 | } |
294 | |
295 | // Returned ordered set of IDs in getUnorderedLeafIDs |
296 | std::vector<IterDomain*> getLeafIDs() { |
297 | std::set<std::pair<IterDomain*, size_t>, id_int_lt> ordered_set; |
298 | for (auto entry : leaf_ids_) |
299 | ordered_set.emplace(entry); |
300 | |
301 | std::vector<IterDomain*> leaf_vec_; |
302 | leaf_vec_.resize(ordered_set.size()); |
303 | std::transform( |
304 | ordered_set.begin(), |
305 | ordered_set.end(), |
306 | leaf_vec_.begin(), |
307 | [](std::pair<IterDomain*, size_t> entry) { return entry.first; }); |
308 | return leaf_vec_; |
309 | } |
310 | |
311 | DisjointSets<IterDomain*> getDisjointSets(); |
312 | |
313 | // Runs a best effort replay that ignores broadcast axes that appear in |
314 | // consumer that are not mapped to producer in root_map. |
315 | static BestEffortReplay replayCasP( |
316 | const TensorView* consumer, |
317 | const TensorView* producer, |
318 | int producer_compute_at_axis, |
319 | const RootDomainMap& root_map); |
320 | |
321 | // Runs a best effort replay that ignores broadcast axes that appear in |
322 | // consumer that are not mapped to producer in root_map. |
323 | static BestEffortReplay replayPasC( |
324 | const TensorView* producer, |
325 | const TensorView* consumer, |
326 | int consumer_compute_at_axis, |
327 | const RootDomainMap& root_map); |
328 | |
329 | // Find the first position i where td1[i] is not the same as td2[i]. "Same" |
330 | // means the DAG and input IDs to generate td1[i] and td2[i] are the same. |
331 | // td1 and td2 are assumed to have some matching iter domains, as this is a |
332 | // strict same-ness check. |
333 | static int findFirstMismatchedID( |
334 | const TensorDomain* td1, |
335 | const TensorDomain* td2); |
336 | }; |
337 | |
338 | } // namespace cuda |
339 | } // namespace fuser |
340 | } // namespace jit |
341 | } // namespace torch |
342 | |