1 | #include <arith.h> |
2 | #include <index_compute.h> |
3 | #include <instrumentation.h> |
4 | #include <ir_iostream.h> |
5 | #include <ir_utils.h> |
6 | #include <kernel_expr_evaluator.h> |
7 | #include <kernel_ir.h> |
8 | #include <lower2device.h> |
9 | #include <lower_index_compute.h> |
10 | #include <lower_shift.h> |
11 | #include <lower_utils.h> |
12 | |
13 | #include <functional> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | Expr* ShiftPredicateInserter::insert( |
21 | Expr* expr, |
22 | const std::vector<kir::ForLoop*>& loops, |
23 | Bool* thread_pred, |
24 | bool within_unswitch) { |
25 | const auto gpu_lower = GpuLower::current(); |
26 | |
27 | TensorView* out_tv = ir_utils::getTvOutput(expr); |
28 | TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output" ); |
29 | |
30 | const bool needs_shift_predicate = |
31 | gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); |
32 | if (!needs_shift_predicate) { |
33 | return expr; |
34 | } |
35 | |
36 | // The conditional branches to create: |
37 | // |
38 | // if (shift_pred) { |
39 | // consumer = producer; |
40 | // } else { |
41 | // if (padding_pred) { |
42 | // consumer = 0; |
43 | // } |
44 | // } |
45 | |
46 | kir::Predicate* thread_pred_expr = nullptr; |
47 | if (within_unswitch) { |
48 | thread_pred_expr = IrBuilder::create<kir::Predicate>(thread_pred); |
49 | } |
50 | |
51 | kir::Predicate* shift_pred = within_unswitch |
52 | ? thread_pred_expr |
53 | : IrBuilder::create<kir::Predicate>( |
54 | PredicateType::Shift, expr, thread_pred); |
55 | |
56 | // If the expr involves a thread-block barrier, set the predicate of |
57 | // the expr with shift_pred. Since the expr is not shift, the |
58 | // padding is safe to omit. |
59 | if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { |
60 | return expr->withPredicate(shift_pred); |
61 | } |
62 | |
63 | auto shift_ite = IrBuilder::create<kir::IfThenElse>(shift_pred); |
64 | |
65 | auto& scope = loops.back()->body(); |
66 | |
67 | // Insert the if statement |
68 | scope.insert_before(expr, shift_ite); |
69 | |
70 | // Remove the expr from the list |
71 | scope.erase(expr); |
72 | |
73 | // Place the expr inside the if statement |
74 | shift_ite->thenBody().push_back(expr); |
75 | |
76 | // No padding condition is required if this is within unswitch. |
77 | if (within_unswitch) { |
78 | return expr; |
79 | } |
80 | |
81 | // Padding by zero |
82 | kir::Predicate* padding_pred = IrBuilder::create<kir::Predicate>( |
83 | PredicateType::Padding, expr, thread_pred); |
84 | auto bounds_ite = IrBuilder::create<kir::IfThenElse>(padding_pred); |
85 | const int pad_value = 0; |
86 | auto pad_expr = IrBuilder::create<UnaryOp>( |
87 | UnaryOpType::Set, out_tv, IrBuilder::create<Int>(pad_value)); |
88 | bounds_ite->thenBody().push_back(pad_expr); |
89 | // Insert the else block |
90 | shift_ite->elseBody().push_back(bounds_ite); |
91 | |
92 | return expr; |
93 | } |
94 | |
95 | int AxisHaloInfo::width() const { |
96 | return width(0) + width(1); |
97 | } |
98 | |
99 | int AxisHaloInfo::width(int pos) const { |
100 | TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); |
101 | return widths_[pos]; |
102 | } |
103 | |
104 | void AxisHaloInfo::setWidth(int pos, int width) { |
105 | TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); |
106 | widths_[pos] = width; |
107 | } |
108 | |
109 | void AxisHaloInfo::merge(int pos, int other) { |
110 | auto new_width = std::max(width(pos), other); |
111 | setWidth(pos, new_width); |
112 | } |
113 | |
114 | void AxisHaloInfo::merge(const AxisHaloInfo& other) { |
115 | for (const auto i : c10::irange(widths_.size())) { |
116 | merge(i, other.width(i)); |
117 | } |
118 | } |
119 | |
120 | bool AxisHaloInfo::hasHalo() const { |
121 | return std::any_of( |
122 | widths_.begin(), widths_.end(), [](auto w) { return w != 0; }); |
123 | } |
124 | |
125 | std::string AxisHaloInfo::toString() const { |
126 | std::stringstream ss; |
127 | ss << "<" << width(0) << ", " << width(1) << ">" ; |
128 | return ss.str(); |
129 | } |
130 | |
131 | bool HaloInfo::hasRootAxisInfo(IterDomain* id) const { |
132 | return root_axis_map_.find(id) != root_axis_map_.end(); |
133 | } |
134 | |
135 | const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { |
136 | // TODO: Enable this check, was failing in many tests |
137 | // TORCH_INTERNAL_ASSERT( |
138 | // id->definition() == nullptr || id->isRFactorProduct(), |
139 | // "Invalid IterDomain: ", |
140 | // id); |
141 | auto it = root_axis_map_.find(id); |
142 | TORCH_INTERNAL_ASSERT( |
143 | it != root_axis_map_.end(), |
144 | "Halo root axis info not found for " , |
145 | id->toString()); |
146 | return it->second; |
147 | } |
148 | |
149 | void HaloInfo::setRootAxisInfo( |
150 | IterDomain* id, |
151 | const AxisHaloInfo& root_axis_info) { |
152 | root_axis_map_[id] = root_axis_info; |
153 | |
154 | initializeFromRootAxisInfo(id); |
155 | return; |
156 | } |
157 | |
158 | HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map) |
159 | // Make a copy of the permissive map for extent comparators |
160 | : permissive_map_(ca_map->idGraph().permissiveNodes()) { |
161 | const auto vals = fusion->usedMathVals(); |
162 | auto tvs = ir_utils::filterByType<TensorView>(vals); |
163 | |
164 | // Initialize all root axis info |
165 | for (auto tv : tvs) { |
166 | for (auto root_axis : tv->getRootDomain()) { |
167 | setRootAxisInfo(root_axis, AxisHaloInfo()); |
168 | } |
169 | // Just adds a placeholder to make it not fail. Reduction and |
170 | // rfactor support is not yet in place. |
171 | if (tv->hasRFactor()) { |
172 | for (auto rf_root_axis : tv->getRFactorDomain()) { |
173 | setRootAxisInfo(rf_root_axis, AxisHaloInfo()); |
174 | } |
175 | } |
176 | } |
177 | |
178 | // Propagate backward halo information of root axes from fusion |
179 | // outputs to inputs |
180 | auto exprs = fusion->exprs(); |
181 | for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) { |
182 | auto expr = *it; |
183 | if (!expr->outputs()[0]->isA<TensorView>()) { |
184 | continue; |
185 | } |
186 | |
187 | propagateRootAxisInfo(expr); |
188 | } |
189 | |
190 | // Propagates halo information from root axes down to leaf axes |
191 | for (auto tv : tvs) { |
192 | build(tv->domain()); |
193 | } |
194 | |
195 | if (isDebugDumpEnabled(DebugDumpOption::Halo)) { |
196 | std::cout << toString() << std::endl; |
197 | } |
198 | |
199 | // Note that validation requires consumer halo info |
200 | for (auto tv : tvs) { |
201 | validate(tv, ca_map); |
202 | } |
203 | } |
204 | |
205 | void HaloInfo::propagateRootAxisInfo(Expr* expr) { |
206 | for (auto output : expr->outputs()) { |
207 | auto out_tv = dynamic_cast<TensorView*>(output); |
208 | if (out_tv == nullptr) { |
209 | continue; |
210 | } |
211 | for (auto input : expr->inputs()) { |
212 | auto in_tv = dynamic_cast<TensorView*>(input); |
213 | if (in_tv == nullptr) { |
214 | continue; |
215 | } |
216 | propagateRootAxisInfo(in_tv, out_tv, expr); |
217 | } |
218 | } |
219 | } |
220 | |
221 | void HaloInfo::propagateRootAxisInfo( |
222 | TensorView* producer, |
223 | TensorView* consumer, |
224 | Expr* expr) { |
225 | // Do not add halo to input tensors |
226 | if (producer->isFusionInput()) { |
227 | return; |
228 | } |
229 | |
230 | auto c2p = PairwiseRootDomainMap(producer, consumer) |
231 | .mapConsumerToProducer(consumer->domain(), producer->domain()); |
232 | |
233 | const auto& c_root = consumer->getRootDomain(); |
234 | |
235 | for (const auto i : c10::irange(c_root.size())) { |
236 | auto c_id = c_root[i]; |
237 | auto it = c2p.find(c_id); |
238 | if (it == c2p.end()) { |
239 | // nothing to propagate |
240 | continue; |
241 | } |
242 | |
243 | // propagate root-axis halo info from c_id to p_id |
244 | |
245 | auto p_id = it->second; |
246 | |
247 | AxisHaloInfo p_info; |
248 | if (hasRootAxisInfo(p_id)) { |
249 | p_info = getRootAxisInfo(p_id); |
250 | } |
251 | const auto c_info = getRootAxisInfo(c_id); |
252 | |
253 | // If the root axes are broadcast, no halo should be associated |
254 | // with them. |
255 | if (c_id->isBroadcast()) { |
256 | TORCH_INTERNAL_ASSERT(!c_info.hasHalo()); |
257 | p_info.merge(c_info); |
258 | setRootAxisInfo(p_id, p_info); |
259 | continue; |
260 | } else if (p_id->isRFactorProduct()) { |
261 | TORCH_INTERNAL_ASSERT( |
262 | !c_info.hasHalo(), |
263 | "Propagating halo info to a rfactor producer domain not yet supported." ); |
264 | continue; |
265 | } |
266 | |
267 | // If the defining expression is shift, adjust the producer halo |
268 | // width based on the shift offset. If the shift offset is |
269 | // positive, create halo at offset zero of the producer axis so |
270 | // that the consumer can safely access the producer. If the offset |
271 | // is negative, halo is created at the other end of the axis. |
272 | // If the expr is not shift, just merge the consumer halo info |
273 | // to the producer halo info so that the producer halo can be the |
274 | // maximum of all its consumers. |
275 | if (auto shift_op = dynamic_cast<ShiftOp*>(expr)) { |
276 | const auto offset = shift_op->offset(i); |
277 | if (offset == 0) { |
278 | p_info.merge(c_info); |
279 | } else { |
280 | int pos = (offset > 0) ? 0 : 1; |
281 | p_info.merge(pos, c_info.width(pos) + std::abs(offset)); |
282 | } |
283 | } else if (auto gather_op = dynamic_cast<GatherOp*>(expr)) { |
284 | const auto window_dim = gather_op->windowShape()[i]; |
285 | if (window_dim == 1) { |
286 | p_info.merge(c_info); |
287 | continue; |
288 | } |
289 | const auto pad_dim0 = gather_op->padWidth()[i][0]; |
290 | p_info.merge(0, c_info.width(0) + pad_dim0); |
291 | // The right-side halo is propagated as: |
292 | // consumer_right_halo + (window_dim - 1 - left_padding) |
293 | p_info.merge(1, c_info.width(1) + window_dim - 1 - pad_dim0); |
294 | } else { |
295 | p_info.merge(c_info); |
296 | } |
297 | setRootAxisInfo(p_id, p_info); |
298 | } |
299 | } |
300 | |
301 | void HaloInfo::insertToInheritanceMap( |
302 | TensorDomain* td, |
303 | IterDomain* parent, |
304 | IterDomain* child) { |
305 | // Check each root domain to see if its set includes the parent. If |
306 | // so, adds the child to the same set. |
307 | bool inserted = false; |
308 | for (auto root_axis : td->getRootDomain()) { |
309 | auto it = inheritance_map_.find(root_axis); |
310 | if (it == inheritance_map_.end()) { |
311 | continue; |
312 | } |
313 | auto& id_set = it->second; |
314 | if (id_set.find(parent) != id_set.end()) { |
315 | id_set.insert(child); |
316 | inserted = true; |
317 | } |
318 | } |
319 | // No matching set found. This should not happen. |
320 | TORCH_INTERNAL_ASSERT(inserted); |
321 | } |
322 | |
323 | void HaloInfo::initializeFromRootAxisInfo(IterDomain* id) { |
324 | TORCH_INTERNAL_ASSERT(hasRootAxisInfo(id)); |
325 | |
326 | const auto& halo_info = getRootAxisInfo(id); |
327 | auto halo_width = halo_info.width(); |
328 | |
329 | if (!halo_info.hasHalo()) { |
330 | setHaloWidth(id, 0); |
331 | return; |
332 | } |
333 | |
334 | auto expanded_extent = |
335 | IrBuilder::addExpr(id->extent(), IrBuilder::create<Int>(halo_width)); |
336 | extent_map_[id] = expanded_extent; |
337 | halo_width_map_[id] = halo_width; |
338 | |
339 | inheritance_map_[id] = {id}; |
340 | } |
341 | |
342 | void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) { |
343 | halo_width_map_[id] = halo_width; |
344 | } |
345 | |
346 | // Propagate extent information from root axes to descendants |
347 | void HaloInfo::build(TensorDomain* td) { |
348 | auto exprs = DependencyCheck::getAllExprsBetween( |
349 | {td->getMaybeRFactorDomain().begin(), td->getMaybeRFactorDomain().end()}, |
350 | {td->domain().begin(), td->domain().end()}); |
351 | |
352 | // Track IDs that are generated by merging halo-extended IDs |
353 | std::unordered_set<IterDomain*> merged_shifted_ids; |
354 | |
355 | // Propagate halo information by traversing IterDomain |
356 | // expressions. We populate extent_map_ and |
357 | // halo_width_map_. |
358 | // - extent_map_ maps to Expr* representing the |
359 | // extent of each axis including its halo. If no mapping exists for |
360 | // a particular axis in extent_map_, it means the axis does not have |
361 | // halo. |
362 | // - halo_width_map_ just maps to the integer size of the halo, |
363 | // which is used for extent comparison (e.g., extentLessEqual). |
364 | // |
365 | // - When expr is split: if the halo width of the input axis is |
366 | // zero, both the split outputs get zero halo in halo_width_map_. No |
367 | // mapping is added for extent_map_. Otherwise, the halo is |
368 | // propagated only to the inner output, so the inner output gets the |
369 | // same halo width and its mapping is created in extent_map_. |
370 | // |
371 | // One major assumption here is that splitting an axis that is |
372 | // an output of merging halo-extended axes is not allowed. This is |
373 | // because it is unclear how to split the halo part of the merged |
374 | // axis. This is unlikely to be a real limitation in practice. |
375 | // |
376 | // - When expr is merge: if either of the inputs has halo, a mapping |
377 | // for the output is created in extent_map_. No mapping is created |
378 | // for halo_width_map_ (see the comment on HaloInfo::halo_width_map_ |
379 | // in lower_shift.h). If both of them don't have halo, just adds a |
380 | // new mapping of the output to zero in halo_width_map_. Also adds |
381 | // it to a set (merged_shifted_ids) to track which axes are merge |
382 | // outputs of halo-extended axes. |
383 | |
384 | for (auto expr : exprs) { |
385 | if (auto split = dynamic_cast<Split*>(expr)) { |
386 | // Merge-then-split of halo-extended IDs is not allowed |
387 | TORCH_INTERNAL_ASSERT( |
388 | merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(), |
389 | "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed" ); |
390 | |
391 | auto in_id = split->in(); |
392 | |
393 | // If no halo info is found, nothing needs to be done. This ID |
394 | // must be an ancestor of a domain set by setRootAxisInfo. |
395 | if (!hasHaloWidth(in_id)) { |
396 | continue; |
397 | } |
398 | |
399 | const auto halo_width = getHaloWidth(in_id); |
400 | |
401 | if (halo_width == 0) { |
402 | setHaloWidth(split->outer(), 0); |
403 | setHaloWidth(split->inner(), 0); |
404 | continue; |
405 | } |
406 | |
407 | // propagate to inner domain |
408 | auto out_id = split->inner(); |
409 | |
410 | auto expanded_extent = |
411 | SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); |
412 | extent_map_.insert({out_id, expanded_extent}); |
413 | |
414 | setHaloWidth(split->outer(), 0); |
415 | setHaloWidth(split->inner(), halo_width); |
416 | |
417 | insertToInheritanceMap(td, in_id, split->inner()); |
418 | } else if (auto merge = dynamic_cast<Merge*>(expr)) { |
419 | // If either of the two inputs has halo extension, propagate it |
420 | // to the merged output ID |
421 | auto inner_extent = getExtent(merge->inner()); |
422 | auto outer_extent = getExtent(merge->outer()); |
423 | if (inner_extent != nullptr || outer_extent != nullptr) { |
424 | if (inner_extent == nullptr) { |
425 | inner_extent = merge->inner()->extent(); |
426 | } else { |
427 | insertToInheritanceMap(td, merge->inner(), merge->out()); |
428 | } |
429 | if (outer_extent == nullptr) { |
430 | outer_extent = merge->outer()->extent(); |
431 | } else { |
432 | insertToInheritanceMap(td, merge->outer(), merge->out()); |
433 | } |
434 | auto expanded_extent = |
435 | SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); |
436 | extent_map_.insert({merge->out(), expanded_extent}); |
437 | // Splitting the output of this merge is not allowed, so |
438 | // remember it |
439 | merged_shifted_ids.insert(merge->out()); |
440 | // Note that halo_width_map_ is not updated |
441 | } else { |
442 | setHaloWidth(merge->out(), 0); |
443 | } |
444 | } else if (auto swizzle = dynamic_cast<Swizzle2D*>(expr)) { |
445 | // Assume no halo on swizzled domain for now. |
446 | TORCH_INTERNAL_ASSERT( |
447 | getExtent(swizzle->inX()) == nullptr, |
448 | "Halo is not supported with swizzle. Halo-extended ID: " , |
449 | swizzle->inX()->toString(), |
450 | " used in " , |
451 | swizzle->toString()); |
452 | TORCH_INTERNAL_ASSERT( |
453 | getExtent(swizzle->inY()) == nullptr, |
454 | "Halo is not supported with swizzle. Halo-extended ID: " , |
455 | swizzle->inY()->toString(), |
456 | " used in " , |
457 | swizzle->toString()); |
458 | for (auto id : ir_utils::filterByType<IterDomain>(expr->outputs())) { |
459 | setHaloWidth(id, 0); |
460 | } |
461 | } else { |
462 | TORCH_INTERNAL_ASSERT(false, "Unsupported expr: " , expr); |
463 | } |
464 | } |
465 | } |
466 | |
467 | //! Restriction 1: When allocation is outside of a shifted |
468 | //! axis, the shifted axis must be guaranteed to have a smaller extent |
469 | //! than the concrete axis. For now, shifted axes always mean expanded |
470 | //! allocations when the axis is located inside the allocation |
471 | //! point. This restriction is validated at the allocation lowering |
472 | //! pass. |
473 | //! |
474 | //! Restriction 2: If an expanded axis is parallelized, its memory |
475 | //! must be accessible by all other threads. More specifically: |
476 | //! - TIDx: It must be on shared memory. May want to consider |
477 | //! utilizing the shuffle instructions as well. |
478 | //! - BIDx: Not supported. If on global memory, Cooperative Launch |
479 | //! may be used to support it, however, it's unclear in what |
480 | //! situations block-level parallelization should be used. |
481 | //! |
482 | //! Other types of parallelization should be supported except for |
483 | //! vectorization. Vectorization should be eventually supported but |
484 | //! needs further work. |
485 | void HaloInfo::validate( |
486 | TensorView* tv, |
487 | std::shared_ptr<const ComputeAtMap> ca_map) const { |
488 | const auto mem_type = tv->getMemoryType(); |
489 | |
490 | for (auto axis : tv->domain()->domain()) { |
491 | auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP); |
492 | |
493 | // The extent is assumed to be the same |
494 | TORCH_INTERNAL_ASSERT( |
495 | extentEqual(axis, concrete_id), |
496 | "Axis does not have the same exact size with its concrete ID due to halo extension." , |
497 | " Tensor: T" , |
498 | tv->name(), |
499 | ", Axis: " , |
500 | axis, |
501 | ", concrete ID: " , |
502 | concrete_id); |
503 | |
504 | auto halo_extent = getExtent(axis); |
505 | |
506 | // If no halo extent is associated with this axis, it means the |
507 | // axis is not extended. |
508 | if (halo_extent == nullptr) { |
509 | continue; |
510 | } |
511 | |
512 | // Enforce restrictions on parallelization and memory type |
513 | const auto ptype = concrete_id->getParallelType(); |
514 | |
515 | if (ptype == ParallelType::Serial) { |
516 | continue; |
517 | } |
518 | |
519 | // Only threading parallelism is considered for now |
520 | TORCH_CHECK( |
521 | isParallelTypeThread(ptype), "Unsupported parallel type: " , ptype); |
522 | |
523 | bool shared_mem_needed = false; |
524 | for (auto use : tv->uses()) { |
525 | if (!ir_utils::isTvOp(use)) { |
526 | continue; |
527 | } |
528 | if (use->isA<ShiftOp>() || use->isA<GatherOp>()) { |
529 | shared_mem_needed = true; |
530 | break; |
531 | } |
532 | auto consumer = use->outputs()[0]->as<TensorView>(); |
533 | // Find the corresponding axis in the consumer |
534 | auto it = std::find_if( |
535 | consumer->domain()->domain().begin(), |
536 | consumer->domain()->domain().end(), |
537 | [&](IterDomain* consumer_axis) { |
538 | return ca_map->areMapped( |
539 | axis, consumer_axis, IdMappingMode::PERMISSIVE); |
540 | }); |
541 | if (it == consumer->domain()->domain().end()) { |
542 | continue; |
543 | } |
544 | if (!extentEqual(axis, *it)) { |
545 | shared_mem_needed = true; |
546 | break; |
547 | } |
548 | } |
549 | |
550 | if (!shared_mem_needed) { |
551 | continue; |
552 | } |
553 | |
554 | if (isParallelTypeThreadDim(ptype)) { |
555 | // If all the consumers have the same extent and none of the |
556 | // expressions is shift, any memory should be fine. Otherwise, it |
557 | // must be accessible by all threads involved in the |
558 | // parallelization. |
559 | TORCH_CHECK( |
560 | mem_type == MemoryType::Shared, |
561 | "TV" , |
562 | tv->name(), |
563 | " must be allocated on shared memory as its halo-extended axis is parallelized by " , |
564 | ptype); |
565 | |
566 | } else if (isParallelTypeBlockDim(ptype)) { |
567 | TORCH_CHECK( |
568 | false, |
569 | "Block-based parallelization of a halo-extended axis is not supported: " , |
570 | axis); |
571 | } |
572 | } |
573 | return; |
574 | } |
575 | |
576 | Val* HaloInfo::getExtent(IterDomain* id) const { |
577 | auto it = extent_map_.find(id); |
578 | if (it != extent_map_.end()) { |
579 | return it->second; |
580 | } else { |
581 | return nullptr; |
582 | } |
583 | } |
584 | |
585 | int HaloInfo::getHaloWidth(IterDomain* id) const { |
586 | auto it = halo_width_map_.find(id); |
587 | TORCH_INTERNAL_ASSERT(it != halo_width_map_.end()); |
588 | return it->second; |
589 | } |
590 | |
591 | bool HaloInfo::hasHaloWidth(IterDomain* id) const { |
592 | return halo_width_map_.find(id) != halo_width_map_.end(); |
593 | } |
594 | |
595 | const std::unordered_set<IterDomain*>& HaloInfo::getChildDomains( |
596 | IterDomain* root_id) const { |
597 | auto it = inheritance_map_.find(root_id); |
598 | TORCH_INTERNAL_ASSERT( |
599 | it != inheritance_map_.end(), |
600 | "Domain not found in the inheritance map: " , |
601 | root_id); |
602 | return it->second; |
603 | } |
604 | |
605 | bool HaloInfo::isHaloInherited(IterDomain* root_id, IterDomain* id) const { |
606 | return getChildDomains(root_id).count(id) > 0; |
607 | } |
608 | |
609 | std::unordered_set<IterDomain*> HaloInfo::getRootDomains(IterDomain* id) const { |
610 | std::unordered_set<IterDomain*> id_set; |
611 | |
612 | for (const auto& kv : inheritance_map_) { |
613 | if (kv.second.count(id) > 0) { |
614 | id_set.insert(kv.first); |
615 | } |
616 | } |
617 | |
618 | return id_set; |
619 | } |
620 | |
621 | namespace { |
622 | |
623 | //! Prove if the comparison operator, cmp, is true with the extents of |
624 | //! id1 and id2, including their halo. The comparison is done |
625 | //! conservatively, meaning false negative is possible. |
626 | //! |
627 | //! It is assumed that id1 and id2 are mapped with the CA Loop map, so |
628 | //! what is checked here is only about halo |
629 | //! sizes using HaloInfo::halo_width_map_. Since it does not have |
630 | //! mappings for merged axes, each axis of merge inputs are |
631 | //! individually compared, and only when both of the input axes |
632 | //! return true, the merge output axis returns true. |
633 | template <typename Cmp> |
634 | bool extentCompare( |
635 | const HaloInfo& halo_map, |
636 | IterDomain* id1, |
637 | IterDomain* id2, |
638 | Cmp cmp, |
639 | const DisjointSets<IterDomain*>& permissive_map) { |
640 | TORCH_INTERNAL_ASSERT( |
641 | permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare" ); |
642 | |
643 | // It's invalid to compare two axes and when only either of them has |
644 | // halo. |
645 | |
646 | if (halo_map.hasHaloWidth(id1)) { |
647 | TORCH_INTERNAL_ASSERT( |
648 | halo_map.hasHaloWidth(id2), "Invalid comparison: " , id1, " and " , id2); |
649 | // Both axes have halo. We assume the axes themselves have equal |
650 | // extents, excluding halo, as they are mapped with the CA |
651 | // map. So, we just need to compare the halo width of each axis. |
652 | return cmp(halo_map.getHaloWidth(id1), halo_map.getHaloWidth(id2)); |
653 | } else { |
654 | TORCH_INTERNAL_ASSERT(!halo_map.hasHaloWidth(id2)); |
655 | // Both don't have halo. The only case this can happen must be |
656 | // both axes are the output of a merge expression, so each merge |
657 | // input is recursively compared, and returns true only when both |
658 | // inputs return. |
659 | if (auto merge1 = dynamic_cast<Merge*>(id1->definition())) { |
660 | auto merge2 = dynamic_cast<Merge*>(id2->definition()); |
661 | TORCH_INTERNAL_ASSERT( |
662 | merge2 != nullptr, "Invalid comparison: " , id1, " and " , id2); |
663 | auto inner_le = extentCompare( |
664 | halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map); |
665 | auto outer_le = extentCompare( |
666 | halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map); |
667 | return inner_le && outer_le; |
668 | } else { |
669 | // This is not considered. Should never reach here. |
670 | TORCH_INTERNAL_ASSERT(false, "Invalid comparison: " , id1, " and " , id2); |
671 | } |
672 | } |
673 | } |
674 | |
675 | } // namespace |
676 | |
677 | bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { |
678 | return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_); |
679 | } |
680 | |
681 | bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { |
682 | return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_); |
683 | } |
684 | |
685 | std::string HaloInfo::toString() const { |
686 | std::stringstream ss; |
687 | |
688 | ss << "HaloInfo:\n" ; |
689 | |
690 | if (root_axis_map_.empty()) { |
691 | return ss.str(); |
692 | } |
693 | |
694 | Fusion* fusion = root_axis_map_.begin()->first->fusion(); |
695 | |
696 | auto used_vals = DependencyCheck::getAllValsBetween( |
697 | {fusion->inputs().begin(), fusion->inputs().end()}, fusion->outputs()); |
698 | |
699 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
700 | const auto& root = tv->getRootDomain(); |
701 | ss << "TV" << tv->name() << " root domain: " ; |
702 | for (auto axis : root) { |
703 | ss << axis << " -> " << getRootAxisInfo(axis).toString() << ", " ; |
704 | } |
705 | ss << "\n" ; |
706 | } |
707 | |
708 | return ss.str(); |
709 | } |
710 | |
711 | bool HaloInfo::needsShiftPredicate(Expr* expr) const { |
712 | // In lowering shift and gather turn into a unary op. We really need the shift |
713 | // expr. Do a round about trick to grab it: |
714 | auto tv_out = ir_utils::getTvOutput(expr); |
715 | auto consumer_td = tv_out->domain(); |
716 | auto shift_expr = dynamic_cast<ShiftOp*>(tv_out->definition()); |
717 | auto gather_expr = dynamic_cast<GatherOp*>(tv_out->definition()); |
718 | for (const auto i : c10::irange(consumer_td->getRootDomain().size())) { |
719 | auto consumer_id = consumer_td->getRootDomain()[i]; |
720 | const auto consumer_halo_info = getRootAxisInfo(consumer_id); |
721 | if (consumer_halo_info.hasHalo() || |
722 | (shift_expr != nullptr && shift_expr->offset(i) != 0 && |
723 | !consumer_id->isBroadcast()) || |
724 | (gather_expr != nullptr && gather_expr->windowShape()[i] != 1 && |
725 | !consumer_id->isBroadcast())) { |
726 | return true; |
727 | } |
728 | } |
729 | return false; |
730 | } |
731 | |
732 | std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap( |
733 | const LoopIndexing& loop_indexing) const { |
734 | // Use a local workspace to avoid re-defining halo info. |
735 | HaloInfo local_halo_info = *GpuLower::current()->haloInfo(); |
736 | |
737 | auto global_halo_info = GpuLower::current()->haloInfo(); |
738 | |
739 | // Setup root: |
740 | for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) { |
741 | auto consumer_index_concrete_id = |
742 | GpuLower::current()->caMap()->getConcreteMappedID( |
743 | consumer_root_id, IdMappingMode::EXACT); |
744 | local_halo_info.setRootAxisInfo( |
745 | consumer_index_concrete_id, |
746 | global_halo_info->getRootAxisInfo(consumer_root_id)); |
747 | } |
748 | |
749 | // Track IDs that are generated by merging halo-extended IDs |
750 | std::unordered_set<IterDomain*> merged_shifted_ids; |
751 | |
752 | for (auto expr : loop_indexing.getForwardExprList()) { |
753 | if (auto split = dynamic_cast<Split*>(expr)) { |
754 | // Merge-then-split of halo-extended IDs is not allowed |
755 | TORCH_INTERNAL_ASSERT( |
756 | merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(), |
757 | "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed" ); |
758 | |
759 | auto in_id = GpuLower::current()->caMap()->getConcreteMappedID( |
760 | split->in(), IdMappingMode::EXACT); |
761 | |
762 | // If no halo info is found, nothing needs to be done. This ID |
763 | // must be an ancestor of a domain set by setRootAxisInfo. |
764 | if (!local_halo_info.hasHaloWidth(in_id)) { |
765 | continue; |
766 | } |
767 | |
768 | const auto halo_width = local_halo_info.getHaloWidth(in_id); |
769 | |
770 | if (halo_width == 0) { |
771 | local_halo_info.setHaloWidth( |
772 | GpuLower::current()->caMap()->getConcreteMappedID( |
773 | split->outer(), IdMappingMode::EXACT), |
774 | 0); |
775 | local_halo_info.setHaloWidth( |
776 | GpuLower::current()->caMap()->getConcreteMappedID( |
777 | split->inner(), IdMappingMode::EXACT), |
778 | 0); |
779 | continue; |
780 | } |
781 | |
782 | // propagate to inner domain |
783 | auto out_id = GpuLower::current()->caMap()->getConcreteMappedID( |
784 | split->inner(), IdMappingMode::EXACT); |
785 | |
786 | auto expanded_extent = |
787 | SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); |
788 | local_halo_info.extent_map_.insert({out_id, expanded_extent}); |
789 | |
790 | local_halo_info.setHaloWidth( |
791 | GpuLower::current()->caMap()->getConcreteMappedID( |
792 | split->outer(), IdMappingMode::EXACT), |
793 | 0); |
794 | local_halo_info.setHaloWidth( |
795 | GpuLower::current()->caMap()->getConcreteMappedID( |
796 | split->inner(), IdMappingMode::EXACT), |
797 | halo_width); |
798 | |
799 | // TODO: add support for inheritance map |
800 | } else if (auto merge = dynamic_cast<Merge*>(expr)) { |
801 | // If either of the two inputs has halo extension, propagate it |
802 | // to the merged output ID |
803 | auto inner_extent = local_halo_info.getExtent( |
804 | GpuLower::current()->caMap()->getConcreteMappedID( |
805 | merge->inner(), IdMappingMode::EXACT)); |
806 | auto outer_extent = local_halo_info.getExtent( |
807 | GpuLower::current()->caMap()->getConcreteMappedID( |
808 | merge->outer(), IdMappingMode::EXACT)); |
809 | if (inner_extent != nullptr || outer_extent != nullptr) { |
810 | if (inner_extent == nullptr) { |
811 | inner_extent = merge->inner()->extent(); |
812 | } |
813 | if (outer_extent == nullptr) { |
814 | outer_extent = merge->outer()->extent(); |
815 | } |
816 | auto expanded_extent = |
817 | SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); |
818 | local_halo_info.extent_map_.insert( |
819 | {GpuLower::current()->caMap()->getConcreteMappedID( |
820 | merge->out(), IdMappingMode::EXACT), |
821 | expanded_extent}); |
822 | // Splitting the output of this merge is not allowed, so |
823 | // remember it |
824 | merged_shifted_ids.insert( |
825 | GpuLower::current()->caMap()->getConcreteMappedID( |
826 | merge->out(), IdMappingMode::EXACT)); |
827 | // Note that halo_width_map_ is not updated |
828 | } else { |
829 | local_halo_info.setHaloWidth( |
830 | GpuLower::current()->caMap()->getConcreteMappedID( |
831 | merge->out(), IdMappingMode::EXACT), |
832 | 0); |
833 | } |
834 | } else if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(expr)) { |
835 | // Swizzle with halo not yet supported, just set the width |
836 | // to zero at the moment. |
837 | TORCH_INTERNAL_ASSERT( |
838 | local_halo_info.getHaloWidth( |
839 | GpuLower::current()->caMap()->getConcreteMappedID( |
840 | swizzle_2d->inX(), IdMappingMode::EXACT)) == 0 && |
841 | local_halo_info.getHaloWidth( |
842 | GpuLower::current()->caMap()->getConcreteMappedID( |
843 | swizzle_2d->inY(), IdMappingMode::EXACT)) == 0, |
844 | "Swizzle on ID with halo not yet supported." ); |
845 | TORCH_INTERNAL_ASSERT("Swizzle on ID with halo not yet supported." ); |
846 | local_halo_info.setHaloWidth( |
847 | GpuLower::current()->caMap()->getConcreteMappedID( |
848 | swizzle_2d->outX(), IdMappingMode::EXACT), |
849 | 0); |
850 | local_halo_info.setHaloWidth( |
851 | GpuLower::current()->caMap()->getConcreteMappedID( |
852 | swizzle_2d->outY(), IdMappingMode::EXACT), |
853 | 0); |
854 | } else { |
855 | TORCH_INTERNAL_ASSERT(false, "Unsupported expr: " , expr); |
856 | } |
857 | } |
858 | |
859 | return local_halo_info.extent_map_; |
860 | } |
861 | |
862 | } // namespace cuda |
863 | } // namespace fuser |
864 | } // namespace jit |
865 | } // namespace torch |
866 | |