1 | #include <transform_replay.h> |
2 | |
3 | #include <arith.h> |
4 | #include <disjoint_set.h> |
5 | #include <fusion.h> |
6 | #include <instrumentation.h> |
7 | #include <ir_all_nodes.h> |
8 | #include <ir_builder.h> |
9 | #include <ir_iostream.h> |
10 | #include <ir_utils.h> |
11 | #include <maxinfo_propagator.h> |
12 | #include <root_domain_map.h> |
13 | #include <transform_iter.h> |
14 | |
15 | #include <deque> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | namespace fuser { |
20 | namespace cuda { |
21 | |
22 | using id_map = std::unordered_map<IterDomain*, IterDomain*>; |
23 | |
24 | namespace { |
25 | |
26 | class ReplaySelf : public ReplayTransformations { |
27 | private: |
28 | // Took a good bit of this from ReplayTransformations::handle(Split...) |
29 | void handle(Split* s) override { |
30 | // Grab input to the split operation |
31 | auto id_in = s->in(); |
32 | |
33 | // Grab our mapping of that ID to the one we're replaying |
34 | auto it = id_map_.find(id_in); |
35 | |
36 | // Make sure it exists in the map |
37 | TORCH_INTERNAL_ASSERT( |
38 | it != id_map_.end(), |
39 | "Transform traversal failed, dependencies not met." ); |
40 | // Grab the ID we're going to replay on |
41 | auto mapped = it->second; |
42 | |
43 | // This ID should be a leaf ID (meaning it has no uses we generated) |
44 | TORCH_INTERNAL_ASSERT( |
45 | leaf_ids_.find(mapped) != leaf_ids_.end(), |
46 | "Transform traversal failed, modified a node but it was not a leaf node." ); |
47 | |
48 | // outer loop size |
49 | Val* remainder = ceilDiv( |
50 | Split::extent(mapped->extent(), s->startOffset(), s->stopOffset()), |
51 | s->factor()); |
52 | |
53 | // Manually replay the split, following the output of the operations. |
54 | // This is so rfactor ops are replayed correctly. |
55 | IterDomain* ido = |
56 | IterDomainBuilder(s->outer()) |
57 | .start(s->container()->zeroVal()) |
58 | .extent(s->innerSplit() ? remainder->as<Int>() : s->factor()) |
59 | .build(); |
60 | |
61 | // inner IterDomain |
62 | IterDomain* idi = |
63 | IterDomainBuilder(s->inner()) |
64 | .start(s->container()->zeroVal()) |
65 | .extent(s->innerSplit() ? s->factor() : remainder->as<Int>()) |
66 | .build(); |
67 | |
68 | // Generate the split node |
69 | IrBuilder::create<Split>( |
70 | s->container(), |
71 | ido, |
72 | idi, |
73 | mapped, |
74 | s->factor(), |
75 | s->innerSplit(), |
76 | s->startOffset(), |
77 | s->stopOffset()); |
78 | |
79 | // Remove mapped id from leaf IDs |
80 | leaf_ids_.erase(mapped); |
81 | |
82 | // Add outputs to leaf IDs |
83 | leaf_ids_[ido] = counter++; |
84 | leaf_ids_[idi] = counter++; |
85 | |
86 | // Update our ID map to include these outputs |
87 | id_map_[s->outer()] = ido; |
88 | id_map_[s->inner()] = idi; |
89 | } |
90 | |
91 | void handle(Merge* m) override { |
92 | auto id_outer = m->outer(); |
93 | auto id_inner = m->inner(); |
94 | |
95 | auto it_outer = id_map_.find(id_outer); |
96 | auto it_inner = id_map_.find(id_inner); |
97 | |
98 | TORCH_INTERNAL_ASSERT( |
99 | it_outer != id_map_.end() && it_inner != id_map_.end(), |
100 | "Transform traversal failed, dependencies not met." ); |
101 | |
102 | auto id_outer_mapped = it_outer->second; |
103 | auto id_inner_mapped = it_inner->second; |
104 | |
105 | TORCH_INTERNAL_ASSERT( |
106 | leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && |
107 | leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), |
108 | "Transform traversal failed, modified " , |
109 | id_outer_mapped, |
110 | " and " , |
111 | id_inner_mapped, |
112 | " however one or both are not leaf nodes." ); |
113 | |
114 | Val* merged_id_size = |
115 | mul(id_outer_mapped->extent(), id_inner_mapped->extent()); |
116 | |
117 | IterDomain* merged_id = IterDomainBuilder(m->out()) |
118 | .start(m->container()->zeroVal()) |
119 | .extent(merged_id_size->as<Int>()) |
120 | .build(); |
121 | |
122 | IrBuilder::create<Merge>( |
123 | m->container(), merged_id, id_outer_mapped, id_inner_mapped); |
124 | |
125 | // Remove inputs from the leaf IDs |
126 | leaf_ids_.erase(id_outer_mapped); |
127 | leaf_ids_.erase(id_inner_mapped); |
128 | |
129 | // Add the output to the leaf IDs |
130 | leaf_ids_[merged_id] = counter++; |
131 | |
132 | id_map_[m->out()] = merged_id; |
133 | } |
134 | |
135 | public: |
136 | ReplaySelf(const std::vector<IterDomain*>& _target_domain, id_map _id_map) |
137 | : ReplayTransformations(_target_domain, std::move(_id_map), false) {} |
138 | }; |
139 | |
140 | } // namespace |
141 | |
142 | // Self replay. |
143 | TensorDomain* TransformReplay::fullSelfReplay( |
144 | const TensorDomain* new_self_root, |
145 | const TensorDomain* self) { |
146 | FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay" ); |
147 | |
148 | TORCH_INTERNAL_ASSERT( |
149 | new_self_root->getRootDomain().size() == self->getRootDomain().size(), |
150 | "Invalid number of IterDomains provided." ); |
151 | |
152 | // Map for replay, should be pretty simple. |
153 | id_map axis_map; |
154 | { |
155 | size_t i = 0; |
156 | for (auto id : self->getRootDomain()) { |
157 | TORCH_INTERNAL_ASSERT( |
158 | new_self_root->getRootDomain()[i]->isReduction() == |
159 | id->isReduction() && |
160 | new_self_root->getRootDomain()[i]->isRFactorProduct() == |
161 | id->isRFactorProduct() && |
162 | new_self_root->getRootDomain()[i]->isBroadcast() == |
163 | id->isBroadcast(), |
164 | "Axes " , |
165 | id, |
166 | " and " , |
167 | new_self_root->getRootDomain()[i], |
168 | " do not match for self replay." ); |
169 | axis_map[id] = new_self_root->getRootDomain()[i]; |
170 | i++; |
171 | } |
172 | } |
173 | |
174 | // Replay producer dimensions. |
175 | ReplaySelf replay(self->domain(), axis_map); |
176 | std::vector<IterDomain*> new_domain(self->nDims(), nullptr); |
177 | |
178 | { |
179 | size_t i = 0; |
180 | for (auto id : self->domain()) { |
181 | auto it = replay.getReplay().find(id); |
182 | TORCH_INTERNAL_ASSERT( |
183 | it != replay.getReplay().end(), |
184 | "Error during replay, didn't replay an axis." ); |
185 | new_domain[i++] = it->second; |
186 | } |
187 | |
188 | if (self->hasRFactor()) { |
189 | std::vector<IterDomain*> new_rfactor_domain( |
190 | self->getMaybeRFactorDomain().size(), nullptr); |
191 | size_t i = 0; |
192 | for (auto id : self->getMaybeRFactorDomain()) { |
193 | auto it = replay.getReplay().find(id); |
194 | TORCH_INTERNAL_ASSERT( |
195 | it != replay.getReplay().end(), |
196 | "Error during replay, didn't replay an axis." ); |
197 | new_rfactor_domain[i++] = it->second; |
198 | } |
199 | return IrBuilder::create<TensorDomain>( |
200 | self->container(), |
201 | new_self_root->getRootDomain(), |
202 | new_rfactor_domain, |
203 | new_domain, |
204 | self->contiguity()); |
205 | } |
206 | } |
207 | |
208 | return IrBuilder::create<TensorDomain>( |
209 | self->container(), |
210 | new_self_root->getRootDomain(), |
211 | new_domain, |
212 | new_self_root->contiguity()); |
213 | } |
214 | |
215 | // Producer could have rfactor axes which consumer may want replayed. We can |
216 | // "replay" them as long as it doesn't modify the root rfactor axes. What we |
217 | // really want to do is validate if we replayed these axes to the ones they |
218 | // mapped to in the consumer the operations would all be the same. then we want |
219 | // to start the replay of the producer from the rfactor root axes, not the root. |
220 | std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC( |
221 | const TensorView* producer, |
222 | const TensorView* consumer, |
223 | int consumer_compute_at_axis, |
224 | const RootDomainMap& root_map, |
225 | bool replay_swizzle) { |
226 | FUSER_PERF_SCOPE("TransformReplay::replayPasC" ); |
227 | |
228 | // If this is a reduction operation, we may call transform_replay on the |
229 | // tensor view. When this happens, just return thet target view. |
230 | if (producer == consumer) |
231 | return {producer->domain(), producer->nDims()}; |
232 | |
233 | if (consumer_compute_at_axis < 0) |
234 | consumer_compute_at_axis += (int)consumer->nDims() + 1; |
235 | TORCH_INTERNAL_ASSERT( |
236 | consumer_compute_at_axis >= 0 && |
237 | (unsigned int)consumer_compute_at_axis <= consumer->nDims(), |
238 | "Invalid axis in transform replayPasC." ); |
239 | |
240 | // consumer ids we need to match in producer |
241 | std::vector<IterDomain*> consumer_CA_ids( |
242 | consumer->domain()->domain().begin(), |
243 | consumer->domain()->domain().begin() + consumer_compute_at_axis); |
244 | |
245 | // Instead of replaying from the root, lets try to play forward the history of |
246 | // producer if they match ops on consumer. Enforce if we modify an rfactor |
247 | // axis that those ops must match. |
248 | auto forward_replay = BestEffortReplay::replayPasC( |
249 | producer, consumer, consumer_compute_at_axis, root_map); |
250 | |
251 | // Make a new map based on all the leaves resulting from best effort replay |
252 | id_map forwarded_replay_map; |
253 | auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs(); |
254 | for (auto entry : forward_replay.getReplay()) { |
255 | if (forward_dangling_leaves.find(entry.second) != |
256 | forward_dangling_leaves.end()) { |
257 | forwarded_replay_map[entry.first] = entry.second; |
258 | forward_dangling_leaves.erase(entry.second); |
259 | } |
260 | } |
261 | |
262 | // Replay producer dimensions. |
263 | ReplayTransformations replay_PasC( |
264 | consumer_CA_ids, forwarded_replay_map, false, replay_swizzle); |
265 | |
266 | auto leaf_ids(replay_PasC.getUnorderedLeafIDs()); |
267 | |
268 | // Remove all ids that map to the compute at axis, we're going to replay the |
269 | // rest, track all dims needed to match consumer CA dims |
270 | std::vector<IterDomain*> needed_dims; |
271 | for (auto c_id : consumer_CA_ids) { |
272 | auto it = replay_PasC.getReplay().find(c_id); |
273 | if (it == replay_PasC.getReplay().end()) { |
274 | TORCH_INTERNAL_ASSERT( |
275 | c_id->isBroadcast() || c_id->isGather() || c_id->isVectorComponent(), |
276 | "Could not find axis, " , |
277 | c_id, |
278 | ", requested in replay." ); |
279 | continue; |
280 | } |
281 | TORCH_INTERNAL_ASSERT( |
282 | leaf_ids.find(it->second) != leaf_ids.end(), |
283 | "Replayed id to match consumer id " , |
284 | c_id, |
285 | " should be a leaf in replay map." ); |
286 | leaf_ids.erase(it->second); |
287 | needed_dims.push_back(it->second); |
288 | } |
289 | |
290 | // leaf_ids now contains all producer ID products that are not used to satisfy |
291 | // the computeAt Turn into a map so we can play forward these IDs in producer |
292 | // (if possible): |
293 | id_map producer_self_replay_map; |
294 | for (auto entry : leaf_ids) { |
295 | producer_self_replay_map[entry.first] = entry.first; |
296 | } |
297 | |
298 | for (auto entry : forward_dangling_leaves) { |
299 | producer_self_replay_map[entry.first] = entry.first; |
300 | } |
301 | |
302 | // Check which root domains were used to produce the leaf_ids. We may have |
303 | // picked up extra roots in consumer because of broadcast forwarding. |
304 | std::vector<Val*> unordered_non_root_leaf_vals; |
305 | for (auto leaf_id : replay_PasC.getUnorderedLeafIDs()) { |
306 | if (leaf_id.first->definition() == nullptr) { |
307 | continue; |
308 | } else { |
309 | unordered_non_root_leaf_vals.emplace_back(leaf_id.first); |
310 | } |
311 | } |
312 | |
313 | auto producer_root = producer->getMaybeRFactorDomain(); |
314 | |
315 | // Figure out all id's that have been processed to generate the |
316 | // unordered_non_root_leaf_vals. This needs to be done because we want to |
317 | // match on producer's rfactor domain, not root domain. |
318 | std::unordered_set<IterDomain*> all_processed_ids; |
319 | { |
320 | auto all_processed_vals_vec = DependencyCheck::getAllValsBetween( |
321 | {producer_root.begin(), producer_root.end()}, |
322 | unordered_non_root_leaf_vals); |
323 | auto all_processed_ids_vec = |
324 | ir_utils::filterByType<IterDomain>(all_processed_vals_vec); |
325 | all_processed_ids.insert( |
326 | all_processed_ids_vec.begin(), all_processed_ids_vec.end()); |
327 | } |
328 | |
329 | // Any root domain that was not used to generate computeIDs we can also put in |
330 | // the map to forward their transformations. |
331 | for (auto producer_root_id : producer_root) { |
332 | if (all_processed_ids.find(producer_root_id) == all_processed_ids.end() && |
333 | std::find(needed_dims.begin(), needed_dims.end(), producer_root_id) == |
334 | needed_dims.end()) { |
335 | producer_self_replay_map[producer_root_id] = producer_root_id; |
336 | } |
337 | } |
338 | |
339 | // Play forward transformations all producer IDs we can |
340 | auto producer_replayed_leaves = BestEffortReplay( |
341 | producer->domain()->domain(), |
342 | producer->domain()->domain(), |
343 | producer_self_replay_map); |
344 | |
345 | /* |
346 | * Accumulate axes in to the new domain in the following order, making sure to |
347 | * avoid any duplicates: |
348 | * |
349 | * (1) replay_PasC.getReplay holds mappings from axes in consumer compute at |
350 | * axes -> corresponding generated axes in producer |
351 | * |
352 | * (2) Any axes that were not added, that can be mapped directly from an ID in |
353 | * consumer->domain(). These are axes that were "fully replayed" relative to |
354 | * the consumer, even though it wasn't in the computeAt range. |
355 | * |
356 | * producer_replayed_leaves now contain ids that we tried to forward |
357 | * back to what they were in producer. If they couldn't be forwarded they're |
358 | * left in their "most forwarded" form which may be just a remainder of the |
359 | * transformation required to generate the computeAt axes. |
360 | * |
361 | * (3) Axes in producer->domain() that are in producer_replayed_leaves |
362 | * |
363 | * (4) Axes not in producer->domain() that are in producer_replayed_leaves |
364 | * |
365 | */ |
366 | |
367 | std::vector<IterDomain*> new_IDs; |
368 | std::unordered_set<IterDomain*> used_IDs; |
369 | // Add axes in (1) |
370 | for (auto c_id : consumer_CA_ids) { |
371 | auto it = replay_PasC.getReplay().find(c_id); |
372 | if (it == replay_PasC.getReplay().end()) { |
373 | TORCH_INTERNAL_ASSERT( |
374 | c_id->isBroadcast() || c_id->isGather() || c_id->isVectorComponent(), |
375 | "Could not find axis, " , |
376 | c_id, |
377 | ", requested in replay." ); |
378 | continue; |
379 | } |
380 | new_IDs.push_back(it->second); |
381 | used_IDs.emplace(it->second); |
382 | } |
383 | |
384 | unsigned int producer_compute_at_axis = new_IDs.size(); |
385 | |
386 | // Add axes in (2) |
387 | for (auto c_id : consumer->domain()->domain()) { |
388 | auto it = replay_PasC.getReplay().find(c_id); |
389 | if (it != replay_PasC.getReplay().end()) { |
390 | auto id = it->second; |
391 | // If the leaf id from ReplayTransformations is used to move |
392 | // forward in BestEffortReplay, it is not a final ID. |
393 | if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) == |
394 | producer_replayed_leaves.getUnorderedLeafIDs().end()) { |
395 | continue; |
396 | } |
397 | if (used_IDs.find(id) == used_IDs.end()) { |
398 | new_IDs.push_back(id); |
399 | used_IDs.emplace(id); |
400 | } |
401 | } |
402 | } |
403 | |
404 | // Add axes in (3) |
405 | for (auto id : producer->domain()->domain()) { |
406 | if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) != |
407 | producer_replayed_leaves.getUnorderedLeafIDs().end()) { |
408 | if (used_IDs.find(id) == used_IDs.end()) { |
409 | new_IDs.push_back(id); |
410 | used_IDs.emplace(id); |
411 | } |
412 | } |
413 | } |
414 | |
415 | // Add axes in (4) |
416 | for (auto id : producer_replayed_leaves.getLeafIDs()) { |
417 | if (used_IDs.find(id) == used_IDs.end()) { |
418 | new_IDs.push_back(id); |
419 | } |
420 | } |
421 | TensorDomain* replayed = IrBuilder::create<TensorDomain>( |
422 | producer->container(), |
423 | producer->getRootDomain(), |
424 | producer->getRFactorDomain(), |
425 | new_IDs, |
426 | producer->domain()->contiguity()); |
427 | |
428 | return {replayed, producer_compute_at_axis}; |
429 | } |
430 | |
431 | std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP( |
432 | const TensorView* consumer, |
433 | const TensorView* producer, |
434 | int producer_compute_at_axis, |
435 | const RootDomainMap& root_map) { |
436 | FUSER_PERF_SCOPE("TransformReplay::replayCasP" ); |
437 | |
438 | // If this is a reduction operation, we may call transform_replay on the same |
439 | // tensor view. When this happens, just return thet target view. |
440 | if (consumer == producer) |
441 | return {consumer->domain(), consumer->nDims()}; |
442 | |
443 | if (producer_compute_at_axis < 0) |
444 | producer_compute_at_axis += (int)producer->nDims() + 1; |
445 | |
446 | TORCH_INTERNAL_ASSERT( |
447 | producer_compute_at_axis >= 0 && |
448 | (unsigned int)producer_compute_at_axis <= producer->nDims(), |
449 | "Invalid axis in transform replayCasP." ); |
450 | |
451 | // producer ids we need to match in consumer |
452 | std::vector<IterDomain*> producer_CA_ids( |
453 | producer->domain()->domain().begin(), |
454 | producer->domain()->domain().begin() + producer_compute_at_axis); |
455 | producer_CA_ids = TensorDomain::noReductions(producer_CA_ids); |
456 | |
457 | // Instead of replaying from the root, lets try to forward the history of |
458 | // consumer if they match ops on producer. Enforce if we modify an rfactor |
459 | // axis that those ops match. |
460 | BestEffortReplay forward_replay = BestEffortReplay::replayCasP( |
461 | consumer, producer, producer_compute_at_axis, root_map); |
462 | |
463 | // Track dangling leaves which can be produced in |
464 | // BestEffortReplay::replayCasP these don't have any equivalent in producer |
465 | // so they're not in the map. We will simply map them to themselves so we |
466 | // don't lose them. |
467 | id_map forwarded_replay_map; |
468 | auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs(); |
469 | for (auto entry : forward_replay.getReplay()) { |
470 | if (forward_dangling_leaves.find(entry.second) != |
471 | forward_dangling_leaves.end()) { |
472 | forwarded_replay_map[entry.first] = entry.second; |
473 | forward_dangling_leaves.erase(entry.second); |
474 | } |
475 | } |
476 | |
477 | // Replay producer dimensions. |
478 | ReplayTransformations replay_CasP( |
479 | producer_CA_ids, forwarded_replay_map, false); |
480 | |
481 | auto leaf_ids(replay_CasP.getUnorderedLeafIDs()); |
482 | |
483 | // Remove all ids that map to the compute at axis, we're going to replay the |
484 | // rest, track all dims that are needed to match producer CA dims |
485 | std::vector<IterDomain*> needed_dims; |
486 | for (auto p_id : producer_CA_ids) { |
487 | auto it = replay_CasP.getReplay().find(p_id); |
488 | TORCH_INTERNAL_ASSERT( |
489 | it != replay_CasP.getReplay().end(), |
490 | "Could not find axis, " , |
491 | p_id, |
492 | ", requested in replay." ); |
493 | TORCH_INTERNAL_ASSERT( |
494 | leaf_ids.find(it->second) != leaf_ids.end(), |
495 | "Replayed id to match producer id " , |
496 | p_id, |
497 | " should be a leaf in replay map." ); |
498 | leaf_ids.erase(it->second); |
499 | needed_dims.push_back(it->second); |
500 | } |
501 | |
502 | // leaf_ids now contains all consumer ID products that are not used to satisfy |
503 | // the computeAt. Turn into a map so we can play forward these IDs in |
504 | // consumer (if possible): |
505 | id_map consumer_self_replay_map; |
506 | for (auto entry : leaf_ids) { |
507 | consumer_self_replay_map[entry.first] = entry.first; |
508 | } |
509 | |
510 | for (auto entry : forward_dangling_leaves) { |
511 | consumer_self_replay_map[entry.first] = entry.first; |
512 | } |
513 | |
514 | // Check which root domains were used to produce the leaf_ids. We may have |
515 | // picked up extra roots in consumer because of broadcast forwarding. |
516 | std::vector<Val*> unordered_non_root_leaf_vals; |
517 | for (auto leaf_id : replay_CasP.getUnorderedLeafIDs()) { |
518 | if (leaf_id.first->definition() == nullptr) { |
519 | continue; |
520 | } else { |
521 | unordered_non_root_leaf_vals.emplace_back(leaf_id.first); |
522 | } |
523 | } |
524 | |
525 | auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals); |
526 | |
527 | std::vector<IterDomain*> consumer_root = consumer->getRootDomain(); |
528 | |
529 | // Any root domain that was not used to generate computeIDs we can also put in |
530 | // the map to forward their transformations. |
531 | for (auto consumer_root_id : consumer_root) { |
532 | if (std::find( |
533 | processed_roots.begin(), processed_roots.end(), consumer_root_id) == |
534 | processed_roots.end() && |
535 | // Don't re-add roots that may have directly mapped in the replay |
536 | std::find(needed_dims.begin(), needed_dims.end(), consumer_root_id) == |
537 | needed_dims.end()) { |
538 | consumer_self_replay_map[consumer_root_id] = consumer_root_id; |
539 | } |
540 | } |
541 | |
542 | // Play forward transformations all consumer IDs we can |
543 | auto consumer_replayed_leaves = BestEffortReplay( |
544 | consumer->domain()->domain(), |
545 | consumer->domain()->domain(), |
546 | consumer_self_replay_map); |
547 | |
548 | /* |
549 | * Accumulate axes in to the new domain in the following order, making sure to |
550 | * avoid any duplicates: |
551 | * |
552 | * (1) replay_PasC.getReplay holds mappings from axes in consumer compute at |
553 | * axes -> corresponding generated axes in producer |
554 | * |
555 | * (2) Any axes that were not added, that can be mapped directly from an ID in |
556 | * producer->domain(). These are axes that were "fully replayed" relative to |
557 | * the producer, even though it wasn't in the computeAt range. |
558 | * |
559 | * producer_replayed_leaves now contain ids that we tried to forward |
560 | * back to what they were in producer. If they couldn't be forwarded they're |
561 | * left in their "most forwarded" form which may be just a remainder of the |
562 | * transformation required to generate the computeAt axes. |
563 | * |
564 | * (3) Axes in producer->domain() that are in producer_replayed_leaves |
565 | * |
566 | * (4) Axes not in producer->domain() that are in producer_replayed_leaves |
567 | * |
568 | * TODO: Should (2) and (3) be swapped? |
569 | */ |
570 | |
571 | std::vector<IterDomain*> new_IDs; |
572 | std::unordered_set<IterDomain*> used_IDs; |
573 | // Add axes in (1) |
574 | for (auto p_id : producer_CA_ids) { |
575 | auto it = replay_CasP.getReplay().find(p_id); |
576 | TORCH_INTERNAL_ASSERT( |
577 | it != replay_CasP.getReplay().end(), |
578 | "Could not find axis, " , |
579 | p_id, |
580 | ", requested in replay." ); |
581 | new_IDs.push_back(it->second); |
582 | used_IDs.emplace(it->second); |
583 | } |
584 | |
585 | // Add axes in (2) |
586 | for (auto p_id : producer->domain()->domain()) { |
587 | auto it = replay_CasP.getReplay().find(p_id); |
588 | if (it != replay_CasP.getReplay().end()) { |
589 | auto id = it->second; |
590 | // If the leaf id from ReplayTransformations is used to move |
591 | // forward in BestEffortReplay, it is not a final ID. |
592 | if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) == |
593 | consumer_replayed_leaves.getUnorderedLeafIDs().end()) { |
594 | continue; |
595 | } |
596 | if (used_IDs.find(id) == used_IDs.end()) { |
597 | new_IDs.push_back(id); |
598 | used_IDs.emplace(id); |
599 | } |
600 | } |
601 | } |
602 | |
603 | // Add axes in (3) |
604 | for (auto id : consumer->domain()->domain()) { |
605 | if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) != |
606 | consumer_replayed_leaves.getUnorderedLeafIDs().end()) { |
607 | if (used_IDs.find(id) == used_IDs.end()) { |
608 | new_IDs.push_back(id); |
609 | used_IDs.emplace(id); |
610 | } |
611 | } |
612 | } |
613 | |
614 | // Add axes in (4) |
615 | for (auto id : consumer_replayed_leaves.getLeafIDs()) |
616 | if (used_IDs.find(id) == used_IDs.end()) |
617 | new_IDs.push_back(id); |
618 | |
619 | TensorDomain* replayed = IrBuilder::create<TensorDomain>( |
620 | consumer->container(), |
621 | consumer->getRootDomain(), |
622 | consumer->getRFactorDomain(), |
623 | new_IDs, |
624 | consumer->domain()->contiguity()); |
625 | |
626 | return {replayed, producer_CA_ids.size()}; |
627 | } |
628 | |
629 | // replay Producer as Consumer |
630 | std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC( |
631 | const TensorView* producer, |
632 | const TensorView* consumer, |
633 | int compute_at_axis, |
634 | bool replay_swizzle) { |
635 | // Use the pairwise root map as a default mapper |
636 | PairwiseRootDomainMap root_map(producer, consumer); |
637 | return replayPasC( |
638 | producer, consumer, compute_at_axis, root_map, replay_swizzle); |
639 | } |
640 | |
641 | std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP( |
642 | const TensorView* consumer, |
643 | const TensorView* producer, |
644 | int compute_at_axis) { |
645 | // Use the pairwise root map as a default mapper |
646 | PairwiseRootDomainMap root_map(producer, consumer); |
647 | return replayCasP(consumer, producer, compute_at_axis, root_map); |
648 | } |
649 | |
650 | // In a PasC replay, we want the producer to exactly match the consumer: |
651 | // all the beginning axes in the producer should be mapped to the consumer in |
652 | // the same order. Reductions in the producer needs to be in the back of the |
653 | // producer. |
654 | int TransformReplay::getMatchedLeafPosWithoutReplayPasC( |
655 | const TensorView* producer, |
656 | const TensorView* consumer, |
657 | int consumer_pos) { |
658 | FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC" ); |
659 | |
660 | const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); |
661 | id_map c2p_root_map = pairwise_map.mapConsumerToProducer( |
662 | consumer->domain(), producer->domain()); |
663 | |
664 | // IterDomains in `consumer` root also in `producer` root |
665 | const auto consumer_domain = consumer->domain()->domain(); |
666 | |
667 | std::unordered_set<Val*> mapped_consumer_roots; |
668 | for (auto entry : c2p_root_map) { |
669 | mapped_consumer_roots.emplace(entry.first); |
670 | } |
671 | |
672 | auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( |
673 | mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); |
674 | |
675 | std::unordered_set<Val*> unskippable_consumer_ids( |
676 | unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); |
677 | |
678 | // IterDomains in `producer` root also in `consumer` root |
679 | const auto producer_domain = producer->domain()->domain(); |
680 | |
681 | auto it_consumer = consumer_domain.begin(); |
682 | auto it_producer = producer_domain.begin(); |
683 | |
684 | auto disjoint_sets = |
685 | BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) |
686 | .getDisjointSets(); |
687 | |
688 | int mismatched_consumer_pos = 0; |
689 | int mismatched_producer_pos = 0; |
690 | while (it_consumer != consumer_domain.end()) { |
691 | if (consumer_pos == mismatched_consumer_pos) { |
692 | return mismatched_producer_pos; |
693 | } |
694 | |
695 | auto consumer_id = *it_consumer; |
696 | if (unskippable_consumer_ids.count(consumer_id) == 0) { |
697 | ++it_consumer; |
698 | ++mismatched_consumer_pos; |
699 | continue; |
700 | } |
701 | |
702 | if (it_producer == producer_domain.end()) { |
703 | return -1; |
704 | } |
705 | |
706 | auto producer_id = *it_producer; |
707 | if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) { |
708 | ++mismatched_consumer_pos; |
709 | ++mismatched_producer_pos; |
710 | ++it_consumer; |
711 | ++it_producer; |
712 | } else { |
713 | return -1; |
714 | } |
715 | } |
716 | if (consumer_pos == mismatched_consumer_pos) { |
717 | return mismatched_producer_pos; |
718 | } |
719 | return -1; |
720 | } |
721 | |
722 | // We want to ignore reductions in the producer in a CasP replay. |
723 | int TransformReplay::getMatchedLeafPosWithoutReplayCasP( |
724 | const TensorView* consumer, |
725 | const TensorView* producer, |
726 | int producer_pos) { |
727 | FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP" ); |
728 | |
729 | const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); |
730 | id_map p2c_root_map = pairwise_map.mapProducerToConsumer( |
731 | producer->domain(), consumer->domain()); |
732 | |
733 | // IterDomains in `producer` root that are not reduction |
734 | const auto producer_domain = producer->domain()->domain(); |
735 | auto unskippable_producer_ids_vec = |
736 | TensorDomain::noReductions(producer_domain); |
737 | std::unordered_set<IterDomain*> unskippable_producer_ids( |
738 | unskippable_producer_ids_vec.begin(), unskippable_producer_ids_vec.end()); |
739 | |
740 | // IterDomains in `consumer` root also in `producer` root |
741 | const auto consumer_domain = consumer->domain()->domain(); |
742 | |
743 | std::unordered_set<Val*> mapped_consumer_roots; |
744 | for (auto entry : p2c_root_map) { |
745 | mapped_consumer_roots.emplace(entry.second); |
746 | } |
747 | |
748 | auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( |
749 | mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); |
750 | |
751 | std::unordered_set<Val*> unskippable_consumer_ids( |
752 | unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); |
753 | |
754 | auto it_producer = producer_domain.begin(); |
755 | auto it_consumer = consumer_domain.begin(); |
756 | |
757 | auto disjoint_sets = |
758 | BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) |
759 | .getDisjointSets(); |
760 | |
761 | int mismatched_producer_pos = 0; |
762 | int mismatched_consumer_pos = 0; |
763 | while (it_producer != producer_domain.end()) { |
764 | if (producer_pos == mismatched_producer_pos) { |
765 | return mismatched_consumer_pos; |
766 | } |
767 | |
768 | auto producer_id = *it_producer; |
769 | if (unskippable_producer_ids.count(producer_id) == 0) { |
770 | ++it_producer; |
771 | ++mismatched_producer_pos; |
772 | continue; |
773 | } |
774 | |
775 | if (it_consumer == consumer_domain.end()) { |
776 | return -1; |
777 | } |
778 | |
779 | auto consumer_id = *it_consumer; |
780 | if (unskippable_consumer_ids.count(consumer_id) == 0) { |
781 | ++it_consumer; |
782 | ++mismatched_consumer_pos; |
783 | continue; |
784 | } |
785 | |
786 | if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) { |
787 | ++mismatched_producer_pos; |
788 | ++mismatched_consumer_pos; |
789 | ++it_producer; |
790 | ++it_consumer; |
791 | } else { |
792 | return -1; |
793 | } |
794 | } |
795 | if (producer_pos == mismatched_producer_pos) { |
796 | return mismatched_consumer_pos; |
797 | } |
798 | return -1; |
799 | } |
800 | |
801 | bool TransformReplay::fullSelfMatching( |
802 | const TensorView* replay, |
803 | const TensorView* target) { |
804 | auto replay_root = replay->getRootDomain(); |
805 | auto replay_dom = replay->domain()->domain(); |
806 | auto target_root = target->getRootDomain(); |
807 | auto target_dom = target->domain()->domain(); |
808 | std::unordered_map<IterDomain*, IterDomain*> target2replay_map; |
809 | if (replay_root.size() != target_root.size()) { |
810 | return false; |
811 | } |
812 | target2replay_map.reserve(replay_root.size()); |
813 | std::transform( |
814 | target_root.begin(), |
815 | target_root.end(), |
816 | replay_root.begin(), |
817 | std::inserter(target2replay_map, target2replay_map.begin()), |
818 | [](auto a, auto b) { return std::make_pair(a, b); }); |
819 | BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); |
820 | auto r = replay_.getReplay(); |
821 | for (int64_t i = 0; i < (int64_t)replay_dom.size(); i++) { |
822 | auto target_id = target_dom[i]; |
823 | auto replay_it = r.find(target_id); |
824 | if (replay_it == r.end() || replay_it->second != replay_dom[i]) { |
825 | return false; |
826 | } |
827 | } |
828 | return true; |
829 | } |
830 | |
831 | namespace { |
832 | |
833 | // Make sure if tv is set to new_td it doesn't violate set compute at and max |
834 | // produce at positions. |
835 | bool validateDomain(TensorView* tv, TensorDomain* new_td) { |
836 | auto first_mismatch = |
837 | BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); |
838 | return first_mismatch >= (int)tv->getMaxProducerPosition() && |
839 | first_mismatch >= (int)tv->getComputeAtPosition(); |
840 | } |
841 | |
842 | } // namespace |
843 | |
844 | void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) { |
845 | int pos = replayed_pos_.at(from); |
846 | // Note: [Using multiple TransformPropagators] |
847 | // There are cases that we use multiple TransformPropagators along different |
848 | // spanning trees with different references in the same fusion. Some of these |
849 | // spanning trees could overlap. In cases when there are overlapping nodes, |
850 | // TransformPropagator needs to respect the replay of others, because the |
851 | // current TransformPropagator might not contain the most amount of |
852 | // information on how to do the correct transformation. The logic below tells |
853 | // TransformPropagator to skip the replay when not necessary. |
854 | int new_pos = |
855 | TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); |
856 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
857 | if (debug) { |
858 | std::cout << "TransformPropagator::propagateC2P" << std::endl; |
859 | std::cout << " from: " << from << " @ " << pos << std::endl; |
860 | std::cout << " to: " << to << std::endl; |
861 | } |
862 | if (new_pos < 0) { |
863 | auto replay = TransformReplay::replayPasC(to, from, pos); |
864 | TORCH_INTERNAL_ASSERT( |
865 | validateDomain(to, replay.first), |
866 | "Tried to set the domain of " , |
867 | to, |
868 | " to " , |
869 | replay.first, |
870 | " but that would invalidate previously compute at position or max producer position." ); |
871 | to->setDomain(replay.first); |
872 | new_pos = replay.second; |
873 | if (debug) { |
874 | std::cout << " replayed: " << to << " @ " << new_pos << std::endl; |
875 | } |
876 | } else if (debug) { |
877 | std::cout << " replay skipped. result position: " << new_pos << std::endl; |
878 | } |
879 | replayed_pos_[to] = new_pos; |
880 | } |
881 | |
882 | void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) { |
883 | int pos = replayed_pos_.at(from); |
884 | // See note [Using multiple TransformPropagators] |
885 | int new_pos = |
886 | TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); |
887 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
888 | if (debug) { |
889 | std::cout << "TransformPropagator::propagateP2C" << std::endl; |
890 | std::cout << " from: " << from << " @ " << pos << std::endl; |
891 | std::cout << " to: " << to << std::endl; |
892 | } |
893 | if (new_pos < 0) { |
894 | auto replay = TransformReplay::replayCasP(to, from, pos); |
895 | TORCH_INTERNAL_ASSERT( |
896 | validateDomain(to, replay.first), |
897 | "Tried to set the domain of " , |
898 | to, |
899 | " to " , |
900 | replay.first, |
901 | " but that would invalidate previously compute at position or max producer position." ); |
902 | to->setDomain(replay.first); |
903 | new_pos = replay.second; |
904 | if (debug) { |
905 | std::cout << " replayed: " << to << " @ " << new_pos << std::endl; |
906 | } |
907 | } else if (debug) { |
908 | std::cout << " replay skipped. result position: " << new_pos << std::endl; |
909 | } |
910 | replayed_pos_[to] = new_pos; |
911 | } |
912 | |
913 | void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) { |
914 | int pos = replayed_pos_.at(from); |
915 | // See note [Using multiple TransformPropagators] |
916 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
917 | if (debug) { |
918 | std::cout << "TransformPropagator::propagateSibling" << std::endl; |
919 | std::cout << " from: " << from << " @ " << pos << std::endl; |
920 | std::cout << " to: " << to << std::endl; |
921 | } |
922 | if (!TransformReplay::fullSelfMatching(to, from)) { |
923 | auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); |
924 | TORCH_INTERNAL_ASSERT( |
925 | validateDomain(to, replay), |
926 | "Tried to set the domain of " , |
927 | to, |
928 | " to " , |
929 | replay, |
930 | " but that would invalidate previously compute at position or max producer position." ); |
931 | to->setDomain(replay); |
932 | if (debug) { |
933 | std::cout << " replayed: " << to << " @ " << pos << std::endl; |
934 | } |
935 | } else if (debug) { |
936 | std::cout << " replay skipped. result position: " << pos << std::endl; |
937 | } |
938 | replayed_pos_[to] = pos; |
939 | } |
940 | |
941 | TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { |
942 | if (pos < 0) { |
943 | pos += int64_t(from->nDims()) + 1; |
944 | } |
945 | TORCH_CHECK( |
946 | pos >= 0 && pos <= (int64_t)from->nDims(), |
947 | "TransformPropagator called on an pos outside valid range." ); |
948 | replayed_pos_[from] = pos; |
949 | } |
950 | |
951 | void MostInlinedTransformPropagator::propagateC2P( |
952 | TensorView* from, |
953 | TensorView* to) { |
954 | int pos = from->nDims(); |
955 | // See note [Using multiple TransformPropagators] |
956 | int new_pos = |
957 | TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); |
958 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
959 | if (debug) { |
960 | std::cout << "MostInlinedTransformPropagator::propagateC2P" << std::endl; |
961 | std::cout << " from: " << from << std::endl; |
962 | std::cout << " to: " << to << std::endl; |
963 | } |
964 | if (new_pos < 0) { |
965 | auto replay = TransformReplay::replayPasC(to, from, pos); |
966 | TORCH_INTERNAL_ASSERT( |
967 | validateDomain(to, replay.first), |
968 | "Tried to set the domain of " , |
969 | to, |
970 | " to " , |
971 | replay.first, |
972 | " but that would invalidate previously compute at position or max producer position." ); |
973 | to->setDomain(replay.first); |
974 | if (debug) { |
975 | std::cout << " replayed: " << to << std::endl; |
976 | } |
977 | } else if (debug) { |
978 | std::cout << " replay skipped" << std::endl; |
979 | } |
980 | } |
981 | |
982 | void MostInlinedTransformPropagator::propagateP2C( |
983 | TensorView* from, |
984 | TensorView* to) { |
985 | int pos = from->nDims(); |
986 | // See note [Using multiple TransformPropagators] |
987 | int new_pos = |
988 | TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); |
989 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
990 | if (debug) { |
991 | std::cout << "MostInlinedTransformPropagator::propagateP2C" << std::endl; |
992 | std::cout << " from: " << from << std::endl; |
993 | std::cout << " to: " << to << std::endl; |
994 | } |
995 | if (new_pos < 0) { |
996 | auto replay = TransformReplay::replayCasP(to, from, pos); |
997 | TORCH_INTERNAL_ASSERT( |
998 | validateDomain(to, replay.first), |
999 | "Tried to set the domain of " , |
1000 | to, |
1001 | " to " , |
1002 | replay.first, |
1003 | " but that would invalidate previously compute at position or max producer position." ); |
1004 | to->setDomain(replay.first); |
1005 | if (debug) { |
1006 | std::cout << " replayed: " << to << std::endl; |
1007 | } |
1008 | } else if (debug) { |
1009 | std::cout << " replay skipped" << std::endl; |
1010 | } |
1011 | } |
1012 | |
1013 | void MostInlinedTransformPropagator::propagateSibling( |
1014 | TensorView* from, |
1015 | TensorView* to) { |
1016 | // See note [Using multiple TransformPropagators] |
1017 | bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); |
1018 | if (debug) { |
1019 | std::cout << "MostInlinedTransformPropagator::propagateSibling" |
1020 | << std::endl; |
1021 | std::cout << " from: " << from << std::endl; |
1022 | std::cout << " to: " << to << std::endl; |
1023 | } |
1024 | if (!TransformReplay::fullSelfMatching(to, from)) { |
1025 | auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); |
1026 | TORCH_INTERNAL_ASSERT( |
1027 | validateDomain(to, replay), |
1028 | "Tried to set the domain of " , |
1029 | to, |
1030 | " to " , |
1031 | replay, |
1032 | " but that would invalidate previously compute at position or max producer position." ); |
1033 | to->setDomain(replay); |
1034 | if (debug) { |
1035 | std::cout << " replayed: " << to << std::endl; |
1036 | } |
1037 | } else if (debug) { |
1038 | std::cout << " replay skipped" << std::endl; |
1039 | } |
1040 | } |
1041 | |
1042 | } // namespace cuda |
1043 | } // namespace fuser |
1044 | } // namespace jit |
1045 | } // namespace torch |
1046 | |