1 | #include <transform_iter.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <ir_utils.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace fuser { |
9 | namespace cuda { |
10 | |
11 | // Transform dispatch |
12 | void ReplayTransformations::handle(Expr* e) { |
13 | switch (e->getExprType().value()) { |
14 | case (ExprType::Split): |
15 | case (ExprType::Merge): |
16 | case (ExprType::Swizzle2D): |
17 | break; |
18 | default: |
19 | TORCH_INTERNAL_ASSERT( |
20 | false, "Invalid expr type found in transform traversal." ); |
21 | } |
22 | IterVisitor::handle(e); |
23 | } |
24 | |
25 | // We're going to replay this split operation on the corresponding ID |
26 | void ReplayTransformations::handle(Split* s) { |
27 | // Grab our input to the split node |
28 | auto id_in = s->in(); |
29 | |
30 | // Make sure we have a corresponding entry in our map pointing to the ID we're |
31 | // going to replay the split on |
32 | auto it = id_map_.find(id_in); |
33 | if (it == id_map_.end()) { |
34 | if (error_on_failure_) { |
35 | TORCH_INTERNAL_ASSERT( |
36 | false, "Transform traversal failed, dependencies not met." ); |
37 | } else { |
38 | return; |
39 | } |
40 | } |
41 | |
42 | auto mapped = (*it).second; |
43 | // Make sure this ID is a leaf ID (meaning it has no uses we generated) |
44 | TORCH_INTERNAL_ASSERT( |
45 | leaf_ids_.find(mapped) != leaf_ids_.end(), |
46 | "Transform traversal failed, modified a node but it was not a leaf node." ); |
47 | |
48 | // Replay the split onto mapped |
49 | auto outs = IterDomain::split( |
50 | mapped, s->factor(), s->innerSplit(), s->startOffset(), s->stopOffset()); |
51 | // Remove mapped from the leaf IDs |
52 | leaf_ids_.erase(mapped); |
53 | |
54 | // Add outputs to leaf IDs |
55 | leaf_ids_[outs.first] = counter++; |
56 | leaf_ids_[outs.second] = counter++; |
57 | |
58 | // Update our ID map to include these outputs |
59 | id_map_[s->outer()] = outs.first; |
60 | id_map_[s->inner()] = outs.second; |
61 | } |
62 | |
63 | // We're going to replay this merge operation on the corresponding IDs |
64 | void ReplayTransformations::handle(Merge* m) { |
65 | // Grab the inputs to the merge node |
66 | auto id_outer = m->outer(); |
67 | auto id_inner = m->inner(); |
68 | |
69 | // Make sure we have a corresponding entry in our map pointing to the IDs |
70 | // we're going to replay the merge on |
71 | auto it_outer = id_map_.find(id_outer); |
72 | auto it_inner = id_map_.find(id_inner); |
73 | |
74 | const bool outer_found = it_outer != id_map_.end(); |
75 | const bool outer_bcast = id_outer->isBroadcast(); |
76 | const bool inner_found = it_inner != id_map_.end(); |
77 | const bool inner_bcast = id_inner->isBroadcast(); |
78 | |
79 | // If either are not found |
80 | if (!outer_found || !inner_found) { |
81 | // If both aren't found, it's a failure |
82 | // If outer is found && inner is bcast it is not a failure |
83 | // If inner is found && outer is bcast it is not a failure |
84 | if (!(outer_found || inner_found) || (outer_found && !inner_bcast) || |
85 | (inner_found && !outer_bcast)) { |
86 | if (error_on_failure_) { |
87 | TORCH_INTERNAL_ASSERT( |
88 | false, "Transform traversal failed, dependencies not met." ); |
89 | } else { |
90 | return; |
91 | } |
92 | } |
93 | } |
94 | |
95 | // If we merge a broadcast dim with a non-broadcast dim, just remap the output |
96 | // to the non-broadcast dim. |
97 | if (inner_found && !outer_found && outer_bcast) { |
98 | id_map_[m->out()] = it_inner->second; |
99 | return; |
100 | } |
101 | if (outer_found && !inner_found && inner_bcast) { |
102 | id_map_[m->out()] = it_outer->second; |
103 | return; |
104 | } |
105 | |
106 | // Grab the IDs we're going to replay this merge on |
107 | const auto id_outer_mapped = it_outer->second; |
108 | const auto id_inner_mapped = it_inner->second; |
109 | |
110 | // Make sure these IDs are leaf IDs (meaning they have no uses we generated) |
111 | TORCH_INTERNAL_ASSERT( |
112 | leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && |
113 | leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), |
114 | "Transform traversal failed, tried to replay with " , |
115 | id_outer_mapped, |
116 | " and " , |
117 | id_inner_mapped, |
118 | " however one or both are not leaf nodes." ); |
119 | |
120 | // Replay the merge operation |
121 | auto out = IterDomain::merge(id_outer_mapped, id_inner_mapped); |
122 | |
123 | // Remove inputs from the leaf IDs |
124 | leaf_ids_.erase(id_outer_mapped); |
125 | leaf_ids_.erase(id_inner_mapped); |
126 | |
127 | // Add the output to the leaf IDs |
128 | leaf_ids_[out] = counter++; |
129 | |
130 | // Update our ID map with the replayed output |
131 | id_map_[m->out()] = out; |
132 | } |
133 | |
134 | void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { |
135 | // Grab our input to the split node |
136 | auto id_in_x = swizzle_2d->inX(); |
137 | auto id_in_y = swizzle_2d->inY(); |
138 | |
139 | // Make sure we have a corresponding entry in our map pointing to the ID we're |
140 | // going to replay the swizzle on |
141 | auto it_x = id_map_.find(id_in_x); |
142 | auto it_y = id_map_.find(id_in_y); |
143 | |
144 | if (it_x == id_map_.end() || it_y == id_map_.end()) { |
145 | if (error_on_failure_) { |
146 | TORCH_INTERNAL_ASSERT( |
147 | false, "Transform traversal failed, dependencies not met." ); |
148 | } else { |
149 | return; |
150 | } |
151 | } |
152 | |
153 | auto mapped_x = (*it_x).second; |
154 | auto mapped_y = (*it_y).second; |
155 | |
156 | // Make sure this ID is a leaf ID (meaning it has no uses we generated) |
157 | TORCH_INTERNAL_ASSERT( |
158 | leaf_ids_.find(mapped_x) != leaf_ids_.end() && |
159 | leaf_ids_.find(mapped_y) != leaf_ids_.end(), |
160 | "Transform traversal failed, modified a node but it was not a leaf node." ); |
161 | |
162 | auto outs = std::make_pair(mapped_x, mapped_y); |
163 | |
164 | if (replay_swizzle_) { |
165 | // Replay the swizzle onto mapped |
166 | outs = IterDomain::swizzle(swizzle_2d->swizzleType(), mapped_x, mapped_y); |
167 | |
168 | // Remove mapped from the leaf IDs |
169 | leaf_ids_.erase(mapped_x); |
170 | leaf_ids_.erase(mapped_y); |
171 | } |
172 | |
173 | // Add outputs to leaf IDs |
174 | leaf_ids_[outs.first] = counter++; |
175 | leaf_ids_[outs.second] = counter++; |
176 | |
177 | // Update our ID map to include these outputs |
178 | id_map_[swizzle_2d->outX()] = outs.first; |
179 | id_map_[swizzle_2d->outY()] = outs.second; |
180 | } |
181 | |
182 | ReplayTransformations::ReplayTransformations( |
183 | const std::vector<IterDomain*>& _target_domain, |
184 | std::unordered_map<IterDomain*, IterDomain*> _id_map, |
185 | bool _error_on_failure, |
186 | bool replay_swizzle) |
187 | : target_domain_(_target_domain), |
188 | id_map_(std::move(_id_map)), |
189 | error_on_failure_(_error_on_failure), |
190 | replay_swizzle_(replay_swizzle) { |
191 | // Make sure id_map has all the inputs needed to replay target_domain |
192 | auto inps = IterVisitor::getInputsTo( |
193 | std::vector<Val*>(target_domain_.begin(), target_domain_.end())); |
194 | |
195 | if (error_on_failure_) |
196 | std::for_each(inps.begin(), inps.end(), [this](Val* val) { |
197 | TORCH_INTERNAL_ASSERT( |
198 | val->getValType().value() == ValType::IterDomain, |
199 | "Expected IterDomain only for Replay Transformations, but found " , |
200 | val); |
201 | IterDomain* id = val->as<IterDomain>(); |
202 | TORCH_INTERNAL_ASSERT( |
203 | id_map_.find(id) != id_map_.end(), |
204 | "Could not find required input: " , |
205 | id, |
206 | " in provided id_map." ); |
207 | }); |
208 | |
209 | // Set all the leaf nodes for tracking, all ids start as a leaf and will be |
210 | // updated based on the transformations |
211 | for (auto entry : id_map_) |
212 | leaf_ids_[entry.second] = counter++; |
213 | } |
214 | |
215 | // Replays outputs that were generated from ids.first on ids.second |
216 | void ReplayTransformations::runReplay() { |
217 | TORCH_INTERNAL_ASSERT( |
218 | !ran_replay, |
219 | "Cannot run replay twice without creating a new Replay Class." ); |
220 | ran_replay = true; |
221 | if (target_domain_.empty() || id_map_.empty()) |
222 | return; |
223 | |
224 | // Switch outDomain to a vector to start the traversal |
225 | std::vector<Val*> traversal_vals( |
226 | target_domain_.begin(), target_domain_.end()); |
227 | traverseTo(traversal_vals[0]->fusion(), traversal_vals); |
228 | |
229 | if (error_on_failure_) |
230 | TORCH_INTERNAL_ASSERT( |
231 | leaf_ids_.size() >= target_domain_.size(), |
232 | "Transform traversal failed, did not find enough output IterDomains." ); |
233 | |
234 | // Validate replay |
235 | for (auto out : target_domain_) { |
236 | auto it_replayed = id_map_.find(out); |
237 | if (it_replayed == id_map_.end()) { |
238 | if (error_on_failure_) { |
239 | TORCH_INTERNAL_ASSERT( |
240 | false, |
241 | "Transform traversal failed, could not find expected output." ); |
242 | } |
243 | continue; |
244 | } |
245 | |
246 | auto id_replayed = (*it_replayed).second; |
247 | auto it_leaf = leaf_ids_.find(id_replayed); |
248 | TORCH_INTERNAL_ASSERT( |
249 | it_leaf != leaf_ids_.end(), |
250 | "Transform Traversal failed, expected a replayed dim for " , |
251 | out, |
252 | " but one was not created." ); |
253 | } |
254 | |
255 | // Populate leaf_vec_ in a deterministic manner. This is deterministic |
256 | // because size_t in leaf_ids is filled based on operation order. |
257 | std::set<std::pair<IterDomain*, size_t>, id_int_lt> ordered_set; |
258 | for (auto entry : leaf_ids_) |
259 | ordered_set.emplace(entry); |
260 | |
261 | leaf_vec_.clear(); |
262 | leaf_vec_.resize(ordered_set.size()); |
263 | std::transform( |
264 | ordered_set.begin(), |
265 | ordered_set.end(), |
266 | leaf_vec_.begin(), |
267 | [](std::pair<IterDomain*, size_t> entry) { return entry.first; }); |
268 | } |
269 | |
270 | BestEffortReplay::BestEffortReplay( |
271 | const std::vector<IterDomain*>& replay_domain, |
272 | const std::vector<IterDomain*>& target_domain, |
273 | std::unordered_map<IterDomain*, IterDomain*> target2replay_map, |
274 | std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map, |
275 | std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map, |
276 | bool skip_swizzle) |
277 | : target2replay_id_map_(std::move(target2replay_map)), |
278 | replay_forward_id_map_(std::move(replay_forward_id_map)), |
279 | target_forward_id_map_(std::move(target_forward_id_map)), |
280 | skip_swizzle_(skip_swizzle) { |
281 | for (auto entry : target2replay_id_map_) { |
282 | leaf_ids_[entry.second] = counter++; |
283 | } |
284 | |
285 | // Grab expr history of iter domains in target_domain |
286 | std::vector<Expr*> target_exprs = StmtSort::getExprs( |
287 | FusionGuard::getCurFusion(), |
288 | std::vector<Val*>(target_domain.begin(), target_domain.end())); |
289 | |
290 | // If we check how an IterDomain was generated, it should only use an |
291 | // IterDomain in an expression once. We pull a map from the input |
292 | // IterDomains to the expression consuming them to generate the |
293 | // replay_domain domain. This will be used to propagate the target_domain to |
294 | // replay_domain map. |
295 | |
296 | // Map replay domain's IterDomains to the Exprs they're used in |
297 | std::vector<Expr*> replay_exprs = StmtSort::getExprs( |
298 | FusionGuard::getCurFusion(), |
299 | std::vector<Val*>(replay_domain.begin(), replay_domain.end())); |
300 | |
301 | // Track which id's in replay have to be replayed to guarantee rfactor |
302 | // transformations. The iteration domains in the rfactor axes don't have |
303 | // to be used in a matching expression in target, so we want to exclude those. |
304 | // Only the iteration domains [root_domains, rfactor) domains have to be used |
305 | // in matching transformation to guarantee rfactor domain is consistent. |
306 | // However, if any rfactor id was used to produce the rfactor domain, we need |
307 | // transformations on them to match the target exactly. |
308 | std::unordered_set<IterDomain*> replay_rfactor_ids; |
309 | |
310 | // Track which expressions iteration domains are used, they should only be |
311 | // used in one expression. |
312 | std::unordered_map<IterDomain*, Expr*> replay_id2expr_map; |
313 | for (auto replay_expr : replay_exprs) { |
314 | for (auto id : ir_utils::filterByType<IterDomain>(replay_expr->inputs())) { |
315 | TORCH_INTERNAL_ASSERT( |
316 | replay_id2expr_map.find(id) == replay_id2expr_map.end(), |
317 | "Error trying to map rfactor root domain during replay." , |
318 | " An IterDomain was found to be used in more than one expression." ); |
319 | |
320 | replay_id2expr_map[id] = replay_expr; |
321 | } |
322 | |
323 | // Only want to forward rfactor in map |
324 | auto out_ids = ir_utils::filterByType<IterDomain>(replay_expr->outputs()); |
325 | if (std::any_of(out_ids.begin(), out_ids.end(), [](IterDomain* id) { |
326 | return id->isRFactorProduct(); |
327 | })) { |
328 | auto inp_ids = ir_utils::filterByType<IterDomain>(replay_expr->inputs()); |
329 | replay_rfactor_ids.insert(inp_ids.begin(), inp_ids.end()); |
330 | } |
331 | } |
332 | |
333 | std::unordered_map<IterDomain*, Expr*> target_id2expr_map; |
334 | for (auto target_expr : target_exprs) { |
335 | for (auto id : ir_utils::filterByType<IterDomain>(target_expr->inputs())) { |
336 | TORCH_INTERNAL_ASSERT( |
337 | target_id2expr_map.insert({id, target_expr}).second, |
338 | "BestEffortReplay : Unexpected multi-use of id" , |
339 | id); |
340 | } |
341 | } |
342 | |
343 | if (skip_swizzle_) { |
344 | // Progress through all swizzle ops if we are skipping |
345 | // swizzles on the mapping. |
346 | skipSwizzles(target_id2expr_map, replay_id2expr_map); |
347 | } |
348 | |
349 | std::string err_str( |
350 | "Error during replay, a transformation was called that conflicts with an rfactor call." ); |
351 | |
352 | bool any_target_expr_contains_broadcast_id = false; |
353 | |
354 | // Iterate through target IterDomains' history and compare with what we |
355 | // recorded from replay_domain |
356 | for (auto target_expr : target_exprs) { |
357 | auto target_inps_filtered = |
358 | ir_utils::filterByType<IterDomain>(target_expr->inputs()); |
359 | |
360 | // If any input argument in target expression is in the forward map then |
361 | // forward the mapped IterDomains in replay and continue to the next |
362 | // expression as target_expr cannot match a replay_expr |
363 | if (std::any_of( |
364 | target_inps_filtered.begin(), |
365 | target_inps_filtered.end(), |
366 | [&](IterDomain* target_inp) { |
367 | return this->inTargetForwardMap(target_inp); |
368 | })) { |
369 | for (auto target_inp : target_inps_filtered) { |
370 | if (inTargetForwardMap(target_inp)) { |
371 | auto target2replay_it = target2replay_id_map_.find(target_inp); |
372 | if (target2replay_it != target2replay_id_map_.end()) { |
373 | // Replace target_inp entry in target2replay_id_map_ with forwarded |
374 | // id |
375 | target2replay_id_map_[getTargetForwardedId(target_inp)] = |
376 | target2replay_it->second; |
377 | target2replay_id_map_.erase(target_inp); |
378 | } |
379 | } |
380 | } |
381 | // Continue to next target_expr |
382 | continue; |
383 | } |
384 | |
385 | std::vector<IterDomain*> target_id_inps( |
386 | target_inps_filtered.begin(), target_inps_filtered.end()); |
387 | |
388 | bool target_expr_contains_broadcast_id = std::any_of( |
389 | target_inps_filtered.begin(), |
390 | target_inps_filtered.end(), |
391 | [](IterDomain* id) { return id->isBroadcast(); }); |
392 | any_target_expr_contains_broadcast_id = |
393 | any_target_expr_contains_broadcast_id || |
394 | target_expr_contains_broadcast_id; |
395 | |
396 | std::vector<IterDomain*> replay_inps = |
397 | std::vector<IterDomain*>(target_id_inps.size(), nullptr); |
398 | |
399 | bool missing_replay_input = false; |
400 | |
401 | // Map target_expr inputs to replay domain directly |
402 | for (const auto t_i : c10::irange(target_id_inps.size())) { |
403 | // There might not be a mapping, that could be okay (depends on rfactor |
404 | // checking). |
405 | auto it = target2replay_id_map_.find(target_id_inps[t_i]); |
406 | if (it != target2replay_id_map_.end()) { |
407 | replay_inps[t_i] = getReplayForwardedId(it->second); |
408 | } else { |
409 | missing_replay_input = true; |
410 | } |
411 | } |
412 | |
413 | // Check if any of the associated replay id's are part of an rfactor domain |
414 | bool replay_has_rfactor_inp = std::any_of( |
415 | replay_inps.begin(), |
416 | replay_inps.end(), |
417 | [&replay_rfactor_ids](IterDomain* id) { |
418 | return id == nullptr ? false |
419 | : id->isRFactorProduct() && |
420 | (replay_rfactor_ids.find(id) != replay_rfactor_ids.end()); |
421 | }); |
422 | |
423 | // If some replay id inputs are part of rfactor, make sure all target |
424 | // expression inputs map to a replay input |
425 | if (replay_has_rfactor_inp) { |
426 | bool no_missing_exprs = std::none_of( |
427 | replay_inps.begin(), |
428 | replay_inps.end(), |
429 | [&replay_id2expr_map](IterDomain* id) { |
430 | if (id == nullptr) { |
431 | return true; |
432 | } else { |
433 | return replay_id2expr_map.find(id) == replay_id2expr_map.end(); |
434 | } |
435 | }); |
436 | // View operation creates a TensorView with rfactor. After view, broadcast |
437 | // operation adds iterDomains for any size-1 dimensions. Therefore, the |
438 | // target domain (broadcast) may contain broadcast ids that are not |
439 | // present in the replay domain (view). In this case, we skip any target |
440 | // expressions that contain broadcast ids. |
441 | TORCH_INTERNAL_ASSERT( |
442 | no_missing_exprs || any_target_expr_contains_broadcast_id, err_str); |
443 | } |
444 | |
445 | // If any inputs are missing, continue as this expr doesn't match. |
446 | if (missing_replay_input) { |
447 | TORCH_INTERNAL_ASSERT( |
448 | !replay_has_rfactor_inp || any_target_expr_contains_broadcast_id, |
449 | err_str); |
450 | continue; |
451 | } |
452 | |
453 | // Find which replay_expr maps to the target_expr |
454 | Expr* replay_expr = nullptr; |
455 | // Check if all inputs have the same expression |
456 | bool mismatched_replay_exprs = false; |
457 | for (auto replay_inp : replay_inps) { |
458 | auto it = replay_id2expr_map.find(replay_inp); |
459 | if (it != replay_id2expr_map.end()) { |
460 | if (replay_expr == nullptr) { |
461 | replay_expr = it->second; |
462 | } else { |
463 | mismatched_replay_exprs = |
464 | mismatched_replay_exprs || replay_expr != it->second; |
465 | } |
466 | } else { |
467 | // If no expr is mapped then set mismatched epxrs to go to continue to |
468 | // the next target expr |
469 | mismatched_replay_exprs = true; |
470 | } |
471 | } |
472 | |
473 | // If expressions of mapped inputs don't match, then continue to next target |
474 | // expr |
475 | if (mismatched_replay_exprs || replay_expr == nullptr) { |
476 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
477 | continue; |
478 | } |
479 | |
480 | bool mismatched_inputs = replay_inps.size() != replay_expr->inputs().size(); |
481 | for (size_t i = 0; i < replay_inps.size() && !mismatched_inputs; i++) { |
482 | mismatched_inputs = |
483 | mismatched_inputs || replay_expr->inputs()[i] != replay_inps[i]; |
484 | } |
485 | |
486 | // If there isn't an rfactor id in the replay's inputs and there's a |
487 | // mismatched input, continue |
488 | if (mismatched_inputs) { |
489 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
490 | continue; |
491 | } |
492 | |
493 | // If there isn't an rfactor id in the replay's inputs and there's a |
494 | // mismatch in replay_expr's and target_expr's outputs, continue |
495 | if (target_expr->outputs().size() != replay_expr->outputs().size()) { |
496 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
497 | continue; |
498 | } |
499 | |
500 | // If there isn't an rfactor id in the replay's inputs and there's a |
501 | // mismatch in replay_expr's and target_expr's expression type, continue |
502 | if (replay_expr->getExprType().value() != |
503 | target_expr->getExprType().value()) { |
504 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
505 | continue; |
506 | } |
507 | |
508 | // If there isn't an rfactor id in the replay's inputs and there's a |
509 | // mismatch in replay_expr's and target_expr's split factor (if a split |
510 | // expr), continue |
511 | if (replay_expr->getExprType().value() == ExprType::Split) { |
512 | auto r_split = replay_expr->as<Split>(); |
513 | auto t_split = target_expr->as<Split>(); |
514 | if (!r_split->factor()->sameAs(t_split->factor()) || |
515 | r_split->innerSplit() != t_split->innerSplit() || |
516 | !r_split->startOffset()->sameAs(t_split->startOffset()) || |
517 | !r_split->stopOffset()->sameAs(t_split->stopOffset())) { |
518 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
519 | continue; |
520 | } |
521 | } |
522 | |
523 | // Need to match swizzle type and parameters if |
524 | // not skipping swizzles in this mapping pass. |
525 | if (!skip_swizzle_ && replay_expr->etype() == ExprType::Swizzle2D) { |
526 | auto r_swizzle_2d = replay_expr->as<Swizzle2D>(); |
527 | auto t_swizzle_2d = target_expr->as<Swizzle2D>(); |
528 | if (!(r_swizzle_2d->swizzleType() == t_swizzle_2d->swizzleType())) { |
529 | TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); |
530 | continue; |
531 | } |
532 | } |
533 | |
534 | // Take replay expr inputs out of map: |
535 | for (const auto t_i : c10::irange(target_id_inps.size())) { |
536 | auto t_inp = target_id_inps[t_i]; |
537 | auto r_orig_inp = target2replay_id_map_.at(t_inp); |
538 | auto r_maybe_forwarded_inp = replay_inps[t_i]; |
539 | |
540 | // Remove original target2replay_it->second if it's in leaf_ids |
541 | if (leaf_ids_.find(r_orig_inp) != leaf_ids_.end()) { |
542 | leaf_ids_.erase(r_orig_inp); |
543 | } |
544 | |
545 | // Check if we used a forwarded id, if so add forwarded id's to tracking. |
546 | if (r_orig_inp != r_maybe_forwarded_inp) { |
547 | forwarded_ids_.emplace_back(r_orig_inp); |
548 | } |
549 | } |
550 | |
551 | // Add outputs to map. |
552 | for (const auto i : c10::irange(target_expr->outputs().size())) { |
553 | auto t_out = target_expr->output(i); |
554 | auto r_out = replay_expr->output(i); |
555 | if (t_out->getValType() == ValType::IterDomain && |
556 | r_out->getValType() == ValType::IterDomain) { |
557 | target2replay_id_map_[t_out->as<IterDomain>()] = |
558 | r_out->as<IterDomain>(); |
559 | leaf_ids_[r_out->as<IterDomain>()] = counter++; |
560 | } |
561 | } |
562 | |
563 | if (skip_swizzle_) { |
564 | // Progress through all swizzle ops if we are skipping |
565 | // swizzles on the mapping. |
566 | skipSwizzles(target_id2expr_map, replay_id2expr_map); |
567 | } |
568 | } |
569 | } |
570 | |
571 | // Find the first position i where td1[i] is not the same as td2[i]. |
572 | // "Same" means the DAG to generate td1[i] and td2[i] are the |
573 | // equivelent. |
574 | int BestEffortReplay::findFirstMismatchedID( |
575 | const TensorDomain* td1, |
576 | const TensorDomain* td2) { |
577 | std::unordered_map<IterDomain*, IterDomain*> id_map; |
578 | auto rd1 = td1->getRootDomain(); |
579 | auto rd2 = td2->getRootDomain(); |
580 | std::unordered_set<IterDomain*> rd2_set( |
581 | td2->getRootDomain().begin(), td2->getRootDomain().end()); |
582 | |
583 | // Find matching root IterDomains, we could make this O(nlog(n)) if we could |
584 | // sort IterDomains. |
585 | for (auto rd1i : rd1) { |
586 | for (auto rd2i : rd2) { |
587 | if (rd1i->sameAs(rd2i) && rd2_set.find(rd2i) != rd2_set.end()) { |
588 | id_map[rd1i] = rd2i; |
589 | rd2_set.erase(rd2i); |
590 | break; |
591 | } |
592 | } |
593 | } |
594 | |
595 | BestEffortReplay ber(td2->domain(), td1->domain(), id_map); |
596 | for (const auto i : |
597 | c10::irange(std::max(td1->domain().size(), td2->domain().size()))) { |
598 | if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) { |
599 | return i; |
600 | } |
601 | // Order is important. |
602 | auto td2_axis = ber.getReplay().at(td1->axis(i)); |
603 | if (td2->axis(i) != td2_axis) { |
604 | return i; |
605 | } |
606 | } |
607 | return std::min(td1->nDims(), td2->nDims()); |
608 | } |
609 | |
610 | namespace { |
611 | |
612 | // Maps that track information relevant to best effort replay about broadcast |
613 | // axes in consumer that are not in producer |
614 | // |
615 | // For example if we have consumer: T0[i0, b1, b2, i3] and producer: |
616 | // T1[i0, i3] |
617 | // |
618 | // If consumer transformations are: |
619 | // -> T[i0, b1o, b1i, b2o, b2i, i3] |
620 | // -> T[i0*b1i, b1o, b2o, b2i, i3] |
621 | // -> T[i0*b1i*b2o, b1o, b2i, i3] |
622 | // -> T[i0*b1i*b2o*i3, b1o, b2i] |
623 | // |
624 | // forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o |
625 | // compliment_map would have the entry i0->b1i and i0*b1i->b2o |
626 | // |
627 | // The first is to fast forward transformations in consumer involving broadcast |
628 | // axes not in producer. The compliment map is to use later to compute what leaf |
629 | // nodes we may have after the forwarding process is finished. Leaf nodes are |
630 | // only important for replayCasP, so look there to see how this is done. Forward |
631 | // map is used for replayCasP and replayPasC. |
632 | struct ConsumerForwardingInfo { |
633 | public: |
634 | // Map IterDomain* axes that can safely be forwarded to their output. |
635 | std::unordered_map<IterDomain*, IterDomain*> forwarding_map; |
636 | |
637 | // Given a forward id map id_input -> id_forwarded |
638 | // Track the other inputs in the expr that id_input is an input to. These will |
639 | // be used to adjust the replay's leaf tracking. Don't need to track one to |
640 | // many as currently transformations on IterDomains can only have maximum 2 |
641 | // inputs, but maybe in the future we'll have more. |
642 | std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map; |
643 | |
644 | ConsumerForwardingInfo( |
645 | const TensorView* producer, |
646 | const TensorView* consumer) { |
647 | // Collect which root axes are in consumer that are not in producer because |
648 | // of broadcasting |
649 | std::unordered_set<IterDomain*> consumer_bcast_roots_not_in_producer; |
650 | |
651 | const auto c2p_root_map = |
652 | PairwiseRootDomainMap(producer, consumer) |
653 | .mapConsumerToProducer(consumer->domain(), producer->domain()); |
654 | |
655 | for (auto consumer_root_id : consumer->getRootDomain()) { |
656 | if (consumer_root_id->isBroadcast()) { |
657 | if (c2p_root_map.find(consumer_root_id) == c2p_root_map.end()) { |
658 | consumer_bcast_roots_not_in_producer.emplace(consumer_root_id); |
659 | } |
660 | } |
661 | } |
662 | |
663 | // We have root axes in consumer that don't exist in producer, now forward |
664 | // those to include all id's in consumer comprised of only axes not in |
665 | // producer. |
666 | auto consumer_bcast_ids_not_in_producer = |
667 | consumer_bcast_roots_not_in_producer; |
668 | |
669 | std::vector<Expr*> consumer_history = StmtSort::getExprs( |
670 | FusionGuard::getCurFusion(), |
671 | std::vector<Val*>( |
672 | consumer->domain()->domain().begin(), |
673 | consumer->domain()->domain().end())); |
674 | |
675 | auto isIdOnlyInConsumer = |
676 | [&consumer_bcast_ids_not_in_producer](IterDomain* input_id) { |
677 | return consumer_bcast_ids_not_in_producer.find(input_id) != |
678 | consumer_bcast_ids_not_in_producer.end(); |
679 | }; |
680 | |
681 | for (auto expr : consumer_history) { |
682 | auto input_ids = ir_utils::filterByType<IterDomain>(expr->inputs()); |
683 | // If expr inputs are all in consumer_bcast_ids_not_in_producer, than so |
684 | // are all outputs |
685 | if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) { |
686 | // add all outputs to not being in producer |
687 | for (auto output_ids : |
688 | ir_utils::filterByType<IterDomain>(expr->outputs())) { |
689 | consumer_bcast_ids_not_in_producer.emplace(output_ids); |
690 | } |
691 | } else if ( |
692 | expr->isA<Merge>() && |
693 | std::any_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) { |
694 | auto merge_expr = expr->as<Merge>(); |
695 | // If |
696 | // - one of the inputs is made of id's in consumer that don't map to |
697 | // producer (bcast axes), |
698 | // - && the other input maps to an id in both consumer and producer |
699 | // - && this is a merge |
700 | // for the sake of BestEffortReplay we can forward the input mapping |
701 | // to both consumer and producer to the output of the expression |
702 | std::vector<IterDomain*> forwarded_ids; |
703 | std::vector<IterDomain*> compliment_ids; |
704 | |
705 | for (auto input_id : input_ids) { |
706 | if (!isIdOnlyInConsumer(input_id)) { |
707 | forwarded_ids.emplace_back(input_id); |
708 | forwarding_map.emplace(std::make_pair(input_id, merge_expr->out())); |
709 | } else { |
710 | compliment_ids.push_back(input_id); |
711 | } |
712 | } |
713 | |
714 | // Set up compliment map |
715 | for (auto forwarded_id : forwarded_ids) { |
716 | compliment_map.emplace(std::make_pair(forwarded_id, compliment_ids)); |
717 | } |
718 | } |
719 | } |
720 | } |
721 | }; |
722 | |
723 | // Maps that track information relevant to best effort replay about |
724 | // trivial-reduction axes in producer |
725 | // |
726 | // For example if we have producer: T0[i0, r1, r2, i3] and consumer: |
727 | // T1[i0, i3] |
728 | // |
729 | // If producer transformations are: |
730 | // -> T[i0, r1, r2, i3] |
731 | // -> T[i0*r1, r2, i3] |
732 | // -> T[i0*r1*r2, i3] |
733 | // |
734 | // forwarding_map would forward i0->i0*r1 and i0*r1->i0*r1*r2 |
735 | // compliment_map would have the i0->r1 and i0*r1->r2 |
736 | // |
737 | // These two maps are used similarly as ConsumerForwardingInfo. See |
738 | // its comments as well. |
739 | struct ProducerForwardingInfo { |
740 | public: |
741 | // Map IterDomain* axes that can safely be forwarded to their output. |
742 | std::unordered_map<IterDomain*, IterDomain*> forwarding_map; |
743 | |
744 | // Given a forward id map id_input -> id_forwarded |
745 | // Track the other inputs in the expr that id_input is an input to. These will |
746 | // be used to adjust the replay's leaf tracking. Don't need to track one to |
747 | // many as currently transformations on IterDomains can only have maximum 2 |
748 | // inputs, but maybe in the future we'll have more. |
749 | std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map; |
750 | |
751 | ProducerForwardingInfo(const TensorView* producer) { |
752 | std::vector<Expr*> producer_history = StmtSort::getExprs( |
753 | FusionGuard::getCurFusion(), |
754 | std::vector<Val*>( |
755 | producer->domain()->domain().begin(), |
756 | producer->domain()->domain().end())); |
757 | |
758 | for (auto merge : ir_utils::filterByType<Merge>(producer_history)) { |
759 | auto inner = merge->inner(); |
760 | auto outer = merge->outer(); |
761 | if ((inner->isTrivialReduction() && !outer->isReduction()) || |
762 | (outer->isTrivialReduction() && !inner->isReduction())) { |
763 | auto compliment_id = inner->isTrivialReduction() ? inner : outer; |
764 | auto forwarded_id = inner->isTrivialReduction() ? outer : inner; |
765 | forwarding_map.emplace(std::make_pair(forwarded_id, merge->out())); |
766 | compliment_map.emplace(std::make_pair( |
767 | forwarded_id, std::vector<IterDomain*>{compliment_id})); |
768 | } |
769 | } |
770 | } |
771 | }; |
772 | |
773 | // Trace chain of swizzles until reaching |
774 | // an IterDomain that's either a leaf or |
775 | // not a producer of any swizzle. |
776 | IterDomain* getSwizzleFinalOutput( |
777 | IterDomain* id, |
778 | const std::unordered_map<IterDomain*, Expr*>& id2expr) { |
779 | bool is_swizzle_input = true; |
780 | |
781 | // Note: currently not supporting swizzling consumer of another |
782 | // swizzle id, so this should terminate in 1 iter, but eventually |
783 | // will try to support stacked swizzles so keeping this pass |
784 | // generic. |
785 | while (is_swizzle_input) { |
786 | auto expr_it = id2expr.find(id); |
787 | |
788 | // This means id is a leaf that doesn't |
789 | // have any consumers. Stop iteration in this case. |
790 | if (expr_it == id2expr.end()) { |
791 | is_swizzle_input = false; |
792 | break; |
793 | } |
794 | |
795 | if (expr_it->second->etype() == ExprType::Swizzle2D) { |
796 | // In the case of 2D swizzle ops, just forward |
797 | // inX to outX and inY to outY. |
798 | auto expr = expr_it->second->as<Swizzle2D>(); |
799 | if (id == expr->inX()) { |
800 | id = expr->outX(); |
801 | } else { |
802 | TORCH_INTERNAL_ASSERT( |
803 | id == expr->inY(), |
804 | "unknown input to swizzle op" , |
805 | id->toString(), |
806 | expr->toString()); |
807 | id = expr->outY(); |
808 | } |
809 | } else { |
810 | // Probably unreachable but if the expression |
811 | // is unknown type assume it is not a swizzle op. |
812 | is_swizzle_input = false; |
813 | } |
814 | } |
815 | |
816 | return id; |
817 | } |
818 | |
819 | bool isSwizzleInput( |
820 | IterDomain* input_id, |
821 | const std::unordered_map<IterDomain*, Expr*>& id2expr) { |
822 | auto user_expr_it = id2expr.find(input_id); |
823 | |
824 | if (user_expr_it == id2expr.end()) { |
825 | return false; |
826 | } |
827 | |
828 | return user_expr_it->second->etype() == ExprType::Swizzle2D; |
829 | } |
830 | |
831 | } // namespace |
832 | |
833 | void BestEffortReplay::addComplimentLeafIDs( |
834 | const std::unordered_map<IterDomain*, IterDomain*>& forwarding_map, |
835 | const std::unordered_map<IterDomain*, std::vector<IterDomain*>>& |
836 | compliment_map) { |
837 | // ID's could go through more than one forward iteration in the map before it |
838 | // terminates. Grab every id between the forwarded id, and what it was |
839 | // forwarded to |
840 | std::function<void(IterDomain*, std::vector<IterDomain*>&)> |
841 | collectForwardedIds = |
842 | [&forwarding_map, &collectForwardedIds]( |
843 | IterDomain* forward_id, |
844 | std::vector<IterDomain*>& forwarded_ids) -> void { |
845 | if (forwarding_map.find(forward_id) != forwarding_map.end()) { |
846 | forwarded_ids.emplace_back(forward_id); |
847 | collectForwardedIds(forwarding_map.at(forward_id), forwarded_ids); |
848 | } |
849 | }; |
850 | |
851 | std::vector<IterDomain*> expanded_forwarded_ids; |
852 | for (auto forwarded_id : forwarded_ids_) { |
853 | collectForwardedIds(forwarded_id, expanded_forwarded_ids); |
854 | } |
855 | |
856 | // Grab all compliments of forwarded ids. |
857 | std::vector<IterDomain*> compliments; |
858 | for (auto forwarded_id : expanded_forwarded_ids) { |
859 | auto compliment_map_it = compliment_map.find(forwarded_id); |
860 | TORCH_INTERNAL_ASSERT( |
861 | compliment_map_it != compliment_map.end(), |
862 | "Issue tracking forwarded broadcast merges in best effort replay." ); |
863 | compliments.insert( |
864 | compliments.end(), |
865 | compliment_map_it->second.begin(), |
866 | compliment_map_it->second.end()); |
867 | } |
868 | |
869 | // Grab all exprs used to make the forwarded compliments |
870 | auto compliment_exprs = StmtSort::getExprs( |
871 | FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); |
872 | |
873 | // Figure out if there are any leaves in compliment_exprs that aren't |
874 | // the forwarded id |
875 | std::unordered_map<IterDomain*, size_t> leaf_ids; |
876 | |
877 | for (auto expr : compliment_exprs) { |
878 | for (auto inp : ir_utils::filterByType<IterDomain>(expr->inputs())) { |
879 | leaf_ids.erase(inp); |
880 | } |
881 | for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) { |
882 | // If we used the comliment for forwarded don't add to leaf nodes. |
883 | if (std::find(compliments.begin(), compliments.end(), out) == |
884 | compliments.end()) { |
885 | leaf_ids.emplace(std::make_pair(out, counter++)); |
886 | } |
887 | } |
888 | } |
889 | |
890 | leaf_ids_.insert(leaf_ids.begin(), leaf_ids.end()); |
891 | } |
892 | |
893 | BestEffortReplay BestEffortReplay::replayCasP( |
894 | const TensorView* consumer, |
895 | const TensorView* producer, |
896 | int producer_compute_at_axis, |
897 | const RootDomainMap& root_map) { |
898 | if (producer_compute_at_axis < 0) |
899 | producer_compute_at_axis += (int)producer->nDims() + 1; |
900 | |
901 | TORCH_INTERNAL_ASSERT( |
902 | producer_compute_at_axis >= 0 && |
903 | (unsigned int)producer_compute_at_axis <= producer->nDims(), |
904 | "Invalid axis provided to BestEffortReplay::replayCasP." ); |
905 | |
906 | // producer ids we need to match in consumer |
907 | std::vector<IterDomain*> producer_CA_ids( |
908 | producer->domain()->domain().begin(), |
909 | producer->domain()->domain().begin() + producer_compute_at_axis); |
910 | producer_CA_ids = TensorDomain::noReductions(producer_CA_ids); |
911 | |
912 | // If producer has an rfactor root, that's what will match the consumer |
913 | std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain(); |
914 | |
915 | // Figure out all inputs required to generate the compute_at dimensions. We |
916 | // need all deps because inputs on producer may be in getRootDomain, but we |
917 | // may need in rFactorDomain |
918 | auto all_CA_id_deps = DependencyCheck::getAllValsBetween( |
919 | {producer_root.begin(), producer_root.end()}, |
920 | {producer_CA_ids.begin(), producer_CA_ids.end()}); |
921 | |
922 | // Figure out minimal set of root IDs needed to produce producer_CA_ids: |
923 | std::unordered_set<IterDomain*> producer_CA_root_ids; |
924 | for (IterDomain* id : producer_root) { |
925 | if (std::find(all_CA_id_deps.begin(), all_CA_id_deps.end(), id) != |
926 | all_CA_id_deps.end()) { |
927 | producer_CA_root_ids.emplace(id); |
928 | } |
929 | } |
930 | |
931 | const auto p2c_root_map = root_map.mapProducerToConsumer( |
932 | producer->domain(), consumer->domain(), producer_CA_root_ids); |
933 | |
934 | // See FusionAdvancedComputeAt7 for an example of the forwarding logic |
935 | ConsumerForwardingInfo consumer_forwarding_info(producer, consumer); |
936 | |
937 | ProducerForwardingInfo producer_forwarding_info(producer); |
938 | |
939 | auto consumer_replay = BestEffortReplay( |
940 | consumer->domain()->domain(), |
941 | producer_CA_ids, |
942 | p2c_root_map, |
943 | consumer_forwarding_info.forwarding_map, |
944 | producer_forwarding_info.forwarding_map); |
945 | |
946 | consumer_replay.addComplimentLeafIDs( |
947 | consumer_forwarding_info.forwarding_map, |
948 | consumer_forwarding_info.compliment_map); |
949 | |
950 | return consumer_replay; |
951 | } |
952 | |
953 | // Runs a best effort replay that ignores broadcast axes that appear in |
954 | // consumer that are not mapped to producer in root_map. |
955 | BestEffortReplay BestEffortReplay::replayPasC( |
956 | const TensorView* producer, |
957 | const TensorView* consumer, |
958 | int consumer_compute_at_axis, |
959 | const RootDomainMap& root_map) { |
960 | if (consumer_compute_at_axis < 0) |
961 | consumer_compute_at_axis += (int)consumer->nDims() + 1; |
962 | TORCH_INTERNAL_ASSERT( |
963 | consumer_compute_at_axis >= 0 && |
964 | (unsigned int)consumer_compute_at_axis <= consumer->nDims(), |
965 | "Invalid axis provided to BestEffortReplay::replayPasC." ); |
966 | |
967 | // consumer ids we need to match in producer |
968 | std::vector<IterDomain*> consumer_CA_ids( |
969 | consumer->domain()->domain().begin(), |
970 | consumer->domain()->domain().begin() + consumer_compute_at_axis); |
971 | |
972 | // Figure out all inputs required to generate the compute_at dimensions |
973 | auto consumer_CA_root_vals = IterVisitor::getInputsTo( |
974 | std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end())); |
975 | |
976 | std::unordered_set<IterDomain*> consumer_CA_root_ids; |
977 | for (auto val : consumer_CA_root_vals) { |
978 | if (val->getValType().value() == ValType::IterDomain) { |
979 | consumer_CA_root_ids.emplace(val->as<IterDomain>()); |
980 | } |
981 | } |
982 | |
983 | const auto c2p_root_map = root_map.mapConsumerToProducer( |
984 | consumer->domain(), producer->domain(), consumer_CA_root_ids); |
985 | |
986 | ConsumerForwardingInfo consumer_forwarding_info(producer, consumer); |
987 | |
988 | ProducerForwardingInfo producer_forwarding_info(producer); |
989 | |
990 | // Instead of replaying from the root, lets try to play forward the history |
991 | // of producer if they match ops on consumer. Enforce if we modify an |
992 | // rfactor axis that those ops must match. |
993 | auto producer_replay = BestEffortReplay( |
994 | producer->domain()->domain(), |
995 | consumer_CA_ids, |
996 | c2p_root_map, |
997 | producer_forwarding_info.forwarding_map, |
998 | consumer_forwarding_info.forwarding_map); |
999 | |
1000 | producer_replay.addComplimentLeafIDs( |
1001 | producer_forwarding_info.forwarding_map, |
1002 | producer_forwarding_info.compliment_map); |
1003 | |
1004 | return producer_replay; |
1005 | } |
1006 | |
1007 | void BestEffortReplay::skipSwizzles( |
1008 | const std::unordered_map<IterDomain*, Expr*>& target_id2expr, |
1009 | const std::unordered_map<IterDomain*, Expr*>& replay_id2expr) { |
1010 | // Update target2replay map |
1011 | bool updated = true; |
1012 | |
1013 | while (updated) { |
1014 | updated = false; |
1015 | for (auto it : target2replay_id_map_) { |
1016 | if (isSwizzleInput(it.first, target_id2expr) || |
1017 | isSwizzleInput(it.second, replay_id2expr)) { |
1018 | updated = true; |
1019 | auto new_target = getSwizzleFinalOutput(it.first, target_id2expr); |
1020 | auto new_replay = getSwizzleFinalOutput(it.second, replay_id2expr); |
1021 | |
1022 | // new_target and new_replay will now be the final output |
1023 | // skipping all swizzles in between. We'd need to |
1024 | // update the mapping and leaf ids to the final outputs. |
1025 | target2replay_id_map_.erase(it.first); |
1026 | TORCH_INTERNAL_ASSERT( |
1027 | target2replay_id_map_.insert(std::make_pair(new_target, new_replay)) |
1028 | .second, |
1029 | "Unexpected replay leaf" ); |
1030 | // Progress the leaf ids if the replay is updated |
1031 | if (it.second != new_replay && |
1032 | leaf_ids_.find(it.second) != leaf_ids_.end()) { |
1033 | leaf_ids_.erase(it.second); |
1034 | leaf_ids_[new_replay] = counter++; |
1035 | } |
1036 | break; |
1037 | } |
1038 | } |
1039 | } |
1040 | } |
1041 | |
1042 | DisjointSets<IterDomain*> BestEffortReplay::getDisjointSets() { |
1043 | DisjointSets<IterDomain*> result; |
1044 | const std::unordered_map<IterDomain*, IterDomain*>* maps[3] = { |
1045 | &target2replay_id_map_, &replay_forward_id_map_, &target_forward_id_map_}; |
1046 | for (auto map : maps) { |
1047 | for (auto entry : *map) { |
1048 | result.mapEntries(entry.first, entry.second); |
1049 | } |
1050 | } |
1051 | return result; |
1052 | } |
1053 | |
1054 | } // namespace cuda |
1055 | } // namespace fuser |
1056 | } // namespace jit |
1057 | } // namespace torch |
1058 | |