1 | #include <index_compute.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/irange.h> |
5 | #include <arith.h> |
6 | #include <contiguity.h> |
7 | #include <expr_evaluator.h> |
8 | #include <instrumentation.h> |
9 | #include <ir_all_nodes.h> |
10 | #include <ir_iostream.h> |
11 | #include <ir_utils.h> |
12 | #include <kernel_expr_evaluator.h> |
13 | #include <lower2device.h> |
14 | #include <lower_double_buffer.h> |
15 | #include <lower_index_compute.h> |
16 | #include <lower_magic_zero.h> |
17 | #include <lower_shift.h> |
18 | #include <lower_unroll.h> |
19 | #include <lower_utils.h> |
20 | #include <lower_validation.h> |
21 | #include <root_domain_map.h> |
22 | #include <transform_iter.h> |
23 | #include <transform_replay.h> |
24 | |
25 | namespace torch { |
26 | namespace jit { |
27 | namespace fuser { |
28 | namespace cuda { |
29 | |
30 | namespace { |
31 | |
32 | //! Offset of an index of a producer axis with respect to its |
33 | //! corresponding consumer index |
34 | int getProducerHaloOffset( |
35 | const TensorView* producer_tv, |
36 | size_t producer_axis, |
37 | const TensorView* consumer_tv) { |
38 | auto p2c = |
39 | PairwiseRootDomainMap(producer_tv, consumer_tv) |
40 | .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain()); |
41 | |
42 | auto producer_id = producer_tv->getMaybeRFactorDomain()[producer_axis]; |
43 | |
44 | auto it = p2c.find(producer_id); |
45 | // p2c should always have a mapping for producer_id. The only case |
46 | // where no mapping exists for a producer axis is when it is a |
47 | // reduction axis. Since this function is only used for indexing |
48 | // producer tensors, where reduction axes are skipped, producer_id |
49 | // should never be a reduction axis. |
50 | TORCH_INTERNAL_ASSERT(it != p2c.end()); |
51 | IterDomain* consumer_id = it->second; |
52 | |
53 | const auto& halo_map = GpuLower::current()->haloInfo(); |
54 | const auto p_pad = halo_map->getRootAxisInfo(producer_id).width(0); |
55 | const auto c_pad = halo_map->getRootAxisInfo(consumer_id).width(0); |
56 | |
57 | auto offset = p_pad - c_pad; |
58 | |
59 | // If the consumer is a result of shifting the producer, adjust the |
60 | // producer index per the offsets argument of the shift op. |
61 | if (auto shift_op = dynamic_cast<const ShiftOp*>(consumer_tv->definition())) { |
62 | offset -= shift_op->offset(producer_axis); |
63 | } |
64 | |
65 | return offset; |
66 | } |
67 | |
68 | //! Offset producer index when necessary |
69 | Val* getProducerIndexWithHalo( |
70 | const TensorView* producer_tv, |
71 | size_t producer_axis, |
72 | Val* producer_index, |
73 | const TensorView* consumer_tv) { |
74 | const auto offset = |
75 | getProducerHaloOffset(producer_tv, producer_axis, consumer_tv); |
76 | |
77 | if (offset == 0) { |
78 | return producer_index; |
79 | } |
80 | |
81 | producer_index = SimplifyingIrBuilder::addExpr(producer_index, offset); |
82 | |
83 | return producer_index; |
84 | } |
85 | |
86 | //! Create a producer offset based off a consumer index |
87 | //! |
88 | //! \param consumer_root_axis Position of corresponding consumer axis |
89 | //! \param consumer_tv Consumer TensorView |
90 | //! \param index_map Mappings from consumer or reference to indices |
91 | //! \param use_reference_map True when index_map maps reference domains |
92 | //! \param concrete_to_ref_map Mappings from concrete to reference domains |
93 | Val* getProducerOffsetWithGather( |
94 | size_t consumer_root_axis, |
95 | const TensorView* consumer_tv, |
96 | const std::unordered_map<IterDomain*, Val*>& index_map, |
97 | bool use_reference_map = false, |
98 | const std::unordered_map<IterDomain*, IterDomain*>& concrete_to_ref_map = |
99 | {}) { |
100 | const auto gpu_lower = GpuLower::current(); |
101 | |
102 | const auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition()); |
103 | |
104 | if (gather_expr == nullptr) { |
105 | return gpu_lower->kernel()->zeroVal(); |
106 | } |
107 | |
108 | // If the window extent is one, no specific offsetting |
109 | // is necessary |
110 | if (consumer_root_axis >= gather_expr->windowShape().size() || |
111 | gather_expr->windowShape()[consumer_root_axis] == 1) { |
112 | return gpu_lower->kernel()->zeroVal(); |
113 | } |
114 | |
115 | // Basically, the goal is to build an expression of producer_index + |
116 | // window_index, so we first need to locate the index expression |
117 | // that corresponds to the window axis of this producer axis. |
118 | |
119 | const auto window_axis = gather_expr->gatherAxis(consumer_root_axis); |
120 | auto window_id = consumer_tv->getRootDomain().at(window_axis); |
121 | |
122 | // When index_map maps a reference tensor, find the corresponding |
123 | // reference ID of window_id. |
124 | if (use_reference_map) { |
125 | auto concrete_window_id = gpu_lower->caMap()->getConcreteMappedID( |
126 | window_id, IdMappingMode::EXACT); |
127 | auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id); |
128 | TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end()); |
129 | window_id = concrete_2_ref_it->second; |
130 | } |
131 | |
132 | auto window_idx = index_map.at(window_id); |
133 | |
134 | // Positive padding at offset zero means the indexing shifted to the |
135 | // negative direction. |
136 | auto pad_width = gather_expr->padWidth()[consumer_root_axis][0]; |
137 | |
138 | // producer offset: window_index - padding |
139 | auto producer_offset = SimplifyingIrBuilder::subExpr( |
140 | window_idx, SimplifyingIrBuilder::create<Int>(pad_width)); |
141 | return producer_offset; |
142 | } |
143 | |
144 | //! Create a producer offset based off a consumer index |
145 | //! |
146 | //! \param consumer_root_axis Position of corresponding consumer axis |
147 | //! \param consumer_tv Consumer TensorView |
148 | //! \param index_map Mappings from consumer or reference to indices |
149 | //! \param use_reference_map True when index_map maps reference domains |
150 | //! \param concrete_to_ref_map Mappings from concrete to reference domains |
151 | Val* getConcreteProducerOffsetWithGather( |
152 | size_t consumer_root_axis, |
153 | const TensorView* consumer_tv, |
154 | const std::unordered_map<IterDomain*, Val*>& index_map, |
155 | bool use_concrete_map = false) { |
156 | const auto gpu_lower = GpuLower::current(); |
157 | |
158 | const auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition()); |
159 | |
160 | if (gather_expr == nullptr) { |
161 | return gpu_lower->kernel()->zeroVal(); |
162 | } |
163 | |
164 | // If the window extent is one, no specific offsetting |
165 | // is necessary |
166 | if (consumer_root_axis >= gather_expr->windowShape().size() || |
167 | gather_expr->windowShape()[consumer_root_axis] == 1) { |
168 | return gpu_lower->kernel()->zeroVal(); |
169 | } |
170 | |
171 | // Basically, the goal is to build an expression of producer_index + |
172 | // window_index, so we first need to locate the index expression |
173 | // that corresponds to the window axis of this producer axis. |
174 | |
175 | const auto window_axis = gather_expr->gatherAxis(consumer_root_axis); |
176 | auto window_id = consumer_tv->getRootDomain().at(window_axis); |
177 | |
178 | Val* window_idx = nullptr; |
179 | |
180 | if (use_concrete_map) { |
181 | window_idx = index_map.at(GpuLower::current()->caMap()->getConcreteMappedID( |
182 | window_id, IdMappingMode::EXACT)); |
183 | } else { |
184 | window_idx = index_map.at(window_id); |
185 | } |
186 | |
187 | // Positive padding at offset zero means the indexing shifted to the |
188 | // negative direction. |
189 | auto pad_width = gather_expr->padWidth()[consumer_root_axis][0]; |
190 | |
191 | // producer offset: window_index - padding |
192 | auto producer_offset = SimplifyingIrBuilder::subExpr( |
193 | window_idx, SimplifyingIrBuilder::create<Int>(pad_width)); |
194 | return producer_offset; |
195 | } |
196 | |
197 | //! Offset a producer index of a gather expression |
198 | //! |
199 | //! Given an index of a producer root axis, build a new index |
200 | //! expression that accesses a window position that the current loop |
201 | //! structure refers to. Use getGatherProducerOffset to create an |
202 | //! offset Val. |
203 | Val* getProducerIndexWithGather( |
204 | Val* producer_index, |
205 | size_t producer_root_axis, |
206 | const TensorView* producer_tv, |
207 | const TensorView* consumer_tv, |
208 | const std::unordered_map<IterDomain*, Val*>& concrete_index_map) { |
209 | auto gather_op = dynamic_cast<const GatherOp*>(consumer_tv->definition()); |
210 | |
211 | // Just return the producer index as is if this is not a gather |
212 | if (gather_op == nullptr) { |
213 | return producer_index; |
214 | } |
215 | |
216 | // Consumer axis that corresponds to the producer axis |
217 | int consumer_axis = -1; |
218 | for (const auto i : c10::irange(producer_root_axis + 1)) { |
219 | if (producer_tv->getMaybeRFactorDomain()[i]->isReduction() || |
220 | producer_tv->getMaybeRFactorDomain()[i]->isStride()) { |
221 | continue; |
222 | } |
223 | ++consumer_axis; |
224 | } |
225 | |
226 | TORCH_INTERNAL_ASSERT( |
227 | consumer_axis >= 0 && |
228 | consumer_axis < (int)gather_op->windowShape().size(), |
229 | "Invalid consumer axis" , |
230 | consumer_axis, |
231 | ", producer_axis: " , |
232 | producer_root_axis); |
233 | |
234 | auto offset = getConcreteProducerOffsetWithGather( |
235 | consumer_axis, consumer_tv, concrete_index_map, true); |
236 | return SimplifyingIrBuilder::addExpr(producer_index, offset); |
237 | } |
238 | |
239 | // Adjusts a global consumer index when its root domain is partially |
240 | // split. Note that non-global consumer indices don't need any |
241 | // adjustment. |
242 | Val* getGlobalConsumerOffsetWithPartialSplit(IterDomain* root_id) { |
243 | auto offset = GpuLower::current()->partialSplitMap().getStartOffset(root_id); |
244 | if (offset == nullptr) { |
245 | return GpuLower::current()->kernel()->zeroVal(); |
246 | } else { |
247 | return offset; |
248 | } |
249 | } |
250 | |
251 | // Adjusts a global producer index when its root domain and |
252 | // corresponding consumer root domain have non-matching split |
253 | // offsets. Specifically, since producer_index is calcualted based on |
254 | // the consumer, if the consumer has a non-zero offset, |
255 | // it needs to be added to the index. Also, when the producer itself |
256 | // also has a non-zero split offset, that needs to be subtracted from |
257 | // the index. |
258 | Val* getProducerIndexWithPartialSplit( |
259 | Val* producer_index, |
260 | IterDomain* producer_root_id, |
261 | const TensorView* producer_tv, |
262 | const TensorView* consumer_tv) { |
263 | const auto gpu_lower = GpuLower::current(); |
264 | |
265 | auto p2c = |
266 | PairwiseRootDomainMap(producer_tv, consumer_tv) |
267 | .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain()); |
268 | |
269 | auto it = p2c.find(producer_root_id); |
270 | if (it == p2c.end()) { |
271 | return producer_index; |
272 | } |
273 | |
274 | auto consumer_root_id = it->second; |
275 | |
276 | auto consumer_offset = |
277 | gpu_lower->partialSplitMap().getStartOffset(consumer_root_id); |
278 | consumer_offset = consumer_offset == nullptr ? gpu_lower->kernel()->zeroVal() |
279 | : consumer_offset; |
280 | |
281 | auto producer_offset = |
282 | gpu_lower->partialSplitMap().getStartOffset(producer_root_id); |
283 | producer_offset = producer_offset == nullptr ? gpu_lower->kernel()->zeroVal() |
284 | : producer_offset; |
285 | |
286 | // If the producer is on global memory, it's always allocated |
287 | // without trimming the out-of-bounds region, so the consumer offset |
288 | // should be added to the index. |
289 | if (producer_tv->getMemoryType() == MemoryType::Global) { |
290 | if (consumer_offset->isZeroInt()) { |
291 | return producer_index; |
292 | } else { |
293 | return SimplifyingIrBuilder::addExpr(producer_index, consumer_offset); |
294 | } |
295 | } |
296 | |
297 | // Non-global case. Difference of the split offsets must be |
298 | // accounted. |
299 | |
300 | auto diff = SimplifyingIrBuilder::subExpr(consumer_offset, producer_offset); |
301 | kir::ExpressionEvaluator ee; |
302 | auto diff_eval = ee.evaluate(diff); |
303 | // We currently only allow constant offsetting |
304 | TORCH_INTERNAL_ASSERT(diff_eval.has_value(), "Invalid partial split" ); |
305 | |
306 | if (diff_eval.value() == 0) { |
307 | return producer_index; |
308 | } |
309 | |
310 | return SimplifyingIrBuilder::addExpr( |
311 | producer_index, |
312 | SimplifyingIrBuilder::create<Int>(diff_eval->as<int64_t>())); |
313 | } |
314 | |
315 | } // namespace |
316 | |
317 | void IndexCompute::handle(Split* split) { |
318 | auto in_id = maybeGetExactMapConcreteID(split->in()->as<IterDomain>()); |
319 | auto outer_id = maybeGetExactMapConcreteID(split->outer()->as<IterDomain>()); |
320 | auto inner_id = maybeGetExactMapConcreteID(split->inner()->as<IterDomain>()); |
321 | |
322 | auto outer_it = index_map_.find(outer_id); |
323 | auto inner_it = index_map_.find(inner_id); |
324 | if (outer_it == index_map_.end() || inner_it == index_map_.end()) { |
325 | return; |
326 | } |
327 | |
328 | const auto outer_ind = outer_it->second; |
329 | const auto inner_ind = inner_it->second; |
330 | |
331 | const bool outer_zero = isZero(outer_id); |
332 | const bool inner_zero = isZero(inner_id); |
333 | |
334 | // We want to mark as zero merged in if we're working with shared or local |
335 | // memory, and the dimension we're working with is not part of the allocation, |
336 | // as we have special propagation rules for that scenario. |
337 | |
338 | // Maybe clear in_id as it could have been mapped over from another |
339 | // IndexCompute. Uncertain if this is needed but seems to be safe. |
340 | bool zero_merged_in = hasZeroMerged(in_id) || hasZeroMerged(inner_id) || |
341 | hasZeroMerged(outer_id); |
342 | |
343 | // If both are zero, the split input is also zero |
344 | if (inner_zero && outer_zero) { |
345 | zero_domains_.emplace(in_id); |
346 | } |
347 | |
348 | if (zero_merged_in) { |
349 | zero_merged_in_.emplace(in_id); |
350 | } |
351 | |
352 | if (isZero(in_id)) { |
353 | index_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); |
354 | extent_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); |
355 | } else if (zero_merged_in && outer_zero) { |
356 | index_map_[in_id] = inner_ind; |
357 | extent_map_[in_id] = getExtent(inner_id); |
358 | } else if (zero_merged_in && inner_zero) { |
359 | index_map_[in_id] = outer_ind; |
360 | extent_map_[in_id] = getExtent(outer_id); |
361 | } else { |
362 | index_map_[in_id] = SimplifyingIrBuilder::addExpr( |
363 | SimplifyingIrBuilder::mulExpr(outer_ind, getExtent(inner_id)), |
364 | inner_ind); |
365 | // The extent should be updated only when its allocation is |
366 | // partial, i.e., zero_merged_in is true. See PR #1270. |
367 | if (zero_merged_in) { |
368 | extent_map_[in_id] = SimplifyingIrBuilder::mulExpr( |
369 | getExtent(outer_id), getExtent(inner_id)); |
370 | } |
371 | } |
372 | } |
373 | |
374 | void IndexCompute::handle(Merge* merge) { |
375 | auto out_id = maybeGetExactMapConcreteID(merge->out()); |
376 | auto outer_id = maybeGetExactMapConcreteID(merge->outer()); |
377 | auto inner_id = maybeGetExactMapConcreteID(merge->inner()); |
378 | |
379 | auto out_it = index_map_.find(out_id); |
380 | if (out_it == index_map_.end()) { |
381 | return; |
382 | } |
383 | auto out_ind = out_it->second; |
384 | |
385 | auto zero = GpuLower::current()->kernel()->zeroVal(); |
386 | |
387 | if (isZero(out_id)) { |
388 | index_map_[outer_id] = zero; |
389 | index_map_[inner_id] = zero; |
390 | // TODO: Why do we set extent_map_ to zero? This has to be protected by zero |
391 | // merged in, but seems logical to me the extent would still be one. |
392 | extent_map_[outer_id] = zero; |
393 | extent_map_[inner_id] = zero; |
394 | zero_domains_.emplace(outer_id); |
395 | zero_domains_.emplace(inner_id); |
396 | return; |
397 | } |
398 | |
399 | if (!hasZeroMerged(out_id) && contig_ids_.find(out_id) != contig_ids_.end()) { |
400 | // Contiguous indexing path |
401 | auto input_ids = ir_utils::iterDomainInputsOfOrderedAs( |
402 | {merge->out()}, td_->getMaybeRFactorDomain()); |
403 | |
404 | // Shouldn't hit this, but don't want to segfault if somehow we do. |
405 | TORCH_INTERNAL_ASSERT(!input_ids.empty()); |
406 | |
407 | // Try to find the last non broadcast entry to put the index in if it's a |
408 | // contiguous merge. This isn't strictly necessary but there's implicit |
409 | // assumptions in the indexing logic that assume broadcasted root domains |
410 | // can be ignored. This logic is just to try and match that logic. |
411 | // Initialize everything to zero. |
412 | for (auto root_id : input_ids) { |
413 | index_map_[root_id] = zero; |
414 | } |
415 | |
416 | // If all are broadcast we can just send the index to the last entry. |
417 | if (std::all_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { |
418 | // I don't think reductions can be in here, but strictly matching the |
419 | // logic in the indexing functions like |
420 | // getNonGlobalConsumerStridedIndices |
421 | return id->isBroadcast() || id->isReduction() || id->isStride(); |
422 | })) { |
423 | index_map_[*(input_ids.end() - 1)] = out_ind; |
424 | } else { |
425 | for (auto id_it = input_ids.rbegin(); id_it != input_ids.rend(); |
426 | id_it++) { |
427 | auto id = *id_it; |
428 | if (id->isBroadcast() || id->isReduction() || id->isStride()) { |
429 | continue; |
430 | } else { |
431 | index_map_[id] = out_ind; |
432 | break; |
433 | } |
434 | } |
435 | } |
436 | |
437 | return; |
438 | } |
439 | |
440 | Val* inner_extent = getExtent(inner_id); |
441 | |
442 | // When the reference has halo extent for inner_id, that extent needs to |
443 | // be used to un-merge |
444 | if (halo_extent_map_.find(inner_id) != halo_extent_map_.end()) { |
445 | inner_extent = halo_extent_map_[inner_id]; |
446 | } |
447 | |
448 | const auto outer_extent = getExtent(outer_id); |
449 | |
450 | if (inner_id->isBroadcast() && inner_extent->isOneInt()) { |
451 | // Propagate away from broadcast dims |
452 | index_map_[outer_id] = out_ind; |
453 | index_map_[inner_id] = zero; |
454 | |
455 | extent_map_[outer_id] = getExtent(out_id); |
456 | if (hasZeroMerged(out_id)) { |
457 | zero_merged_in_.insert(outer_id); |
458 | } |
459 | } else if (outer_id->isBroadcast() && outer_extent->isOneInt()) { |
460 | // Propagate away from broadcast dims |
461 | index_map_[outer_id] = zero; |
462 | index_map_[inner_id] = out_ind; |
463 | |
464 | extent_map_[inner_id] = getExtent(out_id); |
465 | if (hasZeroMerged(out_id)) { |
466 | zero_merged_in_.insert(inner_id); |
467 | } |
468 | } else if (hasZeroMerged(out_id)) { |
469 | // Don't propagate to inner id if it's comprised of only broadcast root |
470 | // domains, unless outer is also all broadcast domains. Index shouldn't be |
471 | // anything but zero if both inner and outer are all broadcast domains, but |
472 | // didn't add a hard check for this. See FusionAdvancedIndexing5_CUDA |
473 | if (!inner_id->isBroadcast() && !outer_id->isBroadcast()) { |
474 | // If neither dimension is a broadcast (should be true for reference |
475 | // indexing) pick the preferred path or the inner path. |
476 | if (preferred_paths_.find(outer_id) != preferred_paths_.end() && |
477 | preferred_paths_.find(inner_id) == preferred_paths_.end()) { |
478 | // Marked that we should prop through outer, not inner. |
479 | index_map_[outer_id] = out_ind; |
480 | extent_map_[outer_id] = getExtent(out_id); |
481 | index_map_[inner_id] = zero; |
482 | extent_map_[inner_id] = zero; |
483 | zero_domains_.emplace(inner_id); |
484 | } else { |
485 | // Prop through inner |
486 | index_map_[inner_id] = out_ind; |
487 | extent_map_[inner_id] = getExtent(out_id); |
488 | index_map_[outer_id] = zero; |
489 | extent_map_[outer_id] = zero; |
490 | zero_domains_.emplace(outer_id); |
491 | } |
492 | } else if (inner_id->isBroadcast() && !outer_id->isBroadcast()) { |
493 | // Inner is broadcast and outer isn't, prop through outer |
494 | index_map_[outer_id] = out_ind; |
495 | extent_map_[outer_id] = getExtent(out_id); |
496 | index_map_[inner_id] = zero; |
497 | extent_map_[inner_id] = zero; |
498 | zero_domains_.emplace(inner_id); |
499 | } else { |
500 | // Default to propagating through inner |
501 | index_map_[inner_id] = out_ind; |
502 | extent_map_[inner_id] = getExtent(out_id); |
503 | index_map_[outer_id] = zero; |
504 | extent_map_[outer_id] = zero; |
505 | zero_domains_.emplace(outer_id); |
506 | } |
507 | zero_merged_in_.emplace(inner_id); |
508 | zero_merged_in_.emplace(outer_id); |
509 | } else { |
510 | index_map_[outer_id] = SimplifyingIrBuilder::divExpr(out_ind, inner_extent); |
511 | index_map_[inner_id] = SimplifyingIrBuilder::modExpr(out_ind, inner_extent); |
512 | } |
513 | } |
514 | |
515 | void IndexCompute::handle(Swizzle2D* swizzle_2d) { |
516 | auto out_x_id = maybeGetExactMapConcreteID(swizzle_2d->outX()); |
517 | auto out_y_id = maybeGetExactMapConcreteID(swizzle_2d->outY()); |
518 | auto in_x_id = maybeGetExactMapConcreteID(swizzle_2d->inX()); |
519 | auto in_y_id = maybeGetExactMapConcreteID(swizzle_2d->inY()); |
520 | |
521 | auto out_x_it = index_map_.find(out_x_id); |
522 | auto out_y_it = index_map_.find(out_y_id); |
523 | |
524 | if (out_x_it == index_map_.end() || out_y_it == index_map_.end()) { |
525 | return; |
526 | } |
527 | |
528 | const auto out_x_ind = out_x_it->second; |
529 | const auto out_y_ind = out_y_it->second; |
530 | |
531 | if (swizzle_mode_ == SwizzleMode::NoSwizzle || |
532 | swizzle_mode_ != swizzle_2d->swizzleMode()) { |
533 | // Handle inactive swizzles by just passing through index |
534 | // and extend information. |
535 | |
536 | TORCH_INTERNAL_ASSERT( |
537 | index_map_.count(in_x_id) == index_map_.count(in_y_id), |
538 | "input index should be either both defined or both undefined" ); |
539 | if (index_map_.count(in_x_id)) { |
540 | // Only propagate original index through if |
541 | // the input index hasn't been computed. |
542 | // TODO: |
543 | // This part should be cleaner once we remove the |
544 | // second index traversal pass. |
545 | return; |
546 | } |
547 | index_map_[in_x_id] = out_x_ind; |
548 | index_map_[in_y_id] = out_y_ind; |
549 | extent_map_[in_y_id] = getExtent(out_y_id); |
550 | extent_map_[in_x_id] = getExtent(out_x_id); |
551 | } else { |
552 | // Generate integer swizzle math if the |
553 | // swizzle is activated. See also |
554 | // [Note on swizzle mode]. |
555 | |
556 | auto out_pair = IrBuilder::swizzle2DIntExpr( |
557 | out_x_ind, |
558 | out_y_ind, |
559 | getExtent(out_x_id), |
560 | getExtent(out_y_id), |
561 | swizzle_2d->swizzleType()); |
562 | |
563 | index_map_[in_x_id] = |
564 | IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X); |
565 | index_map_[in_y_id] = |
566 | IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y); |
567 | } |
568 | } |
569 | |
570 | void IndexCompute::handle(Expr* e) { |
571 | switch (e->getExprType().value()) { |
572 | case (ExprType::Split): |
573 | case (ExprType::Merge): |
574 | case (ExprType::Swizzle2D): |
575 | break; |
576 | default: |
577 | TORCH_INTERNAL_ASSERT( |
578 | false, "Invalid expr type found in transform traversal." ); |
579 | } |
580 | BackwardVisitor::handle(e); |
581 | } |
582 | |
583 | IndexCompute::IndexCompute( |
584 | const TensorDomain* _td, |
585 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
586 | std::unordered_map<IterDomain*, Val*> extent_map, |
587 | std::unordered_set<IterDomain*> zero_domains, |
588 | std::unordered_set<IterDomain*> zero_merged_in, |
589 | std::unordered_set<IterDomain*> preferred_paths, |
590 | std::unordered_map<IterDomain*, Val*> halo_extent_map) |
591 | : IndexCompute( |
592 | _td, |
593 | std::move(initial_index_map), |
594 | std::move(extent_map), |
595 | std::move(zero_domains), |
596 | std::move(zero_merged_in), |
597 | ContigIDs::getNonContigIDs(), |
598 | std::move(preferred_paths), |
599 | std::move(halo_extent_map)) {} |
600 | |
601 | IndexCompute::IndexCompute( |
602 | const TensorDomain* _td, |
603 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
604 | std::unordered_map<IterDomain*, Val*> extent_map, |
605 | std::unordered_set<IterDomain*> zero_domains, |
606 | std::unordered_set<IterDomain*> zero_merged_in, |
607 | const ContigIDs& contig_finder, |
608 | std::unordered_set<IterDomain*> preferred_paths, |
609 | std::unordered_map<IterDomain*, Val*> halo_extent_map) |
610 | : td_(_td), |
611 | index_map_(std::move(initial_index_map)), |
612 | extent_map_(std::move(extent_map)), |
613 | zero_domains_(std::move(zero_domains)), |
614 | zero_merged_in_(std::move(zero_merged_in)), |
615 | preferred_paths_(std::move(preferred_paths)), |
616 | halo_extent_map_(std::move(halo_extent_map)) { |
617 | FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute" ); |
618 | |
619 | // Make sure we recompute any indices we can that map to a contiguous access |
620 | // in physical memory. |
621 | contig_ids_ = contig_finder.contigIDs(); |
622 | root_to_indexed_id_ = contig_finder.rootToIndexedID(); |
623 | const auto& within_contig = contig_finder.withinContigIDs(); |
624 | for (auto contig_id : contig_ids_) { |
625 | if (index_map_.find(contig_id) != index_map_.end()) { |
626 | TORCH_INTERNAL_ASSERT( |
627 | within_contig.find(contig_id) != within_contig.end()); |
628 | for (auto id : within_contig.at(contig_id)) { |
629 | index_map_.erase(id); |
630 | } |
631 | } |
632 | } |
633 | } |
634 | |
635 | IndexCompute::IndexCompute( |
636 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
637 | std::unordered_set<IterDomain*> zero_domains, |
638 | std::unordered_set<IterDomain*> preferred_paths, |
639 | std::unordered_map<IterDomain*, Val*> halo_extent_map) |
640 | : index_map_(std::move(initial_index_map)), |
641 | zero_domains_(std::move(zero_domains)), |
642 | preferred_paths_(std::move(preferred_paths)), |
643 | halo_extent_map_(std::move(halo_extent_map)) { |
644 | FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute" ); |
645 | concrete_id_pass_ = true; |
646 | swizzle_mode_ = SwizzleMode::Loop; |
647 | } |
648 | |
649 | void IndexCompute::run(const LoopIndexing& loop_indexing) { |
650 | TORCH_INTERNAL_ASSERT( |
651 | concrete_id_pass_, "concrete pass only for this option" ); |
652 | // Apply loop swizzles if there are any that outputs to |
653 | // the loop domains. |
654 | // Currently only support loop swizzles that directly output |
655 | // to concrete loop domains and these are validated in |
656 | // validate swizzle pass. |
657 | // TODO: |
658 | // will gradually enable replaying and mapping of loop |
659 | // swizzles in the IR infrastructure and once that's piped |
660 | // through this part of logic will be removed. |
661 | std::unordered_set<Expr*> visited; |
662 | for (auto loop_id : loop_indexing.loopDomains()) { |
663 | auto loop_id_def = loop_id->definition(); |
664 | if (loop_id_def != nullptr && loop_id_def->isA<Swizzle2D>()) { |
665 | if (visited.insert(loop_id_def).second) { |
666 | handle(loop_id_def); |
667 | } |
668 | } |
669 | } |
670 | |
671 | // Resolve the index vals that could be resolved with only |
672 | // the loops that consumer_tv doesn't share with any of its |
673 | // consumers, i.e. the not-inlined loops that define consumer_tv |
674 | // values. |
675 | collectIndexIntoPermissiveMap(loop_indexing); |
676 | |
677 | // Run through the loop indexing expressions and generate |
678 | // the indexing integer math for the concrete ids. |
679 | for (auto expr : loop_indexing.getBackwardExprList()) { |
680 | // Resolve missing values from permissive map. |
681 | updateIndexMapFromPermissiveMap(expr); |
682 | |
683 | handle(expr); |
684 | } |
685 | } |
686 | |
687 | void IndexCompute::collectIndexIntoPermissiveMap( |
688 | const LoopIndexing& loop_indexing) { |
689 | // Visit the expressions that only produces un-inlined iterdomains, |
690 | // in reverse topological order. |
691 | for (auto expr : loop_indexing.getBackwardOutOfLineExprList()) { |
692 | // Compute indexing vals for the expression inputs. |
693 | // |
694 | // This stage should run before any indexing computation so it could be |
695 | // made sure that all index values computed at this stage are |
696 | // the ones that can be resolved only with the not-inlined |
697 | // iterdomains. |
698 | // |
699 | auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs()); |
700 | if (std::all_of( |
701 | id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) { |
702 | return index_map_.count( |
703 | GpuLower::current()->caMap()->getConcreteMappedID( |
704 | id, IdMappingMode::EXACT)); |
705 | })) { |
706 | // Visit this expression: |
707 | // LoopIndexingAnalysis::traverseFromDomainVals made sure that each |
708 | // concrete index is bound exactly once so computing these expressions |
709 | // early should still be consistent. |
710 | handle(expr); |
711 | |
712 | auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs()); |
713 | for (auto id : id_inputs) { |
714 | // Collect backward pass results from this expression if they are |
715 | // made available in by this expression. |
716 | auto idx_it = |
717 | index_map_.find(GpuLower::current()->caMap()->getConcreteMappedID( |
718 | id, IdMappingMode::EXACT)); |
719 | |
720 | if (idx_it != index_map_.end()) { |
721 | permissive_index_map_ |
722 | [GpuLower::current()->caMap()->getConcreteMappedID( |
723 | id, IdMappingMode::PERMISSIVE)] = idx_it->second; |
724 | } |
725 | } |
726 | } |
727 | } |
728 | } |
729 | |
730 | void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) { |
731 | auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs()); |
732 | for (auto id : id_outputs) { |
733 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
734 | id, IdMappingMode::EXACT); |
735 | // Only try to copy index val from permissive map when |
736 | // the index is missing. |
737 | if (!index_map_.count(concrete_id)) { |
738 | auto permissive_id = GpuLower::current()->caMap()->getConcreteMappedID( |
739 | id, IdMappingMode::PERMISSIVE); |
740 | // Write the permissive index val into index_map_ if the |
741 | // missing value is found here. |
742 | auto permissive_it = permissive_index_map_.find(permissive_id); |
743 | if (permissive_it != permissive_index_map_.end()) { |
744 | index_map_[concrete_id] = permissive_it->second; |
745 | } |
746 | } |
747 | } |
748 | } |
749 | |
750 | void IndexCompute::run() { |
751 | const std::vector<Val*> domain_vals( |
752 | td_->domain().begin(), td_->domain().end()); |
753 | |
754 | traverseTo(td_->fusion(), domain_vals, false); |
755 | } |
756 | |
757 | IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) { |
758 | if (concrete_id_pass_) { |
759 | return GpuLower::current()->caMap()->getConcreteMappedID( |
760 | id, IdMappingMode::EXACT); |
761 | } |
762 | return id; |
763 | } |
764 | |
765 | Val* IndexCompute::getExtent(IterDomain* id) const { |
766 | // Pick from extent_map_ if available. Previously parallel |
767 | // dimensions were ued (e.g., blockDim.x), however, it would result |
768 | // in out-of-bounds errors when the extent of IterDomain is smaller |
769 | // than the threading dimension. |
770 | if (extent_map_.find(id) != extent_map_.end()) { |
771 | return extent_map_.at(id); |
772 | } else { |
773 | return id->extent(); |
774 | } |
775 | } |
776 | |
777 | bool IndexCompute::hasZeroMerged(IterDomain* id) const { |
778 | return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id); |
779 | } |
780 | |
781 | bool IndexCompute::isZero(IterDomain* id) const { |
782 | return zero_domains_.find(id) != zero_domains_.end(); |
783 | } |
784 | |
785 | IndexCompute IndexCompute::updateIndexCompute( |
786 | const TensorDomain* new_td, |
787 | const std::unordered_map<IterDomain*, IterDomain*>& id_map, |
788 | const ContigIDs& contig_finder) const { |
789 | FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute" ); |
790 | |
791 | std::unordered_map<IterDomain*, Val*> updated_index_map; |
792 | std::unordered_map<IterDomain*, Val*> updated_extent_map; |
793 | std::unordered_set<IterDomain*> updated_zero_domains; |
794 | std::unordered_set<IterDomain*> updated_zero_merged_in; |
795 | std::unordered_map<IterDomain*, Val*> updated_halo_extent_map; |
796 | |
797 | for (auto id_entry : id_map) { |
798 | IterDomain* prev_id = id_entry.first; |
799 | IterDomain* new_id = id_entry.second; |
800 | |
801 | if (index_map_.find(prev_id) != index_map_.end()) { |
802 | updated_index_map[new_id] = index_map_.at(prev_id); |
803 | } |
804 | |
805 | updated_extent_map[new_id] = getExtent(prev_id); |
806 | |
807 | if (zero_domains_.find(prev_id) != zero_domains_.end()) { |
808 | updated_zero_domains.emplace(new_id); |
809 | } |
810 | |
811 | if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) { |
812 | updated_zero_merged_in.emplace(new_id); |
813 | } |
814 | |
815 | auto halo_extent_it = halo_extent_map_.find(prev_id); |
816 | if (halo_extent_it != halo_extent_map_.end()) { |
817 | updated_halo_extent_map[new_id] = halo_extent_it->second; |
818 | } |
819 | } |
820 | |
821 | IndexCompute updated_index_compute( |
822 | new_td, |
823 | updated_index_map, |
824 | updated_extent_map, |
825 | updated_zero_domains, |
826 | updated_zero_merged_in, |
827 | contig_finder, |
828 | {}, |
829 | updated_halo_extent_map); |
830 | |
831 | updated_index_compute.run(); |
832 | |
833 | return updated_index_compute; |
834 | } |
835 | |
836 | namespace { |
837 | // Map indices down to the leaf domains for applying swizzle |
838 | class UpdateLeafIndices : public IterVisitor { |
839 | public: |
840 | UpdateLeafIndices( |
841 | const TensorDomain* td, |
842 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
843 | std::unordered_map<IterDomain*, Val*> extent_map) |
844 | : td_(td), |
845 | index_map_(std::move(initial_index_map)), |
846 | extent_map_(std::move(extent_map)) { |
847 | const std::vector<Val*> domain_vals( |
848 | td_->domain().begin(), td_->domain().end()); |
849 | |
850 | traverseTo(td_->fusion(), domain_vals, false); |
851 | } |
852 | |
853 | const std::unordered_map<IterDomain*, Val*>& indexMap() const { |
854 | return index_map_; |
855 | } |
856 | |
857 | const std::unordered_map<IterDomain*, Val*>& extentMap() const { |
858 | return extent_map_; |
859 | } |
860 | |
861 | private: |
862 | using IterVisitor::handle; |
863 | |
864 | void handle(Split* split) override { |
865 | auto in_id = split->in(); |
866 | auto outer_id = split->outer(); |
867 | auto inner_id = split->inner(); |
868 | |
869 | // Nothing need to be done when mappings for the output axes |
870 | // already exist. |
871 | if (index_map_.find(outer_id) != index_map_.end()) { |
872 | TORCH_INTERNAL_ASSERT( |
873 | index_map_.find(inner_id) != index_map_.end(), |
874 | "Outer exists but inner not found" ); |
875 | return; |
876 | } |
877 | |
878 | if (!index_map_.count(in_id)) { |
879 | // Reduction axes on producer side could be visited on forward |
880 | // propagation pass and current implementation does not yet |
881 | // support reduciton on swizzled iterdomains, so un-indexed |
882 | // reduction iterdomains are just ignored for now. |
883 | TORCH_INTERNAL_ASSERT( |
884 | in_id->isReduction(), "Undefined index for " , in_id->toString()); |
885 | return; |
886 | } |
887 | |
888 | auto factor = split->factor(); |
889 | index_map_[inner_id] = |
890 | SimplifyingIrBuilder::modExpr(index_map_[in_id], factor); |
891 | extent_map_[inner_id] = factor; |
892 | index_map_[outer_id] = |
893 | SimplifyingIrBuilder::divExpr(index_map_[in_id], factor); |
894 | extent_map_[outer_id] = |
895 | SimplifyingIrBuilder::ceilDivExpr(getExtent(in_id), factor); |
896 | } |
897 | |
898 | void handle(Merge* merge) override { |
899 | auto out_id = merge->out(); |
900 | auto outer_id = merge->outer(); |
901 | auto inner_id = merge->inner(); |
902 | |
903 | if (!index_map_.count(outer_id) || !index_map_.count(inner_id)) { |
904 | // Reduction axes on producer side could be visited on forward |
905 | // propagation pass and current implementation does not yet |
906 | // support reduciton on swizzled iterdomains, so un-indexed |
907 | // reduction iterdomains are just ignored for now. |
908 | TORCH_INTERNAL_ASSERT( |
909 | outer_id->isReduction() && inner_id->isReduction(), |
910 | "Undefined index for " , |
911 | outer_id->toString(), |
912 | " and " , |
913 | inner_id->toString()); |
914 | return; |
915 | } |
916 | |
917 | // Nothing need to be done when mappings for the output axes |
918 | // already exist. |
919 | if (index_map_.find(out_id) != index_map_.end()) { |
920 | return; |
921 | } |
922 | |
923 | TORCH_INTERNAL_ASSERT( |
924 | index_map_.find(outer_id) != index_map_.end(), "Outer ID not found" ); |
925 | TORCH_INTERNAL_ASSERT( |
926 | index_map_.find(inner_id) != index_map_.end(), "Inner ID not found" ); |
927 | |
928 | index_map_[out_id] = SimplifyingIrBuilder::addExpr( |
929 | index_map_[inner_id], |
930 | SimplifyingIrBuilder::mulExpr( |
931 | index_map_[outer_id], getExtent(inner_id))); |
932 | |
933 | extent_map_[out_id] = |
934 | SimplifyingIrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); |
935 | } |
936 | |
937 | void handle(Swizzle2D* swizzle_2d) override { |
938 | auto in_x = swizzle_2d->inX(); |
939 | auto in_y = swizzle_2d->inY(); |
940 | auto out_x = swizzle_2d->outX(); |
941 | auto out_y = swizzle_2d->outY(); |
942 | |
943 | // Forward propagation pass still just forward |
944 | // through the indices and the actual swizzle |
945 | // will be applied on the backward pass in |
946 | // IndexSwizzle class implementation. |
947 | index_map_[out_x] = index_map_.at(in_x); |
948 | extent_map_[out_x] = getExtent(in_x); |
949 | index_map_[out_y] = index_map_.at(in_y); |
950 | extent_map_[out_y] = getExtent(in_y); |
951 | } |
952 | |
953 | // return extent_map_[id] if exists, else return id->extent() |
954 | Val* getExtent(IterDomain* id) { |
955 | if (extent_map_.find(id) != extent_map_.end()) { |
956 | return extent_map_.at(id); |
957 | } else { |
958 | return id->extent(); |
959 | } |
960 | } |
961 | |
962 | private: |
963 | const TensorDomain* td_; |
964 | std::unordered_map<IterDomain*, Val*> index_map_; |
965 | std::unordered_map<IterDomain*, Val*> extent_map_; |
966 | }; |
967 | |
968 | // Returns halo-extended extent if id has halo. Otherwise, just |
969 | // returns id->extent. |
970 | Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) { |
971 | if (normal_extent == nullptr) { |
972 | normal_extent = id->extent(); |
973 | } |
974 | |
975 | const auto& halo = GpuLower::current()->haloInfo()->getRootAxisInfo(id); |
976 | if (halo.hasHalo()) { |
977 | auto halo_extent = SimplifyingIrBuilder::addExpr( |
978 | normal_extent, SimplifyingIrBuilder::create<Int>(halo.width())); |
979 | return halo_extent; |
980 | } else { |
981 | return normal_extent; |
982 | } |
983 | } |
984 | |
985 | } // namespace |
986 | |
987 | IndexSwizzle::IndexSwizzle( |
988 | const TensorView* tv, |
989 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
990 | std::unordered_map<IterDomain*, Val*> extent_map, |
991 | std::unordered_set<IterDomain*> zero_domains, |
992 | std::unordered_set<IterDomain*> zero_merged_in) |
993 | : IndexCompute( |
994 | tv->domain(), |
995 | std::move(initial_index_map), |
996 | std::move(extent_map), |
997 | std::move(zero_domains), |
998 | std::move(zero_merged_in)), |
999 | tv_(tv), |
1000 | swizzle_type_(tv->swizzleType()), |
1001 | ids_to_swizzle_(tv->axesToSwizzle()) {} |
1002 | |
1003 | IndexSwizzle::IndexSwizzle( |
1004 | const TensorView* tv, |
1005 | const TensorDomain* domain, |
1006 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
1007 | std::unordered_map<IterDomain*, Val*> extent_map, |
1008 | std::unordered_set<IterDomain*> zero_domains, |
1009 | std::unordered_set<IterDomain*> zero_merged_in) |
1010 | : IndexCompute( |
1011 | domain, |
1012 | std::move(initial_index_map), |
1013 | std::move(extent_map), |
1014 | std::move(zero_domains), |
1015 | std::move(zero_merged_in)), |
1016 | tv_(tv), |
1017 | swizzle_type_(tv->swizzleType()), |
1018 | ids_to_swizzle_(tv->axesToSwizzle()) {} |
1019 | |
1020 | void IndexSwizzle::run() { |
1021 | TORCH_INTERNAL_ASSERT( |
1022 | swizzle_type_ == SwizzleType::NoSwizzle || |
1023 | swizzle_type_ == SwizzleType::Transpose, |
1024 | "Invalid swizzle type" ); |
1025 | if (swizzle_type_ == SwizzleType::Transpose) { |
1026 | // Shifts the second axis by the first axis as ((idx_1 + idx_2) % |
1027 | // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also |
1028 | // work if ext is a power of two. Practically, ext should be 32 if |
1029 | // the data type of the tensor is float, so the latter approach |
1030 | // should also be fine. |
1031 | TORCH_INTERNAL_ASSERT(tv_->getMemoryType() == MemoryType::Shared); |
1032 | TORCH_INTERNAL_ASSERT(tv_->axesToSwizzle().size() == 2); |
1033 | |
1034 | UpdateLeafIndices update_leaves(td_, indexMap(), extentMap()); |
1035 | index_map_ = update_leaves.indexMap(); |
1036 | extent_map_ = update_leaves.extentMap(); |
1037 | |
1038 | IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0); |
1039 | IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1); |
1040 | |
1041 | if (indexMap().find(id_to_swizzle_i) != indexMap().end() && |
1042 | indexMap().find(id_to_swizzle_j) != indexMap().end()) { |
1043 | auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i); |
1044 | auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j); |
1045 | |
1046 | auto swizzled_idx = SimplifyingIrBuilder::modExpr( |
1047 | SimplifyingIrBuilder::addExpr(idx_to_swizzle_i, idx_to_swizzle_j), |
1048 | id_to_swizzle_j->extent()); |
1049 | index_map_[id_to_swizzle_j] = swizzled_idx; |
1050 | swizzled_ids_.insert(id_to_swizzle_j); |
1051 | IndexCompute::run(); |
1052 | } |
1053 | } else if (tv_->hasSwizzleOp()) { |
1054 | // Propagate backward for the annotated swizzle path. |
1055 | // TODO: |
1056 | // eventually will unify the two swizzling implementation |
1057 | // code path in a follow up. Currently just focusing on |
1058 | // getting the necessary implementation of the swizzle |
1059 | // operator ready. |
1060 | // |
1061 | // At this intermediate state, the legacy swizzle implementation |
1062 | // takes precedence, i.e. whenever swizzle_type_ is not NoSwizzle, |
1063 | // the new swizzle op pass is disabled. |
1064 | UpdateLeafIndices update_leaves(td_, indexMap(), extentMap()); |
1065 | index_map_ = update_leaves.indexMap(); |
1066 | extent_map_ = update_leaves.extentMap(); |
1067 | IndexCompute::swizzle_mode_ = SwizzleMode::Data; |
1068 | IndexCompute::run(); |
1069 | } |
1070 | } |
1071 | |
1072 | void IndexSwizzle::handle(Expr* e) { |
1073 | auto out_ids = ir_utils::filterByType<IterDomain>(e->outputs()); |
1074 | bool needs_update = |
1075 | std::any_of( |
1076 | out_ids.begin(), |
1077 | out_ids.end(), |
1078 | [this](IterDomain* id) { |
1079 | return swizzled_ids_.find(id) != swizzled_ids_.end(); |
1080 | }) || |
1081 | (e->isA<Swizzle2D>() && |
1082 | e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle && |
1083 | e->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Data); |
1084 | if (!needs_update) { |
1085 | return; |
1086 | } |
1087 | |
1088 | IndexCompute::handle(e); |
1089 | for (auto input : ir_utils::filterByType<IterDomain>(e->inputs())) { |
1090 | swizzled_ids_.insert(input); |
1091 | } |
1092 | } |
1093 | |
1094 | void IndexSwizzle::handle(Swizzle2D* swizzle_2d) { |
1095 | auto out_x_id = swizzle_2d->outX(); |
1096 | auto out_y_id = swizzle_2d->outY(); |
1097 | |
1098 | auto out_x_it = index_map_.find(out_x_id); |
1099 | auto out_y_it = index_map_.find(out_y_id); |
1100 | |
1101 | // TODO: unify the legacy path in all usage |
1102 | TORCH_INTERNAL_ASSERT( |
1103 | swizzle_type_ == SwizzleType::NoSwizzle, |
1104 | "Cannot mix usage of two swizzle implementations" ); |
1105 | |
1106 | TORCH_INTERNAL_ASSERT( |
1107 | out_x_it != index_map_.end() && out_y_it != index_map_.end(), |
1108 | "Swizzle output indices were not propagated through" ); |
1109 | |
1110 | IndexCompute::handle(swizzle_2d); |
1111 | } |
1112 | |
1113 | // Used for local and shared index mapping. Returns a map from loops |
1114 | // to loop indices as well as a set of loops that do not contribute to |
1115 | // indexing. |
1116 | std::pair< |
1117 | std::unordered_map<kir::ForLoop*, Val*>, |
1118 | std::unordered_set<kir::ForLoop*>> |
1119 | indexMapFromTV( |
1120 | const TensorView* tv, |
1121 | const std::vector<kir::ForLoop*>& loops, |
1122 | kir::ForLoop* alloc_loop, |
1123 | bool as_consumer, |
1124 | kir::ForLoop* double_buffer_loop) { |
1125 | bool within_alloc = false; |
1126 | if (alloc_loop == nullptr) { |
1127 | within_alloc = true; |
1128 | } |
1129 | |
1130 | const bool is_global = tv->getMemoryType() == MemoryType::Global; |
1131 | const bool is_shared = tv->getMemoryType() == MemoryType::Shared; |
1132 | const bool is_local = tv->getMemoryType() == MemoryType::Local; |
1133 | |
1134 | std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map; |
1135 | |
1136 | // Check if the current op has an implicit loop implemented |
1137 | // within an mma instruction. |
1138 | bool within_mma_loops = |
1139 | std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { |
1140 | return fl->iter_domain()->isMma(); |
1141 | }); |
1142 | |
1143 | // When indexed as a producer, the parallel types of the the |
1144 | // producer domains may not be the same as those of the loops, but |
1145 | // that's still valid parallelization. However, in that case, using |
1146 | // the parallel types of the loops to decide replacement of indices |
1147 | // with zero isn't valid. That's only valid when there's a matching |
1148 | // IterDomain in the producer tensor that has the same parallel |
1149 | // type. |
1150 | auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool { |
1151 | const auto gpu_lower = GpuLower::current(); |
1152 | auto it = std::find_if( |
1153 | tv->domain()->domain().begin(), |
1154 | tv->domain()->domain().end(), |
1155 | [&](IterDomain* tv_id) { |
1156 | // Matching is done using the index and loop maps. See |
1157 | // validateParallelize as well. |
1158 | return gpu_lower->caMap()->areMapped( |
1159 | id, tv_id, IdMappingMode::EXACT) || |
1160 | (GpuLower::current()->caMap()->areMapped( |
1161 | id, tv_id, IdMappingMode::PERMISSIVE) && |
1162 | ir_utils::derivedFromRootCAAxes(tv, tv_id)); |
1163 | }); |
1164 | if (it == tv->domain()->domain().end()) { |
1165 | return false; |
1166 | } |
1167 | |
1168 | auto corresponding_domain = *it; |
1169 | return corresponding_domain->getParallelType() == id->getParallelType(); |
1170 | }; |
1171 | |
1172 | // Track domains that do not contibute to the resulting |
1173 | // index. Previously, index->isZeroInt() was used to detect such |
1174 | // domains, but that's not a reliable method as we may set an |
1175 | // initial index to zero for unswitch. |
1176 | std::unordered_set<kir::ForLoop*> zero_loops; |
1177 | |
1178 | for (auto loop : loops) { |
1179 | Val* idx = nullptr; |
1180 | const auto same_parallel_type = as_consumer || |
1181 | find_matching_parallel_domain(loop->iter_domain()) || |
1182 | // Note && TODO: |
1183 | // mma swizzled lane_id does not map naturally from producer |
1184 | // to consumer but they should still be detected as same |
1185 | // parallel type. In a follow up may want to extent |
1186 | // find_matching_parallel_domain to cover this case. |
1187 | (within_mma_loops && |
1188 | loop->iter_domain()->getParallelType() == ParallelType::TIDx); |
1189 | // See also LoopNestGenerator::pushAlloc. |
1190 | // NOLINTNEXTLINE(bugprone-branch-clone) |
1191 | if (!within_alloc) { |
1192 | if ((loop->iter_domain()->isThreadDim() && is_shared) || |
1193 | (loop->iter_domain()->isThread() && is_global)) { |
1194 | idx = loop->index(); |
1195 | } else { |
1196 | idx = GpuLower::current()->kernel()->zeroVal(); |
1197 | zero_loops.insert(loop); |
1198 | } |
1199 | } else if ( |
1200 | // For shared-memory tensors, when a domain is parallelized by |
1201 | // BID, the index can be replaced with zero as long as the |
1202 | // tensor has a matching domain that has the same parallel |
1203 | // type. Matching can be omitted when indexed as a consumer |
1204 | // since it is always the case. When indexed as a producer, to |
1205 | // replace it with zero, the same parallel type of BID must be |
1206 | // used by the producer tensor. Thus, since this is a shared |
1207 | // memory tensor, when a producer domain is parallelized by |
1208 | // BID, there must be a matching consumer domain with the same |
1209 | // parallel type, which must be the IterDomain of the |
1210 | // loop. |
1211 | (loop->iter_domain()->isBlockDim() && is_shared && |
1212 | same_parallel_type) || |
1213 | // Similarly for local memory tensors, zero replacement can be |
1214 | // only done when there's a matching domain with the same |
1215 | // parallel type |
1216 | (loop->iter_domain()->isThread() && is_local && same_parallel_type) || |
1217 | // MMA operands are currently indexed in units of "fragments", |
1218 | // so each mma tensor domain would be zero-ed and the tensor index |
1219 | // calculated here would be the fragment index. |
1220 | // TODO: This is a quick WAR to enable iterating over a register array |
1221 | // of MMA fragments, so we could generate unrolled mma loops. |
1222 | // Eventually we still want IdGraph to be able to analyze the |
1223 | // in-register layout of mma fragments for more unified indexing math |
1224 | // as well as more flexibility in swizzling loops. |
1225 | (loop->iter_domain()->isMma() && !as_consumer)) { |
1226 | idx = GpuLower::current()->kernel()->zeroVal(); |
1227 | zero_loops.insert(loop); |
1228 | } else { |
1229 | idx = loop->index(); |
1230 | } |
1231 | |
1232 | // If the loop is trivial, the loop index can only be the loop |
1233 | // start value. |
1234 | if (idx == loop->index() && loop->isTrivial()) { |
1235 | idx = loop->start(); |
1236 | } |
1237 | |
1238 | if (loop == double_buffer_loop) { |
1239 | auto stage_depth = |
1240 | GpuLower::current()->doubleBufferInfo().getStageDepthFor( |
1241 | loop->iter_domain()); |
1242 | idx = SimplifyingIrBuilder::addExpr( |
1243 | idx, SimplifyingIrBuilder::create<Int>(stage_depth - 1)); |
1244 | } |
1245 | |
1246 | loop_to_ind_map[loop] = idx; |
1247 | |
1248 | if (!within_alloc && loop == alloc_loop) { |
1249 | within_alloc = true; |
1250 | } |
1251 | } |
1252 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
1253 | return {loop_to_ind_map, zero_loops}; |
1254 | } |
1255 | |
1256 | //! Set "pragma unroll" required for loops that indexing of Local |
1257 | //! tensors depends on. |
1258 | //! |
1259 | //! \param tv Indexed tensor |
1260 | //! \param alloc_loop Allocation loop of tv |
1261 | //! \param loops The current loop structure |
1262 | //! \param id_map Producer-to-consumer map in case of indexing as producer |
1263 | void ensureStaticIndexing( |
1264 | const TensorView* tv, |
1265 | kir::ForLoop* alloc_loop, |
1266 | const std::vector<kir::ForLoop*>& loops, |
1267 | const std::unordered_map<IterDomain*, IterDomain*>& id_map) { |
1268 | if (tv->getMemoryType() != MemoryType::Local) { |
1269 | return; |
1270 | } |
1271 | |
1272 | bool within_alloc = false; |
1273 | if (alloc_loop == nullptr) { |
1274 | within_alloc = true; |
1275 | } |
1276 | |
1277 | for (auto loop : loops) { |
1278 | if (!within_alloc) { |
1279 | if (loop == alloc_loop) { |
1280 | within_alloc = true; |
1281 | } |
1282 | continue; |
1283 | } |
1284 | IterDomain* loop_id = loop->iter_domain(); |
1285 | if (loop->vectorize() || loop_id->isThread()) { |
1286 | continue; |
1287 | } |
1288 | // Look for a domain that is mapped with the loop. If mapped in |
1289 | // the loop map, the loop index should be used for indexing of the |
1290 | // tensor, except for broadcast and reduction domains. |
1291 | auto it = std::find_if( |
1292 | tv->domain()->domain().begin(), |
1293 | tv->domain()->domain().end(), |
1294 | [loop_id, &id_map](IterDomain* id) { |
1295 | if (id->isBroadcast() || id->isReduction() || id->isStride()) { |
1296 | return false; |
1297 | } |
1298 | auto id_replacement = id_map.find(id); |
1299 | if (id_replacement != id_map.end()) { |
1300 | id = id_replacement->second; |
1301 | } |
1302 | return GpuLower::current()->caMap()->areMapped( |
1303 | loop_id, id, IdMappingMode::PERMISSIVE); |
1304 | }); |
1305 | if (it != tv->domain()->domain().end()) { |
1306 | loop->requireUnroll(); |
1307 | } |
1308 | } |
1309 | } |
1310 | |
1311 | namespace { |
1312 | |
1313 | //! Returns an iterdomain that corresponds to the |
1314 | //! indexing sub-expression to hoist or a nullopt |
1315 | //! if the index should not be hoisted. |
1316 | c10::optional<IterDomain*> getMaybeIndexedIdToHoist( |
1317 | IterDomain* root_id, |
1318 | const TensorView* tv, |
1319 | const IndexCompute& indexing, |
1320 | Val* index) { |
1321 | if (isOptionDisabled(DisableOption::IndexHoist) || |
1322 | index->definition() == nullptr) { |
1323 | return c10::nullopt; |
1324 | } |
1325 | |
1326 | // The old swizzle interface, which should be deprecated, is not |
1327 | // supported. |
1328 | if (tv->swizzleType() != SwizzleType::NoSwizzle) { |
1329 | return c10::nullopt; |
1330 | } |
1331 | |
1332 | // New swizzle interface not yet supported |
1333 | if (tv->hasSwizzleOp()) { |
1334 | return c10::nullopt; |
1335 | } |
1336 | |
1337 | // Find the true indexed domain, which can be a merged contiguous domain. |
1338 | auto contig_id_it = indexing.rootToContigID().find(root_id); |
1339 | TORCH_INTERNAL_ASSERT( |
1340 | contig_id_it != indexing.rootToContigID().end(), |
1341 | "Consumer indexed ID not found: " , |
1342 | root_id->toString()); |
1343 | auto indexed_id = contig_id_it->second; |
1344 | // Make sure this contig ID is indeed indexed |
1345 | TORCH_INTERNAL_ASSERT( |
1346 | indexing.indexMap().find(contig_id_it->second) != |
1347 | indexing.indexMap().end(), |
1348 | "Invalid contig index: " , |
1349 | contig_id_it->second->toString()); |
1350 | |
1351 | return indexed_id; |
1352 | } |
1353 | |
1354 | // Version of hoisting without using reference tensor, |
1355 | // should eventually deprecate the other one once reference |
1356 | // tensor is completely deprecated. |
1357 | Val* hoistConsumerIndex( |
1358 | IterDomain* consumer_root_id, |
1359 | const TensorView* consumer_tv, |
1360 | const IndexCompute& consumer_indexing, |
1361 | std::vector<IterDomain*> loop_domains, |
1362 | const std::unordered_map<IterDomain*, Val*> initial_loop_index_map, |
1363 | const std::vector<kir::ForLoop*>& loops, |
1364 | Val* index) { |
1365 | auto maybe_hoisted_consumer_id = getMaybeIndexedIdToHoist( |
1366 | consumer_root_id, consumer_tv, consumer_indexing, index); |
1367 | |
1368 | if (!maybe_hoisted_consumer_id.has_value()) { |
1369 | return index; |
1370 | } |
1371 | |
1372 | // Insert the index into the common index map. A previously inserted |
1373 | // val can be returned. |
1374 | auto common_index = GpuLower::current() |
1375 | ->commonIndexMap() |
1376 | .insert( |
1377 | maybe_hoisted_consumer_id.value(), |
1378 | consumer_tv->domain(), |
1379 | loop_domains, |
1380 | initial_loop_index_map, |
1381 | loops, |
1382 | index) |
1383 | .first; |
1384 | |
1385 | return common_index; |
1386 | } |
1387 | |
1388 | std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap( |
1389 | const std::unordered_map<IterDomain*, IterDomain*>& map) { |
1390 | std::unordered_map<IterDomain*, IterDomain*> inverted; |
1391 | for (const auto& kv : map) { |
1392 | bool inserted = inverted.emplace(kv.second, kv.first).second; |
1393 | TORCH_INTERNAL_ASSERT( |
1394 | inserted, |
1395 | "Multiple mappings to the same value detected: " , |
1396 | kv.second->toString()); |
1397 | } |
1398 | return inverted; |
1399 | } |
1400 | |
1401 | Val* hoistProducerIndex( |
1402 | IterDomain* producer_root_id, |
1403 | const TensorView* producer_tv, |
1404 | const IndexCompute& producer_indexing, |
1405 | const TensorView* consumer_tv, |
1406 | const std::unordered_map<IterDomain*, IterDomain*>& p2c_map, |
1407 | std::vector<IterDomain*> loop_domains, |
1408 | const std::unordered_map<IterDomain*, Val*> initial_loop_index_map, |
1409 | const std::vector<kir::ForLoop*>& loops, |
1410 | Val* index) { |
1411 | auto maybe_indexed_producer_id = getMaybeIndexedIdToHoist( |
1412 | producer_root_id, producer_tv, producer_indexing, index); |
1413 | |
1414 | if (!maybe_indexed_producer_id.has_value()) { |
1415 | return index; |
1416 | } |
1417 | |
1418 | // Use the corresponding consumer domain to find matching |
1419 | // for-loops. Note that there's no CA mapping with the producer |
1420 | // domains as the producer TensorDomain is a temporary replay |
1421 | // domain. |
1422 | auto indexed_consumer_id_it = p2c_map.find(maybe_indexed_producer_id.value()); |
1423 | |
1424 | // There can be no corresponding consumer ID. For example, consider: |
1425 | // consumer: [b1, i2, i3] |
1426 | // producer: [i2, i3]. |
1427 | // Suppose the consumer is transformed as: |
1428 | // consumer: [(b1*i2)*i3] |
1429 | // Then the producer would be transformed when indexed: |
1430 | // producer: [i2*i3] |
1431 | // Assuming i2 and i3 are contiguous, the producer indexing is done |
1432 | // with the mreged i2*i3 domain, but there's no domain in the |
1433 | // cosumer that maps with the producer indexed domain. |
1434 | // It seems non-trivial to support patterns like this. Skip for now. |
1435 | if (indexed_consumer_id_it == p2c_map.end()) { |
1436 | return index; |
1437 | } |
1438 | |
1439 | IterDomain* indexed_consumer_id = indexed_consumer_id_it->second; |
1440 | |
1441 | auto common_index = GpuLower::current() |
1442 | ->commonIndexMap() |
1443 | .insert( |
1444 | indexed_consumer_id, |
1445 | consumer_tv->domain(), |
1446 | loop_domains, |
1447 | initial_loop_index_map, |
1448 | loops, |
1449 | index) |
1450 | .first; |
1451 | |
1452 | return common_index; |
1453 | } |
1454 | |
1455 | } // namespace |
1456 | |
1457 | std::vector<Val*> Index::getGlobalProducerStridedIndices( |
1458 | TensorView* producer_tv, |
1459 | const TensorView* consumer_tv, |
1460 | const std::vector<kir::ForLoop*>& loops) { |
1461 | FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex" ); |
1462 | const auto gpu_lower = GpuLower::current(); |
1463 | |
1464 | // Replay producer to look like consumer so we can index on producer since |
1465 | // our loop nests look like consumer |
1466 | auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); |
1467 | auto producerAsC = |
1468 | TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map) |
1469 | .first; |
1470 | |
1471 | // Make the producer_tv look like consumer while performing indexing math |
1472 | ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC); |
1473 | |
1474 | // Map sent to best effort replay needs to match the exact incantation for |
1475 | // compute_at_mode.cpp with MappingMode::Index |
1476 | auto c2p_root_map = |
1477 | PairwiseRootDomainMap(producer_tv, consumer_tv, true) |
1478 | .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain()); |
1479 | |
1480 | // This replay has to be consistent with compute at index map. |
1481 | BestEffortReplay replay_producer_as_consumer( |
1482 | producer_tv->domain()->domain(), |
1483 | consumer_tv->domain()->domain(), |
1484 | c2p_root_map); |
1485 | |
1486 | const auto& c2p_map = replay_producer_as_consumer.getReplay(); |
1487 | const auto p2c_map = invertOneToOneMap(c2p_map); |
1488 | |
1489 | // Forward vectorized IDs to index into producer correctly |
1490 | // We want p_id to be vectorized like consumer just for the indexing, then we |
1491 | // need to switch it back later. Store previous state here when changing. We |
1492 | // need to do this as replaying producer as consumer can use replay best |
1493 | // effort which means some domains may be producer's original domains. |
1494 | std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup; |
1495 | for (auto entry : c2p_map) { |
1496 | auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( |
1497 | entry.first, IdMappingMode::EXACT); |
1498 | auto p_id = entry.second; |
1499 | if (ref_id->getParallelType() == ParallelType::Vectorize) { |
1500 | p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); |
1501 | p_id->parallelize(ParallelType::Vectorize); |
1502 | } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) { |
1503 | p_id->parallelize(ParallelType::MisalignedVectorize); |
1504 | } |
1505 | } |
1506 | |
1507 | auto producer_indexing_from_idgraph = |
1508 | getTensorIndexFromIdGraph(loops, consumer_tv, producer_tv, true, c2p_map); |
1509 | |
1510 | auto producer_indexing = producer_indexing_from_idgraph.index; |
1511 | |
1512 | // Revert p_ids |
1513 | for (auto entry : p_id_backup) { |
1514 | entry.first->parallelize(entry.second); |
1515 | } |
1516 | |
1517 | // Indices should now be mapped onto IterDomains in producer, so just grab |
1518 | // and use them. |
1519 | auto root_dom = producer_tv->getMaybeRFactorDomain(); |
1520 | |
1521 | // TODO: Abstract stride logic to reuse with consumer indexing |
1522 | std::vector<Val*> strides(root_dom.size(), nullptr); |
1523 | { |
1524 | int stride_i = 0; |
1525 | for (const auto i : c10::irange(root_dom.size())) { |
1526 | if (root_dom[i]->isReduction()) { |
1527 | strides[i] = GpuLower::current()->kernel()->oneVal(); |
1528 | continue; |
1529 | } |
1530 | std::stringstream ss; |
1531 | ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]" ; |
1532 | strides[i] = |
1533 | SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int); |
1534 | } |
1535 | } |
1536 | |
1537 | TORCH_INTERNAL_ASSERT( |
1538 | root_dom.size() == producer_tv->domain()->contiguity().size()); |
1539 | Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); |
1540 | for (const auto i : c10::irange(root_dom.size())) { |
1541 | auto dim = root_dom.size() - i - 1; |
1542 | if (root_dom[dim]->isReduction()) { |
1543 | continue; |
1544 | } |
1545 | |
1546 | Val* root_ind = nullptr; |
1547 | if (producer_indexing.indexMap().find(root_dom[dim]) != |
1548 | producer_indexing.indexMap().end()) { |
1549 | root_ind = producer_indexing.indexMap().at(root_dom[dim]); |
1550 | } else if (root_dom[dim]->isBroadcast()) { |
1551 | root_ind = GpuLower::current()->kernel()->zeroVal(); |
1552 | } |
1553 | |
1554 | TORCH_INTERNAL_ASSERT( |
1555 | root_ind != nullptr, |
1556 | "Couldn't find root mapping for " , |
1557 | producer_tv->toString(), |
1558 | " dim: " , |
1559 | dim, |
1560 | " id: " , |
1561 | root_dom[dim]->toString()); |
1562 | |
1563 | if (producer_tv->domain()->contiguity()[dim]) { |
1564 | // If contig, used the stored stride which may be the previous |
1565 | // dimensions stride * previous dimensions size |
1566 | strides[dim] = cur_contig_stride; |
1567 | // Prepare for the next dimension which may also be contiguous, multiply |
1568 | // by extent of this dimension |
1569 | auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); |
1570 | cur_contig_stride = |
1571 | SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent); |
1572 | } else { |
1573 | // If non contiguous dimension, keep local stride information, set cur |
1574 | // stride to local stride * local raw extent |
1575 | auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); |
1576 | cur_contig_stride = |
1577 | SimplifyingIrBuilder::mulExpr(strides[dim], root_dim_extent); |
1578 | } |
1579 | } |
1580 | |
1581 | auto vectorize_shift = |
1582 | loops.empty() ? nullptr : loops.back()->vectorize_shift(); |
1583 | |
1584 | // Global striding |
1585 | std::vector<Val*> strided_inds( |
1586 | root_dom.size(), GpuLower::current()->kernel()->zeroVal()); |
1587 | for (const auto i : c10::irange(root_dom.size())) { |
1588 | // If the domain is derived from a trivial reduction, no indexing |
1589 | // to create. |
1590 | if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || |
1591 | gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) { |
1592 | continue; |
1593 | } |
1594 | |
1595 | TORCH_INTERNAL_ASSERT( |
1596 | producer_indexing.indexMap().find(root_dom[i]) != |
1597 | producer_indexing.indexMap().end(), |
1598 | "Couldn't find root mapping for TV" , |
1599 | producer_tv->name(), |
1600 | " dim: " , |
1601 | i, |
1602 | " id: " , |
1603 | root_dom[i]->toString()); |
1604 | |
1605 | auto root_ind = producer_indexing.indexMap().at(root_dom[i]); |
1606 | |
1607 | // index hoist must be done before the adjustments for halo |
1608 | root_ind = hoistProducerIndex( |
1609 | root_dom[i], |
1610 | producer_tv, |
1611 | producer_indexing, |
1612 | consumer_tv, |
1613 | p2c_map, |
1614 | producer_indexing_from_idgraph.resolved_loop_domains, |
1615 | producer_indexing_from_idgraph.initial_concrete_index_map, |
1616 | loops, |
1617 | root_ind); |
1618 | |
1619 | root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); |
1620 | |
1621 | root_ind = getProducerIndexWithGather( |
1622 | root_ind, |
1623 | i, |
1624 | producer_tv, |
1625 | consumer_tv, |
1626 | producer_indexing_from_idgraph.concrete_index.indexMap()); |
1627 | |
1628 | root_ind = getProducerIndexWithPartialSplit( |
1629 | root_ind, root_dom[i], producer_tv, consumer_tv); |
1630 | |
1631 | if (root_ind->isZeroInt()) { |
1632 | continue; |
1633 | } else { |
1634 | auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); |
1635 | if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { |
1636 | strided_inds[i] = |
1637 | SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); |
1638 | } else { |
1639 | strided_inds[i] = strided_ind; |
1640 | } |
1641 | } |
1642 | } |
1643 | |
1644 | return strided_inds; |
1645 | } |
1646 | |
1647 | namespace { |
1648 | |
1649 | // Maps all producer domains to consumer with broadcast |
1650 | // forwarding. Used to find the allocation position. |
1651 | std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer( |
1652 | TensorView* producer_tv, |
1653 | const TensorView* consumer_tv) { |
1654 | // This map has forwarded broadcast axes, it should only be used to compute |
1655 | // the allocation position of the producer, and to figure out which producer |
1656 | // indices are mapped to consumer trivial reductions. |
1657 | std::unordered_map<IterDomain*, IterDomain*> p2c_alloc_map; |
1658 | |
1659 | // We want to replay producer as consumer instead of the other way around |
1660 | // since consumer may have some broadcasted axes producer doesn't have |
1661 | // merged into loops producer may use. If we did consumer as producer we |
1662 | // wouldn't have this information in the mapping. |
1663 | auto replay_PasC = BestEffortReplay::replayPasC( |
1664 | producer_tv, |
1665 | consumer_tv, |
1666 | -1, |
1667 | PairwiseRootDomainMap(producer_tv, consumer_tv)); |
1668 | |
1669 | // Grab consumer domain entries and reverse replay map. TODO: Maybe |
1670 | // TransformReplay::replayPasC could return this map |
1671 | for (auto id : consumer_tv->domain()->domain()) { |
1672 | const auto& c2p_map = replay_PasC.getReplay(); |
1673 | auto c2p_it = c2p_map.find(id); |
1674 | if (c2p_it != c2p_map.end()) { |
1675 | auto c_id = c2p_it->first; |
1676 | auto p_id = c2p_it->second; |
1677 | p2c_alloc_map[p_id] = c_id; |
1678 | } |
1679 | } |
1680 | |
1681 | return p2c_alloc_map; |
1682 | } |
1683 | |
1684 | } // namespace |
1685 | |
1686 | // Producer index for either shared or local memory |
1687 | std::vector<Val*> Index::getNonGlobalProducerStridedIndices( |
1688 | TensorView* producer_tv, |
1689 | const TensorView* consumer_tv, |
1690 | const std::vector<kir::ForLoop*>& loops) { |
1691 | const auto gpu_lower = GpuLower::current(); |
1692 | |
1693 | // Replay producer to look like consumer so we can index on producer since our |
1694 | // loop nests look like consumer |
1695 | auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); |
1696 | auto producer_replayed_as_consumer = |
1697 | TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map) |
1698 | .first; |
1699 | |
1700 | ir_utils::TVDomainGuard domain_guard( |
1701 | producer_tv, producer_replayed_as_consumer); |
1702 | const auto p2c_alloc_map = |
1703 | mapAllProducerDomainsToConsumer(producer_tv, consumer_tv); |
1704 | |
1705 | // Map everything we can from reference to producer using compute at index |
1706 | // map. All producer id's don't exist in the compute at map. The rfactor axes |
1707 | // all may be, but since I haven't proven that to be the case, going to do a |
1708 | // more conservative approach, which is to use the consumer as a proxy between |
1709 | // producer to reference. |
1710 | std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_producer; |
1711 | std::unordered_map<IterDomain*, IterDomain*> c2p_index_map; |
1712 | std::unordered_map<IterDomain*, IterDomain*> p2c_index_map; |
1713 | |
1714 | // Map sent to best effort replay needs to match the exact incantation for |
1715 | // compute_at_mode.cpp with MappingMode::Index |
1716 | auto c2p_root_map = |
1717 | PairwiseRootDomainMap(producer_tv, consumer_tv, true) |
1718 | .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain()); |
1719 | |
1720 | // This replay has to be consistent with compute at index map. |
1721 | BestEffortReplay replay_producer_as_consumer( |
1722 | producer_tv->domain()->domain(), |
1723 | consumer_tv->domain()->domain(), |
1724 | c2p_root_map); |
1725 | |
1726 | c2p_index_map = replay_producer_as_consumer.getReplay(); |
1727 | p2c_index_map = invertOneToOneMap(c2p_index_map); |
1728 | |
1729 | // Forward vectorized IDs to index into producer correctly |
1730 | // We want p_id to be vectorized like consumer just for the indexing, then we |
1731 | // need to switch it back later. Store previous state here when changing. We |
1732 | // need to do this as replaying producer as consumer can use replay best |
1733 | // effort which means some domains may be the originals. |
1734 | std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup; |
1735 | for (auto entry : c2p_index_map) { |
1736 | auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( |
1737 | entry.first, IdMappingMode::EXACT); |
1738 | auto p_id = entry.second; |
1739 | if (ref_id->getParallelType() == ParallelType::Vectorize) { |
1740 | p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); |
1741 | p_id->parallelize(ParallelType::Vectorize); |
1742 | } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) { |
1743 | p_id->parallelize(ParallelType::MisalignedVectorize); |
1744 | } |
1745 | } |
1746 | |
1747 | auto producer_indexing_from_idgraph = getTensorIndexFromIdGraph( |
1748 | loops, consumer_tv, producer_tv, false, c2p_index_map); |
1749 | |
1750 | auto producer_indexing = producer_indexing_from_idgraph.index; |
1751 | |
1752 | // Revert p_ids |
1753 | for (auto entry : p_id_backup) { |
1754 | entry.first->parallelize(entry.second); |
1755 | } |
1756 | |
1757 | IndexSwizzle index_swizzle( |
1758 | producer_tv, |
1759 | producer_indexing.indexMap(), |
1760 | producer_indexing.extentMap(), |
1761 | producer_indexing.zeroDomains(), |
1762 | producer_indexing.zeroMergedIn()); |
1763 | |
1764 | index_swizzle.run(); |
1765 | |
1766 | auto producer_swizzled_index = index_swizzle; |
1767 | |
1768 | if (producer_tv->hasSwizzleOp()) { |
1769 | // Special handling needed on the new swizzle |
1770 | // op pass: |
1771 | // each swizzle op is local to the tensor, |
1772 | // so ReplayPasC will not include the swizzle |
1773 | // ops on the producer iterdomain. So would |
1774 | // need to traverse forward the producer domain |
1775 | // before the replay to get the swizzle ops. |
1776 | IndexSwizzle producer_swizzle2d( |
1777 | producer_tv, |
1778 | domain_guard.prevDomain(), |
1779 | producer_indexing.indexMap(), |
1780 | producer_indexing.extentMap(), |
1781 | producer_indexing.zeroDomains(), |
1782 | producer_indexing.zeroMergedIn()); |
1783 | producer_swizzle2d.run(); |
1784 | producer_swizzled_index = producer_swizzle2d; |
1785 | } |
1786 | |
1787 | // TODO: merge the two swizzle compute logic once the new one is ready. |
1788 | // will need to replace cyclic shift swizzle with xor since swizzle2d |
1789 | // doesn't have cyclic shift. |
1790 | const auto& index_map = producer_swizzled_index.indexMap(); |
1791 | |
1792 | const auto& extent_map = producer_indexing.extentMap(); |
1793 | const auto& zero_domain_map = producer_indexing.zeroDomains(); |
1794 | // Indices should now be mapped onto IterDomains in producer, so just grab |
1795 | // and use them. |
1796 | auto root_dom = producer_tv->getMaybeRFactorDomain(); |
1797 | |
1798 | // Figure out which root axes we don't need to index |
1799 | std::unordered_set<IterDomain*> skip_indexing; |
1800 | |
1801 | for (auto root_id : root_dom) { |
1802 | // Already taken care of because we can detect no indexing required |
1803 | if (root_id->isBroadcast() || root_id->isReduction() || |
1804 | gpu_lower->trivialReductionInfo().isDerived(root_id) || |
1805 | root_id->isStride()) { |
1806 | skip_indexing.insert(root_id); |
1807 | continue; |
1808 | } |
1809 | |
1810 | // Already an entry for this root domain, continue |
1811 | if (index_map.find(root_id) != index_map.end()) { |
1812 | continue; |
1813 | } |
1814 | |
1815 | // Maps to consumers trivial reduction, don't index |
1816 | if (p2c_alloc_map.find(root_id) != p2c_alloc_map.end() && |
1817 | gpu_lower->trivialReductionInfo().isDerived( |
1818 | p2c_alloc_map.at(root_id))) { |
1819 | skip_indexing.emplace(root_id); |
1820 | } |
1821 | } |
1822 | |
1823 | std::vector<Val*> strided_inds( |
1824 | root_dom.size(), GpuLower::current()->kernel()->zeroVal()); |
1825 | for (const auto i : c10::irange(root_dom.size())) { |
1826 | if (skip_indexing.count(root_dom[i])) { |
1827 | continue; |
1828 | } |
1829 | |
1830 | TORCH_INTERNAL_ASSERT( |
1831 | index_map.find(root_dom[i]) != index_map.end(), |
1832 | "Couldn't find root mapping for " , |
1833 | producer_tv->toString(), |
1834 | " dim: " , |
1835 | i, |
1836 | " id: " , |
1837 | root_dom[i]->toString()); |
1838 | |
1839 | auto root_ind_i = index_map.at(root_dom[i]); |
1840 | |
1841 | // index hoist must be done before the adjustments for halo |
1842 | root_ind_i = hoistProducerIndex( |
1843 | root_dom[i], |
1844 | producer_tv, |
1845 | producer_indexing, |
1846 | consumer_tv, |
1847 | p2c_index_map, |
1848 | producer_indexing_from_idgraph.resolved_loop_domains, |
1849 | producer_indexing_from_idgraph.initial_concrete_index_map, |
1850 | loops, |
1851 | root_ind_i); |
1852 | |
1853 | root_ind_i = |
1854 | getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv); |
1855 | |
1856 | root_ind_i = getProducerIndexWithGather( |
1857 | root_ind_i, |
1858 | i, |
1859 | producer_tv, |
1860 | consumer_tv, |
1861 | producer_indexing_from_idgraph.concrete_index.indexMap()); |
1862 | |
1863 | root_ind_i = getProducerIndexWithPartialSplit( |
1864 | root_ind_i, root_dom[i], producer_tv, consumer_tv); |
1865 | |
1866 | if (root_ind_i->isZeroInt()) { |
1867 | continue; |
1868 | } |
1869 | |
1870 | // Compute striding for this index. |
1871 | Val* stride = nullptr; |
1872 | for (const auto j : c10::irange(i + 1, root_dom.size())) { |
1873 | if (skip_indexing.count(root_dom[j])) { |
1874 | continue; |
1875 | } |
1876 | |
1877 | TORCH_INTERNAL_ASSERT( |
1878 | index_map.find(root_dom[j]) != index_map.end(), |
1879 | "Couldn't find root mapping for " , |
1880 | producer_tv->name(), |
1881 | " dim: " , |
1882 | j, |
1883 | " id: " , |
1884 | root_dom[j]->toString()); |
1885 | |
1886 | auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() |
1887 | ? root_dom[j]->extent() |
1888 | : extent_map.at(root_dom[j]); |
1889 | |
1890 | root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); |
1891 | |
1892 | if (zero_domain_map.count(root_dom[j]) == 0) { |
1893 | if (stride == nullptr) { |
1894 | stride = root_ext_j; |
1895 | } else { |
1896 | stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j); |
1897 | } |
1898 | } |
1899 | } |
1900 | |
1901 | if (stride != nullptr) { |
1902 | strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride); |
1903 | } else { |
1904 | strided_inds[i] = root_ind_i; |
1905 | } |
1906 | } |
1907 | |
1908 | if (producer_tv->isDoubleBuffered() || producer_tv->isCircularBuffered()) { |
1909 | auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( |
1910 | producer_tv, loops, true); |
1911 | if (db_loop != nullptr) { |
1912 | auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( |
1913 | db_loop->iter_domain()); |
1914 | auto loop_index = |
1915 | db_loop->isTrivial() ? db_loop->start() : db_loop->index(); |
1916 | auto db_switch_index = SimplifyingIrBuilder::modExpr( |
1917 | loop_index, SimplifyingIrBuilder::create<Int>(stage_depth)); |
1918 | auto original_alloc_size = |
1919 | gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); |
1920 | auto db_strided_index = |
1921 | SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size); |
1922 | strided_inds.push_back(db_strided_index); |
1923 | } |
1924 | } |
1925 | |
1926 | return strided_inds; |
1927 | } |
1928 | |
1929 | std::vector<Val*> Index::getLinearLogicalIndex( |
1930 | TensorView* consumer_tv, |
1931 | const std::vector<kir::ForLoop*>& loops) { |
1932 | auto guard = ir_utils::overrideContiguityGuard(consumer_tv, true); |
1933 | return getGlobalConsumerStridedIndices(consumer_tv, loops); |
1934 | } |
1935 | |
1936 | std::vector<Val*> Index::getPerDimLogicalIndex( |
1937 | TensorView* consumer_tv, |
1938 | const std::vector<kir::ForLoop*>& loops) { |
1939 | auto guard = ir_utils::overrideContiguityGuard(consumer_tv, false); |
1940 | IndexFromIdGraph index_from_id_graph = |
1941 | getTensorIndexFromIdGraph(loops, consumer_tv); |
1942 | return getRootIndices(consumer_tv, loops, index_from_id_graph); |
1943 | } |
1944 | |
1945 | std::vector<Val*> Index::getStrides(const TensorView* tv) { |
1946 | // Indices should now be mapped onto IterDomains in consumer, so just grab |
1947 | // and use them. |
1948 | auto root_dom = tv->getMaybeRFactorDomain(); |
1949 | |
1950 | std::vector<Val*> strides( |
1951 | root_dom.size(), GpuLower::current()->kernel()->oneVal()); |
1952 | { |
1953 | int stride_i = 0; |
1954 | for (const auto i : c10::irange(root_dom.size())) { |
1955 | if (root_dom[i]->isReduction() || root_dom[i]->isStride()) { |
1956 | strides[i] = GpuLower::current()->kernel()->oneVal(); |
1957 | continue; |
1958 | } |
1959 | std::stringstream ss; |
1960 | ss << "T" << tv->name() << ".stride[" << stride_i++ << "]" ; |
1961 | strides[i] = |
1962 | SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int); |
1963 | } |
1964 | } |
1965 | |
1966 | TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size()); |
1967 | Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); |
1968 | for (const auto i : c10::irange(root_dom.size())) { |
1969 | auto dim = root_dom.size() - i - 1; |
1970 | if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) { |
1971 | continue; |
1972 | } |
1973 | |
1974 | if (tv->domain()->contiguity()[dim]) { |
1975 | // If contig, used the stored stride which may be the previous |
1976 | // dimensions stride * previous dimensions size |
1977 | strides[dim] = cur_contig_stride; |
1978 | // Prepare for the next dimension which may also be contiguous, multiply |
1979 | // by extent of this dimension |
1980 | auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); |
1981 | cur_contig_stride = |
1982 | SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent); |
1983 | } else { |
1984 | // If non contiguous dimension, keep local stride information, set cur |
1985 | // stride to local stride * local raw extent |
1986 | cur_contig_stride = SimplifyingIrBuilder::mulExpr( |
1987 | strides[dim], getHaloExtentOfRootAxis(root_dom[dim])); |
1988 | } |
1989 | } |
1990 | return strides; |
1991 | } |
1992 | |
1993 | std::vector<Val*> Index::getRootIndices( |
1994 | const TensorView* tv, |
1995 | const std::vector<kir::ForLoop*>& loops, |
1996 | const IndexFromIdGraph& index_from_id_graph) { |
1997 | auto gpu_lower = GpuLower::current(); |
1998 | auto root_dom = tv->getMaybeRFactorDomain(); |
1999 | auto indexing = index_from_id_graph.index; |
2000 | |
2001 | std::vector<Val*> root_inds( |
2002 | root_dom.size(), GpuLower::current()->kernel()->zeroVal()); |
2003 | for (const auto i : c10::irange(root_dom.size())) { |
2004 | // See a comment in indexing to root domains in getGlobalProducerIndex. |
2005 | if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || |
2006 | gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) || |
2007 | root_dom[i]->isStride()) { |
2008 | continue; |
2009 | } |
2010 | |
2011 | TORCH_INTERNAL_ASSERT( |
2012 | indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(), |
2013 | "Couldn't find root mapping for " , |
2014 | tv->toString(), |
2015 | " dim: " , |
2016 | i, |
2017 | " id: " , |
2018 | root_dom[i]->toString()); |
2019 | |
2020 | auto root_ind = indexing.indexMap().at(root_dom[i]); |
2021 | |
2022 | // index hoist must be done before the adjustments for halo |
2023 | root_ind = hoistConsumerIndex( |
2024 | root_dom[i], |
2025 | tv, |
2026 | indexing, |
2027 | index_from_id_graph.resolved_loop_domains, |
2028 | index_from_id_graph.initial_concrete_index_map, |
2029 | loops, |
2030 | root_ind); |
2031 | |
2032 | root_ind = SimplifyingIrBuilder::addExpr( |
2033 | root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i])); |
2034 | root_inds[i] = root_ind; |
2035 | } |
2036 | return root_inds; |
2037 | } |
2038 | |
2039 | std::vector<Val*> Index::getGlobalConsumerStridedIndices( |
2040 | const TensorView* consumer_tv, |
2041 | const std::vector<kir::ForLoop*>& loops) { |
2042 | FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex" ); |
2043 | |
2044 | auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); |
2045 | auto consumer_indexing = index_from_id_graph.index; |
2046 | auto strides = getStrides(consumer_tv); |
2047 | auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph); |
2048 | |
2049 | // Global striding |
2050 | auto vectorize_shift = |
2051 | loops.empty() ? nullptr : loops.back()->vectorize_shift(); |
2052 | std::vector<Val*> strided_inds( |
2053 | root_inds.size(), GpuLower::current()->kernel()->zeroVal()); |
2054 | for (const auto i : c10::irange(root_inds.size())) { |
2055 | if (root_inds[i]->isZeroInt()) { |
2056 | continue; |
2057 | } else { |
2058 | auto strided_ind = |
2059 | SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]); |
2060 | if (i == strides.size() - 1 && vectorize_shift != nullptr) { |
2061 | strided_inds[i] = |
2062 | SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); |
2063 | } else { |
2064 | strided_inds[i] = strided_ind; |
2065 | } |
2066 | } |
2067 | } |
2068 | |
2069 | TORCH_INTERNAL_ASSERT( |
2070 | strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); |
2071 | |
2072 | return strided_inds; |
2073 | } |
2074 | |
2075 | // Consumer index for either shared or local memory |
2076 | std::vector<Val*> Index::getNonGlobalConsumerStridedIndices( |
2077 | const TensorView* consumer_tv, |
2078 | const std::vector<kir::ForLoop*>& loops) { |
2079 | const auto gpu_lower = GpuLower::current(); |
2080 | |
2081 | auto consumer_indexing_from_idgraph = getTensorIndexFromIdGraph( |
2082 | loops, |
2083 | consumer_tv, |
2084 | // Producer tv |
2085 | nullptr, |
2086 | // Index global |
2087 | false); |
2088 | |
2089 | auto consumer_indexing = consumer_indexing_from_idgraph.index; |
2090 | |
2091 | IndexSwizzle index_swizzle( |
2092 | consumer_tv, |
2093 | consumer_indexing.indexMap(), |
2094 | consumer_indexing.extentMap(), |
2095 | consumer_indexing.zeroDomains(), |
2096 | consumer_indexing.zeroMergedIn()); |
2097 | |
2098 | index_swizzle.run(); |
2099 | |
2100 | const auto& index_map = index_swizzle.indexMap(); |
2101 | const auto& extent_map = consumer_indexing.extentMap(); |
2102 | const auto& zero_domain_map = consumer_indexing.zeroDomains(); |
2103 | |
2104 | // Indices should now be mapped onto IterDomains in consumer, so just grab |
2105 | // and use them. |
2106 | auto root_dom = consumer_tv->getMaybeRFactorDomain(); |
2107 | std::vector<Val*> strided_inds( |
2108 | root_dom.size(), GpuLower::current()->kernel()->zeroVal()); |
2109 | for (const auto i : c10::irange(root_dom.size())) { |
2110 | if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || |
2111 | gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) || |
2112 | root_dom[i]->isStride()) { |
2113 | continue; |
2114 | } |
2115 | |
2116 | TORCH_INTERNAL_ASSERT( |
2117 | index_map.find(root_dom[i]) != index_map.end(), |
2118 | "Couldn't find root mapping for " , |
2119 | consumer_tv->toString(), |
2120 | " dim: " , |
2121 | i, |
2122 | " id: " , |
2123 | root_dom[i]->toString()); |
2124 | |
2125 | auto root_ind_i = index_map.at(root_dom[i]); |
2126 | if (root_ind_i->isZeroInt()) { |
2127 | continue; |
2128 | } |
2129 | |
2130 | // index hoist must be done before the adjustments for halo |
2131 | root_ind_i = hoistConsumerIndex( |
2132 | root_dom[i], |
2133 | consumer_tv, |
2134 | consumer_indexing, |
2135 | consumer_indexing_from_idgraph.resolved_loop_domains, |
2136 | consumer_indexing_from_idgraph.initial_concrete_index_map, |
2137 | loops, |
2138 | root_ind_i); |
2139 | |
2140 | // Compute striding for this index. |
2141 | Val* stride = nullptr; |
2142 | for (const auto j : c10::irange(i + 1, root_dom.size())) { |
2143 | if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() || |
2144 | gpu_lower->trivialReductionInfo().isDerived(root_dom[j]) || |
2145 | root_dom[j]->isStride()) { |
2146 | continue; |
2147 | } |
2148 | |
2149 | TORCH_INTERNAL_ASSERT( |
2150 | index_map.find(root_dom[j]) != index_map.end(), |
2151 | "Couldn't find root mapping for " , |
2152 | consumer_tv->toString(), |
2153 | " dim: " , |
2154 | j, |
2155 | " id: " , |
2156 | root_dom[j]->toString()); |
2157 | |
2158 | auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() |
2159 | ? root_dom[j]->extent() |
2160 | : extent_map.at(root_dom[j]); |
2161 | |
2162 | root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); |
2163 | |
2164 | if (zero_domain_map.count(root_dom[j]) == 0) { |
2165 | if (stride == nullptr) { |
2166 | stride = root_ext_j; |
2167 | } else { |
2168 | stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j); |
2169 | } |
2170 | } |
2171 | } |
2172 | |
2173 | if (stride != nullptr) { |
2174 | strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride); |
2175 | } else { |
2176 | strided_inds[i] = root_ind_i; |
2177 | } |
2178 | } |
2179 | |
2180 | // This check was originally done in getConsumerStridedIndices, but |
2181 | // the number of strided index values depends on the loop where the |
2182 | // consumer tensor is located. If it's double buffered and not in |
2183 | // the prologue loop, strided_inds ends up having one more |
2184 | // index, so it's just much simpler to check here before adding the |
2185 | // additional index for double buffering. |
2186 | TORCH_INTERNAL_ASSERT( |
2187 | strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); |
2188 | |
2189 | if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) { |
2190 | auto db_loop = |
2191 | gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops); |
2192 | auto stage_depth = |
2193 | gpu_lower->doubleBufferInfo().getStageDepthFor(db_loop->iter_domain()); |
2194 | bool is_circular_buffer_loop = stage_depth > 2; |
2195 | bool is_prolog = |
2196 | db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog; |
2197 | |
2198 | Val* db_switch_index = nullptr; |
2199 | |
2200 | // In double buffered we don't materialize the prolog loop as there will |
2201 | // be only one iteration. In circular buffer case we materialize the |
2202 | // prolog loop as well covering the first N-1 iterations, N being the |
2203 | // stage depth. |
2204 | if (!is_prolog || is_circular_buffer_loop) { |
2205 | if (is_prolog && is_circular_buffer_loop) { |
2206 | // The buffer switching logic is the same as original index |
2207 | // in the case of circular buffer prolog. |
2208 | db_switch_index = db_loop->index(); |
2209 | } else { |
2210 | // Switching index generated for main loop or epilog component. |
2211 | db_switch_index = SimplifyingIrBuilder::modExpr( |
2212 | SimplifyingIrBuilder::addExpr( |
2213 | db_loop->index(), |
2214 | SimplifyingIrBuilder::create<Int>(stage_depth - 1)), |
2215 | SimplifyingIrBuilder::create<Int>(stage_depth)); |
2216 | } |
2217 | |
2218 | // Use the generated switching buffer index to access the buffer space. |
2219 | auto original_alloc_size = |
2220 | gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv); |
2221 | auto db_strided_index = |
2222 | SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size); |
2223 | strided_inds.push_back(db_strided_index); |
2224 | } |
2225 | } |
2226 | |
2227 | return strided_inds; |
2228 | } |
2229 | |
2230 | std::vector<Val*> Index::getProducerStridedIndices( |
2231 | TensorView* producer, |
2232 | const TensorView* consumer, |
2233 | const std::vector<kir::ForLoop*>& loops) { |
2234 | FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices" ); |
2235 | if (producer->domain()->noReductions().size() == 0) { |
2236 | return std::vector<Val*>( |
2237 | producer->getMaybeRFactorDomain().size(), |
2238 | GpuLower::current()->kernel()->zeroVal()); |
2239 | } |
2240 | |
2241 | std::vector<Val*> strided_indices; |
2242 | if (producer->getMemoryType() == MemoryType::Global) { |
2243 | strided_indices = |
2244 | getGlobalProducerStridedIndices(producer, consumer, loops); |
2245 | } else { |
2246 | strided_indices = |
2247 | getNonGlobalProducerStridedIndices(producer, consumer, loops); |
2248 | } |
2249 | |
2250 | TORCH_INTERNAL_ASSERT( |
2251 | strided_indices.size() == |
2252 | producer->getMaybeRFactorDomain().size() + |
2253 | (producer->isDoubleBuffered() || producer->isCircularBuffered() ? 1 |
2254 | : 0)); |
2255 | |
2256 | return strided_indices; |
2257 | } |
2258 | |
2259 | // Producer is the inputs of an expression |
2260 | kir::TensorIndex* Index::getProducerIndex( |
2261 | TensorView* producer, |
2262 | const TensorView* consumer, |
2263 | const std::vector<kir::ForLoop*>& loops) { |
2264 | auto strided_indices = getProducerStridedIndices(producer, consumer, loops); |
2265 | return SimplifyingIrBuilder::create<kir::TensorIndex>( |
2266 | producer, strided_indices); |
2267 | } |
2268 | |
2269 | std::vector<Val*> Index::getConsumerStridedIndices( |
2270 | const TensorView* consumer, |
2271 | const std::vector<kir::ForLoop*>& loops) { |
2272 | FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices" ); |
2273 | if (consumer->domain()->noReductions().size() == 0) { |
2274 | return std::vector<Val*>( |
2275 | consumer->getMaybeRFactorDomain().size(), |
2276 | GpuLower::current()->kernel()->zeroVal()); |
2277 | } |
2278 | |
2279 | std::vector<Val*> strided_indices; |
2280 | if (consumer->getMemoryType() == MemoryType::Global) { |
2281 | strided_indices = getGlobalConsumerStridedIndices(consumer, loops); |
2282 | } else { |
2283 | strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops); |
2284 | } |
2285 | |
2286 | return strided_indices; |
2287 | } |
2288 | |
2289 | // Consumer is the output of an expression |
2290 | kir::TensorIndex* Index::getConsumerIndex( |
2291 | const TensorView* consumer, |
2292 | const std::vector<kir::ForLoop*>& loops) { |
2293 | auto strided_indices = getConsumerStridedIndices(consumer, loops); |
2294 | return SimplifyingIrBuilder::create<kir::TensorIndex>( |
2295 | consumer, strided_indices); |
2296 | } |
2297 | |
2298 | namespace { |
2299 | |
2300 | struct PredicateDomainInfo { |
2301 | public: |
2302 | // Iteration domain to predicate |
2303 | IterDomain* id = nullptr; |
2304 | // The set of iteration domains that make up the id. If this is for |
2305 | // a non-divisible split, the set only contains the id itself. This |
2306 | // set is used to remove redundant predicates when gathering |
2307 | // unswitch predicates. |
2308 | std::unordered_set<IterDomain*> covered_ids; |
2309 | // True if this predicate is for a non-divisible split |
2310 | bool is_non_divisible_split = false; |
2311 | }; |
2312 | |
2313 | // Find iteration domains in the history of a consumer to predicate comprised |
2314 | // only of merge operations. Only return iteration domains that are subsequently |
2315 | // fed into a split, or are in the provided domain. In other words, we don't |
2316 | // want to return every IterDomain that's contiguous, just the one closest to |
2317 | // the leaves. Predicates are not associated with physical memory so we can |
2318 | // treat all of them as contiguous merges. |
2319 | // |
2320 | // TODO: This seems to have a large overlap with ContigIDs. Consider |
2321 | // refactoring. |
2322 | std::vector<PredicateDomainInfo> getPredicateContigIds( |
2323 | TensorView* consumer_tv, |
2324 | const std::unordered_map<IterDomain*, Val*>& consumer_index_map) { |
2325 | const auto gpu_lower = GpuLower::current(); |
2326 | |
2327 | const auto& consumer_root_domain = consumer_tv->getRootDomain(); |
2328 | |
2329 | if (consumer_root_domain.empty()) { |
2330 | return std::vector<PredicateDomainInfo>(); |
2331 | } |
2332 | |
2333 | std::unordered_map<IterDomain*, Val*> concrete_index_map; |
2334 | for (auto entry : consumer_index_map) { |
2335 | auto c_id = gpu_lower->caMap()->getConcreteMappedID( |
2336 | entry.first, IdMappingMode::EXACT); |
2337 | concrete_index_map[c_id] = entry.second; |
2338 | } |
2339 | |
2340 | std::vector<bool> predicate_contiguity(consumer_root_domain.size(), true); |
2341 | std::unordered_set<IterDomain*> final_ids; |
2342 | for (auto root_i : c10::irange(predicate_contiguity.size())) { |
2343 | auto root_id = consumer_root_domain[root_i]; |
2344 | if (root_id->maybePartial()) { |
2345 | final_ids.insert(root_id); |
2346 | continue; |
2347 | } |
2348 | // Shifted or gathered axes need to be predicated at the root domain |
2349 | auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition()); |
2350 | auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition()); |
2351 | if ((shift_expr && shift_expr->offset(root_i) != 0) || |
2352 | (gather_expr && root_i < gather_expr->windowShape().size() && |
2353 | gather_expr->windowShape().at(root_i) != 1)) { |
2354 | final_ids.insert(root_id); |
2355 | } |
2356 | } |
2357 | |
2358 | ContigIDs contig_finder( |
2359 | consumer_tv->domain()->domain(), |
2360 | consumer_root_domain, |
2361 | predicate_contiguity, |
2362 | final_ids, |
2363 | concrete_index_map, |
2364 | GpuLower::current()->divisbleSplitSet(), |
2365 | GpuLower::current()->caMap(), |
2366 | GpuLower::current()->haloInfo(), |
2367 | GpuLower::current()->concretizedBroadcastDomains(), |
2368 | {}, |
2369 | false, |
2370 | true); |
2371 | |
2372 | std::vector<PredicateDomainInfo> contig_id_infos; |
2373 | std::unordered_set<IterDomain*> covered_roots; |
2374 | |
2375 | // Create entries and return them |
2376 | for (auto root_id : consumer_root_domain) { |
2377 | if (covered_roots.count(root_id) > 0) { |
2378 | continue; |
2379 | } |
2380 | |
2381 | auto contig_id_it = contig_finder.rootToIndexedID().find(root_id); |
2382 | |
2383 | TORCH_INTERNAL_ASSERT( |
2384 | contig_id_it != contig_finder.rootToIndexedID().end(), |
2385 | "Error in predicate contiguity analysis, missing index for root " , |
2386 | root_id->toString()); |
2387 | |
2388 | auto contig_id = contig_id_it->second; |
2389 | |
2390 | // Pick inputs from the starting domains, i.e., |
2391 | // reference_predicated_root_domain. |
2392 | auto contig_root_ids = contig_finder.indexedRootIDs(contig_id); |
2393 | covered_roots.insert(contig_root_ids.begin(), contig_root_ids.end()); |
2394 | PredicateDomainInfo contig_id_info; |
2395 | contig_id_info.id = contig_id; |
2396 | contig_id_info.covered_ids = std::unordered_set<IterDomain*>( |
2397 | contig_root_ids.begin(), contig_root_ids.end()); |
2398 | contig_id_infos.push_back(contig_id_info); |
2399 | } |
2400 | return contig_id_infos; |
2401 | } |
2402 | |
2403 | std::vector<PredicateDomainInfo> getNonDivisibleConsumerDomainsToPredicate( |
2404 | TensorView* consumer_tv) { |
2405 | const auto& non_divisible_split_info = |
2406 | GpuLower::current()->nonDivisibleSplitInfo(); |
2407 | |
2408 | std::vector<PredicateDomainInfo> pred_info_vec; |
2409 | |
2410 | auto it = non_divisible_split_info.splitsToPredicate().find(consumer_tv); |
2411 | if (it == non_divisible_split_info.splitsToPredicate().end()) { |
2412 | return {}; |
2413 | } |
2414 | |
2415 | const auto& splits_to_predicate = it->second; |
2416 | |
2417 | for (auto split : splits_to_predicate) { |
2418 | PredicateDomainInfo info{split->in(), {split->in()}, true}; |
2419 | pred_info_vec.emplace_back(info); |
2420 | } |
2421 | |
2422 | return pred_info_vec; |
2423 | } |
2424 | |
2425 | bool needsPadding(TensorView* tv) { |
2426 | auto shift_expr = dynamic_cast<ShiftOp*>(tv->definition()); |
2427 | auto gather_expr = dynamic_cast<GatherOp*>(tv->definition()); |
2428 | |
2429 | return (shift_expr != nullptr && shift_expr->hasPadding()) || |
2430 | (gather_expr != nullptr && gather_expr->hasPadding()); |
2431 | } |
2432 | |
2433 | // Get an additional offset of a stop index when building a predicate |
2434 | // for unswitch. Initial stop indices generated at |
2435 | // getPredicateIndexingFromIdGraph do not take halo into account, and the |
2436 | // adjustment for halo is done as an additional offset to the final index value |
2437 | // so that unswitch predicates can be compared with each other by just looking |
2438 | // at the additional offsets. |
2439 | // |
2440 | // consumer_root_id: the domain for which a stop predicate is being built. |
2441 | int getUnswitchStopOffset( |
2442 | IterDomain* consumer_root_id, |
2443 | TensorView* consumer_tv) { |
2444 | const auto gpu_lower = GpuLower::current(); |
2445 | |
2446 | AxisHaloInfo halo_info = |
2447 | gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id); |
2448 | |
2449 | // If the consumer root domain to predicate does not have halo, no |
2450 | // adjustment is required. |
2451 | if (!halo_info.hasHalo()) { |
2452 | return 0; |
2453 | } |
2454 | |
2455 | // Find if this contig_id is used in the unswitched domains |
2456 | auto unswitch_it = std::find_if( |
2457 | consumer_tv->domain()->domain().begin(), |
2458 | consumer_tv->domain()->domain().end(), |
2459 | [](IterDomain* id) { |
2460 | return id->getParallelType() == ParallelType::Unswitch || |
2461 | id->getParallelType() == ParallelType::Unroll || |
2462 | id->getParallelType() == ParallelType::Vectorize; |
2463 | }); |
2464 | |
2465 | // If any of the unswitched leaf domains inherits the halo from the |
2466 | // root domain, the halo width needs to be added to the stop offset |
2467 | if (std::any_of( |
2468 | unswitch_it, |
2469 | consumer_tv->domain()->domain().end(), |
2470 | [&gpu_lower, &consumer_root_id](auto leaf_id) { |
2471 | return gpu_lower->haloInfo()->isHaloInherited( |
2472 | consumer_root_id, leaf_id); |
2473 | })) { |
2474 | return halo_info.width(); |
2475 | } else { |
2476 | return 0; |
2477 | } |
2478 | } |
2479 | |
2480 | std::pair<Val*, Val*> getStartAndStopOffsetsForShift( |
2481 | TensorView* consumer_tv, |
2482 | IterDomain* consumer_id, |
2483 | bool padding_predicate) { |
2484 | TORCH_INTERNAL_ASSERT(consumer_id != nullptr); |
2485 | |
2486 | auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition()); |
2487 | |
2488 | // Adjustment is not necessary if not shift. |
2489 | // Even so, padding predicate does not need any adjustment. |
2490 | if (shift_expr == nullptr || padding_predicate) { |
2491 | return { |
2492 | GpuLower::current()->kernel()->zeroVal(), |
2493 | GpuLower::current()->kernel()->zeroVal()}; |
2494 | } |
2495 | |
2496 | const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); |
2497 | |
2498 | // The first or last N elements, where N is the padding width, |
2499 | // correspond to the padding predicate. |
2500 | |
2501 | const auto shift_offset = shift_expr->offset(root_axis_pos); |
2502 | const auto pad_width = shift_expr->padWidth().at(root_axis_pos); |
2503 | |
2504 | int start_offset = 0; |
2505 | int stop_offset = 0; |
2506 | |
2507 | if (shift_offset > 0) { |
2508 | start_offset = -pad_width; |
2509 | } else if (shift_offset < 0) { |
2510 | stop_offset = pad_width; |
2511 | } |
2512 | |
2513 | return { |
2514 | SimplifyingIrBuilder::create<Int>(start_offset), |
2515 | SimplifyingIrBuilder::create<Int>(stop_offset)}; |
2516 | } |
2517 | |
2518 | std::pair<Val*, Val*> getStartAndStopOffsetsForGather( |
2519 | TensorView* consumer_tv, |
2520 | IterDomain* consumer_id, |
2521 | const std::unordered_map<IterDomain*, Val*>& ref_start_index_map, |
2522 | const std::unordered_map<IterDomain*, Val*>& ref_stop_index_map, |
2523 | bool padding_predicate) { |
2524 | TORCH_INTERNAL_ASSERT(consumer_id != nullptr); |
2525 | |
2526 | // Adjustment is not necessary if not gather. Even so, padding |
2527 | // predicate does not need any adjustment. |
2528 | if (!consumer_tv->definition()->isA<GatherOp>() || padding_predicate) { |
2529 | return { |
2530 | GpuLower::current()->kernel()->zeroVal(), |
2531 | GpuLower::current()->kernel()->zeroVal()}; |
2532 | } |
2533 | |
2534 | const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); |
2535 | |
2536 | auto producer_start_offset = getProducerOffsetWithGather( |
2537 | root_axis_pos, consumer_tv, ref_start_index_map); |
2538 | |
2539 | auto producer_stop_offset = getProducerOffsetWithGather( |
2540 | root_axis_pos, consumer_tv, ref_stop_index_map); |
2541 | |
2542 | auto consumer_start_offset = GpuLower::current()->kernel()->zeroVal(); |
2543 | auto consumer_stop_offset = GpuLower::current()->kernel()->zeroVal(); |
2544 | |
2545 | if (producer_start_offset->isZeroInt() && producer_stop_offset->isZeroInt()) { |
2546 | return {consumer_start_offset, consumer_stop_offset}; |
2547 | } |
2548 | |
2549 | Val* start_offset = nullptr; |
2550 | Val* stop_offset = nullptr; |
2551 | |
2552 | // In the normal case, take the minimum of the start and the |
2553 | // maximum of the stop offsets. If there's no padding, the producer |
2554 | // offset must be always larger than the consumer |
2555 | // offset. So, the consumer and produce offsets can be always used |
2556 | // for the start and stop offsets, respectively. |
2557 | const auto pad_left = |
2558 | consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][0]; |
2559 | const auto pad_right = |
2560 | consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][1]; |
2561 | const auto window_size = |
2562 | consumer_tv->definition()->as<GatherOp>()->windowShape()[root_axis_pos]; |
2563 | |
2564 | // consumer index: index |
2565 | // producer index: index + window_index - pad_left |
2566 | // |
2567 | // consumer extent: ext |
2568 | // producer extent: ext + window_size - 1 - pad_left - pad_right |
2569 | // |
2570 | // consumer stop pred: index < ext |
2571 | // producer stop pred: index + window_index - pad_left < ext + window_size - 1 |
2572 | // - pad_left - pad_right |
2573 | // -> index + window_index - pad_left - (window_size - 1 - |
2574 | // pad_left - pad_right) < ext |
2575 | // -> index + window_index - (window_size - 1 - pad_right) < |
2576 | // ext |
2577 | // |
2578 | // consumer start pred: index >= 0 |
2579 | // producer start pred: index + window_index - pad_left >= 0 |
2580 | |
2581 | const auto producer_ext_adj = window_size - 1 - pad_left - pad_right; |
2582 | producer_stop_offset = SimplifyingIrBuilder::subExpr( |
2583 | producer_stop_offset, |
2584 | SimplifyingIrBuilder::create<Int>(producer_ext_adj)); |
2585 | |
2586 | // As commented above, when pad_left is zero, the consumer predicate |
2587 | // is always more restrictive than the producer predicate. |
2588 | if (pad_left == 0) { |
2589 | start_offset = consumer_start_offset; |
2590 | } else { |
2591 | start_offset = SimplifyingIrBuilder::minExpr( |
2592 | consumer_start_offset, producer_start_offset); |
2593 | } |
2594 | |
2595 | // As commented above, when pad_right is zero, the consumer |
2596 | // predicate is always more restrictive than the producer |
2597 | // predicate. |
2598 | if (pad_right == 0) { |
2599 | stop_offset = consumer_stop_offset; |
2600 | } else { |
2601 | stop_offset = SimplifyingIrBuilder::maxExpr( |
2602 | consumer_stop_offset, producer_stop_offset); |
2603 | } |
2604 | |
2605 | TORCH_INTERNAL_ASSERT(start_offset != nullptr); |
2606 | TORCH_INTERNAL_ASSERT(stop_offset != nullptr); |
2607 | |
2608 | return {start_offset, stop_offset}; |
2609 | } |
2610 | |
2611 | // Get the start and stop limit offsets that define the valid range to |
2612 | // compute. In the simplest case, they are just 0 and |
2613 | // IterDomain::extent. However, IterDomain may have non-zero start and |
2614 | // stop that's different from extent. Also, when IterDomain has halo, |
2615 | // the actual offsets of the logical start and stop positions are |
2616 | // shifted. |
2617 | std::pair<Val*, Val*> getStartAndStopLimitOffsets( |
2618 | IterDomain* consumer_id, |
2619 | bool padding_predicate, |
2620 | bool non_divisible_pred) { |
2621 | const auto gpu_lower = GpuLower::current(); |
2622 | |
2623 | TORCH_INTERNAL_ASSERT(consumer_id != nullptr); |
2624 | |
2625 | Val* start_limit = consumer_id->start(); |
2626 | Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); |
2627 | |
2628 | if (!non_divisible_pred) { |
2629 | AxisHaloInfo halo_info = |
2630 | gpu_lower->haloInfo()->getRootAxisInfo(consumer_id); |
2631 | |
2632 | // Below, "left" and "right" halo mean halo at offset zero and |
2633 | // axis extent, respectively. |
2634 | // |
2635 | // The consumer axis looks like this: |
2636 | // |
2637 | // [0, left halo)[start_limit, stop_limit)[0, right halo) |
2638 | // |
2639 | if (!padding_predicate) { |
2640 | start_limit = |
2641 | SimplifyingIrBuilder::addExpr(start_limit, halo_info.width(0)); |
2642 | stop_limit = |
2643 | SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width(0)); |
2644 | } else { |
2645 | // In case of the padding predicate, the whole range, including both left |
2646 | // and right halo regions, is computed. |
2647 | stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width()); |
2648 | } |
2649 | } else { |
2650 | // For non-divisible predicates, the index must be predicated such |
2651 | // that it is less than the extent of the predicated ID + |
2652 | // halo. Note that getRootAxisInfo doesn't work since consumer_id |
2653 | // isn't a root domain. |
2654 | if (gpu_lower->haloInfo()->hasHaloWidth(consumer_id)) { |
2655 | auto halo = gpu_lower->haloInfo()->getHaloWidth(consumer_id); |
2656 | stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo); |
2657 | } |
2658 | } |
2659 | |
2660 | return {start_limit, stop_limit}; |
2661 | } |
2662 | |
2663 | // Get the offsets for the start and stop predicates. The offsets |
2664 | // are to be added to the index. |
2665 | std::pair<Val*, Val*> getStartAndStopOffsets( |
2666 | IterDomain* consumer_id, |
2667 | TensorView* consumer_tv, |
2668 | const std::unordered_map<IterDomain*, Val*>& consumer_start_index_map, |
2669 | const std::unordered_map<IterDomain*, Val*>& consumer_stop_index_map, |
2670 | bool padding_predicate, |
2671 | bool unswitch, |
2672 | bool non_divisible_pred) { |
2673 | // By default, the offsets for the start and stop predicates are |
2674 | // just zero. All halo-related adjustments are done at root domains, |
2675 | // so consumer_id is not a root domain, no adjustment is required. |
2676 | if (consumer_id->definition() != nullptr && !non_divisible_pred) { |
2677 | return { |
2678 | GpuLower::current()->kernel()->zeroVal(), |
2679 | GpuLower::current()->kernel()->zeroVal()}; |
2680 | } |
2681 | |
2682 | auto consumer_def = consumer_tv->definition(); |
2683 | |
2684 | Val* start_offset = GpuLower::current()->kernel()->zeroVal(); |
2685 | Val* stop_offset = GpuLower::current()->kernel()->zeroVal(); |
2686 | |
2687 | // These adjustments are not required when predicating non-divisible splits |
2688 | if (!non_divisible_pred) { |
2689 | if (consumer_def->isA<ShiftOp>()) { |
2690 | std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift( |
2691 | consumer_tv, consumer_id, padding_predicate); |
2692 | } else if (consumer_def->isA<GatherOp>()) { |
2693 | std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForGather( |
2694 | consumer_tv, |
2695 | consumer_id, |
2696 | consumer_start_index_map, |
2697 | consumer_stop_index_map, |
2698 | padding_predicate); |
2699 | } |
2700 | |
2701 | // Adjustment for partial split |
2702 | auto partial_split_offset = |
2703 | getGlobalConsumerOffsetWithPartialSplit(consumer_id); |
2704 | start_offset = |
2705 | SimplifyingIrBuilder::addExpr(start_offset, partial_split_offset); |
2706 | stop_offset = |
2707 | SimplifyingIrBuilder::addExpr(stop_offset, partial_split_offset); |
2708 | |
2709 | // If generating a predicate for unswitch, adjust the stop offset to |
2710 | // accommodate the addition of halo to the loop stop. See the |
2711 | // comment in getPredicateIndexingFromIdGraph as well. |
2712 | if (unswitch) { |
2713 | TORCH_INTERNAL_ASSERT( |
2714 | !padding_predicate, "Unswitch should not use the padding predicate" ); |
2715 | auto stop_unswitch_offset = |
2716 | getUnswitchStopOffset(consumer_id, consumer_tv); |
2717 | stop_offset = |
2718 | SimplifyingIrBuilder::addExpr(stop_offset, stop_unswitch_offset); |
2719 | } |
2720 | } |
2721 | |
2722 | // Get the boundaries of two ends |
2723 | auto limits = getStartAndStopLimitOffsets( |
2724 | consumer_id, padding_predicate, non_divisible_pred); |
2725 | |
2726 | // At this point, we have everything to create both start and stop |
2727 | // predicates as: |
2728 | // |
2729 | // index + start_offset >= start_limit |
2730 | // index + stop_offset < extent + stop_limit |
2731 | // |
2732 | // In order to enable consolidating unswitch predicates, organize |
2733 | // the predicates as: |
2734 | // |
2735 | // index + (start_offset - start_limit) >= 0 |
2736 | // index + (stop_offset - stop_limit) < extent |
2737 | |
2738 | start_offset = SimplifyingIrBuilder::subExpr(start_offset, limits.first); |
2739 | stop_offset = SimplifyingIrBuilder::subExpr(stop_offset, limits.second); |
2740 | |
2741 | return {start_offset, stop_offset}; |
2742 | } |
2743 | |
2744 | // A partial value of a start offset is returned if determined to be |
2745 | // safe. Nullptr is returned if it can be omitted completely. |
2746 | Val* simplifyStartOffset(Val* start_offset) { |
2747 | // Start predicate can be omitted when start_offset >= 0. |
2748 | auto offset_val = start_offset->as<Int>()->value(); |
2749 | if (offset_val.has_value() && offset_val.value() >= 0) { |
2750 | return nullptr; |
2751 | } |
2752 | |
2753 | // start_offset may look like min(0, window_index - pad). Then, can |
2754 | // remove min and leave the rhs only. |
2755 | auto def = dynamic_cast<BinaryOp*>(start_offset->definition()); |
2756 | if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min && |
2757 | def->lhs()->isZeroInt()) { |
2758 | return def->rhs(); |
2759 | } |
2760 | |
2761 | return start_offset; |
2762 | } |
2763 | |
2764 | bool canOmitStopPredicate( |
2765 | Val* stop_index, |
2766 | Val* stop_offset, |
2767 | IterDomain* contig_id) { |
2768 | bool index_simple = stop_index->definition() == nullptr; |
2769 | // The definition may be just adding the magic zero, which can be |
2770 | // effectively considered "simple" |
2771 | if (!index_simple && isProtectedWithMagicZero(stop_index)) { |
2772 | // Make sure the lhs of stop_index is simple. |
2773 | auto lhs = stop_index->definition()->as<BinaryOp>()->lhs(); |
2774 | if (lhs->definition() == nullptr) { |
2775 | index_simple = true; |
2776 | } |
2777 | } |
2778 | |
2779 | if (!index_simple) { |
2780 | return false; |
2781 | } |
2782 | |
2783 | const auto gpu_lower = GpuLower::current(); |
2784 | |
2785 | // Stop predicate: stop_index + stop_offset < extent, where |
2786 | // stop_index ranges from 0 to (extent + halo), so this can be |
2787 | // omitted if extent + halo + stop_offset < extent, i.e., halo + |
2788 | // stop_offset <= 0. |
2789 | |
2790 | auto stop_offset_val = stop_offset->as<Int>()->value(); |
2791 | |
2792 | // If they are not compile-time constant, can't prove the |
2793 | // condition. |
2794 | if (!stop_offset_val.has_value()) { |
2795 | return false; |
2796 | } |
2797 | |
2798 | // Note that when a root domain is halo extended, it is the domain |
2799 | // to be predicated, not its merged contig id even if it exists. So, |
2800 | // if contig_id does not have root axis info, contig_id is |
2801 | // guaranteed to have no halo. |
2802 | auto halo_ext = gpu_lower->haloInfo()->hasRootAxisInfo(contig_id) |
2803 | ? gpu_lower->haloInfo()->getRootAxisInfo(contig_id).width() |
2804 | : 0; |
2805 | |
2806 | if (halo_ext + stop_offset_val.value() > 0) { |
2807 | return false; |
2808 | } |
2809 | |
2810 | // When the domain is parallelized, the parallel dimension must be |
2811 | // exact. Otherwise, there would be extra threads/blocks that need |
2812 | // to be predicated out. |
2813 | if (isParallelTypeThread(contig_id->getParallelType())) { |
2814 | if (!gpu_lower->parallelDimensionMap().isExact( |
2815 | contig_id->getParallelType())) { |
2816 | return false; |
2817 | } |
2818 | // If the domain has halo, the loop is expanded by the halo |
2819 | // extent, so we can't prove the loop extent is the same as the |
2820 | // parallel dimension. |
2821 | if (halo_ext != 0) { |
2822 | return false; |
2823 | } |
2824 | } |
2825 | |
2826 | return true; |
2827 | } |
2828 | |
2829 | std::pair<Val*, Val*> hoistPredicates( |
2830 | Val* start_index, |
2831 | Val* stop_index, |
2832 | const std::vector<kir::ForLoop*>& loops, |
2833 | std::vector<IterDomain*> loop_domains, |
2834 | const std::unordered_map<IterDomain*, Val*>& start_initial_loop_index_map, |
2835 | const std::unordered_map<IterDomain*, Val*>& stop_initial_loop_index_map, |
2836 | kir::ForLoop* unswitch_or_vec_loop, |
2837 | IterDomain* predicated_consumer_id, |
2838 | TensorView* predicated_consumer_tv) { |
2839 | const std::pair<Val*, Val*> same_indices{start_index, stop_index}; |
2840 | |
2841 | if (isOptionDisabled(DisableOption::IndexHoist)) { |
2842 | return same_indices; |
2843 | } |
2844 | |
2845 | const auto start_is_same_as_stop = stop_index == start_index; |
2846 | |
2847 | Val* hoisted_stop_index = nullptr; |
2848 | |
2849 | if (stop_index->definition() == nullptr) { |
2850 | // If the index doens't have an expression, nothing to hoist |
2851 | hoisted_stop_index = stop_index; |
2852 | } else { |
2853 | bool inserted = false; |
2854 | std::tie(hoisted_stop_index, inserted) = |
2855 | GpuLower::current()->commonIndexMap().insert( |
2856 | predicated_consumer_id, |
2857 | predicated_consumer_tv->domain(), |
2858 | loop_domains, |
2859 | stop_initial_loop_index_map, |
2860 | loops, |
2861 | stop_index); |
2862 | } |
2863 | |
2864 | Val* hoisted_start_index = nullptr; |
2865 | if (start_is_same_as_stop) { |
2866 | hoisted_start_index = hoisted_stop_index; |
2867 | } else if (start_index->definition() == nullptr) { |
2868 | hoisted_start_index = start_index; |
2869 | } else { |
2870 | bool inserted = false; |
2871 | std::tie(hoisted_start_index, inserted) = |
2872 | GpuLower::current()->commonIndexMap().insert( |
2873 | predicated_consumer_id, |
2874 | predicated_consumer_tv->domain(), |
2875 | loop_domains, |
2876 | start_initial_loop_index_map, |
2877 | loops, |
2878 | start_index); |
2879 | } |
2880 | |
2881 | return {hoisted_start_index, hoisted_stop_index}; |
2882 | } |
2883 | |
2884 | // Updates a loop index map with a loop index protected by magic zero |
2885 | std::unordered_map<IterDomain*, Val*> updateInitialLoopIndexMap( |
2886 | const std::unordered_map<IterDomain*, Val*>& initial_loop_index_map, |
2887 | const IndexMagicZeroInfo& magic_zero_info) { |
2888 | if (magic_zero_info.original_loop_index != nullptr) { |
2889 | TORCH_INTERNAL_ASSERT(magic_zero_info.protected_loop_index != nullptr); |
2890 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
2891 | magic_zero_info.loop_id, IdMappingMode::EXACT); |
2892 | auto updated_map = initial_loop_index_map; |
2893 | updated_map[concrete_loop_id] = magic_zero_info.protected_loop_index; |
2894 | return updated_map; |
2895 | } else { |
2896 | return initial_loop_index_map; |
2897 | } |
2898 | } |
2899 | |
2900 | } // namespace |
2901 | |
2902 | // Returns predicates and the concrete (by loop map) root domains they cover |
2903 | std::vector<RootPredicateInfo> Index::getReferenceRootPredicates( |
2904 | TensorView* consumer_tv, |
2905 | const std::vector<kir::ForLoop*>& loops, |
2906 | kir::ForLoop* unswitch_or_vec_loop, |
2907 | bool shift_padding) { |
2908 | FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates" ); |
2909 | |
2910 | const auto gpu_lower = GpuLower::current(); |
2911 | |
2912 | const bool is_unswitch = unswitch_or_vec_loop != nullptr; |
2913 | |
2914 | // Nothing needs to be done when padding is not required. |
2915 | if (shift_padding && !needsPadding(consumer_tv)) { |
2916 | return {RootPredicateInfo::getFalseInfo()}; |
2917 | } |
2918 | |
2919 | auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); |
2920 | |
2921 | // Generate start and stop indexing from idgraph. |
2922 | // |
2923 | // Both start and stop positions may need to be predicated. Indexing |
2924 | // differs when generating predicates for unswitch. |
2925 | // NOTE: If we could find-and-replace KIR nodes, we could just |
2926 | // generate one index map, clone it and replace the loop-to-index |
2927 | // mappings of unswitched loops for the start predicate. |
2928 | |
2929 | auto stop_indexing_from_idgraph = getPredicateIndexingFromIdGraph( |
2930 | loops, consumer_tv, unswitch_or_vec_loop, db_axis, false); |
2931 | const auto consumer_stop_indexing = stop_indexing_from_idgraph.index; |
2932 | const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap(); |
2933 | |
2934 | // If not unswitch, share the same indexing map as the stop index |
2935 | // map |
2936 | const auto start_indexing_from_idgraph = is_unswitch |
2937 | ? getPredicateIndexingFromIdGraph( |
2938 | loops, consumer_tv, unswitch_or_vec_loop, db_axis, true) |
2939 | : stop_indexing_from_idgraph; |
2940 | const auto consumer_start_indexing = start_indexing_from_idgraph.index; |
2941 | const auto& consumer_start_index_map = consumer_start_indexing.indexMap(); |
2942 | |
2943 | // Get the contiguous ids we need to generate predicates for |
2944 | auto contig_id_infos = |
2945 | getPredicateContigIds(consumer_tv, consumer_stop_index_map); |
2946 | |
2947 | auto non_divisible_splits = |
2948 | getNonDivisibleConsumerDomainsToPredicate(consumer_tv); |
2949 | contig_id_infos.insert( |
2950 | contig_id_infos.end(), |
2951 | non_divisible_splits.begin(), |
2952 | non_divisible_splits.end()); |
2953 | |
2954 | std::vector<RootPredicateInfo> pred_info_vec; |
2955 | |
2956 | for (auto contig_id_entry : contig_id_infos) { |
2957 | auto contig_id = contig_id_entry.id; |
2958 | // No predicates needed for braodcasted indices. |
2959 | if (contig_id->isBroadcast() || |
2960 | gpu_lower->trivialReductionInfo().isDerived(contig_id)) { |
2961 | continue; |
2962 | } |
2963 | |
2964 | auto root_ids = contig_id_entry.covered_ids; |
2965 | |
2966 | const auto consumer_stop_indexing_it = |
2967 | consumer_stop_index_map.find(contig_id); |
2968 | |
2969 | // First condition below happens with Misaligned predicates, where |
2970 | // inner-most vectorized loops are not included in the loops |
2971 | // parameter. Predicates involving vectorized loops are separately |
2972 | // generated in lower_misaligned_vectorization. |
2973 | // |
2974 | // Second condition is simply to avoid predication on broadcasting axes as |
2975 | // it's not required. |
2976 | if (consumer_stop_indexing_it == consumer_stop_index_map.end() || |
2977 | consumer_stop_indexing_it->second->isZeroInt()) { |
2978 | continue; |
2979 | } |
2980 | |
2981 | RootPredicateInfo info; |
2982 | |
2983 | // Compute offsets for start and stop predicate. For non-shift, |
2984 | // non-gather ops, there's only stop predicate as indices never be |
2985 | // negative. However, for shift and gather, the index may need to |
2986 | // be predicated so that it is >= zero. |
2987 | // |
2988 | // Furthermore, in case of gather, both producer and consumer |
2989 | // positions may need to be predicated, so there can be multiple |
2990 | // offset values. |
2991 | // |
2992 | // The final predicates will look like: |
2993 | // (index + start_offset) >= 0 && (index + stop_offset) < extent. |
2994 | |
2995 | std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets( |
2996 | contig_id, |
2997 | consumer_tv, |
2998 | consumer_start_index_map, |
2999 | consumer_stop_index_map, |
3000 | shift_padding, |
3001 | unswitch_or_vec_loop != nullptr, |
3002 | contig_id_entry.is_non_divisible_split); |
3003 | |
3004 | auto stop_index = consumer_stop_indexing_it->second; |
3005 | auto start_index = consumer_start_index_map.at(contig_id); |
3006 | |
3007 | IndexMagicZeroInfo start_magic_zero_info; |
3008 | IndexMagicZeroInfo stop_magic_zero_info; |
3009 | |
3010 | // When the start and stop indices are not the same, apply the |
3011 | // magic-zero protection separately for both of them. |
3012 | if (stop_index != start_index) { |
3013 | start_magic_zero_info = protectPredicateIndexWithMagicZero( |
3014 | start_index, start_indexing_from_idgraph, loops); |
3015 | stop_magic_zero_info = protectPredicateIndexWithMagicZero( |
3016 | stop_index, stop_indexing_from_idgraph, loops); |
3017 | } else { |
3018 | stop_magic_zero_info = protectPredicateIndexWithMagicZero( |
3019 | stop_index, stop_indexing_from_idgraph, loops); |
3020 | start_magic_zero_info = stop_magic_zero_info; |
3021 | } |
3022 | |
3023 | start_index = start_magic_zero_info.index; |
3024 | stop_index = stop_magic_zero_info.index; |
3025 | |
3026 | // Update the loop-index map with the magic-zero protection info |
3027 | // before passing it to the hoisting function |
3028 | std::tie(start_index, stop_index) = hoistPredicates( |
3029 | start_index, |
3030 | stop_index, |
3031 | loops, |
3032 | stop_indexing_from_idgraph.resolved_loop_domains, |
3033 | updateInitialLoopIndexMap( |
3034 | start_indexing_from_idgraph.initial_concrete_index_map, |
3035 | start_magic_zero_info), |
3036 | updateInitialLoopIndexMap( |
3037 | stop_indexing_from_idgraph.initial_concrete_index_map, |
3038 | stop_magic_zero_info), |
3039 | unswitch_or_vec_loop, |
3040 | contig_id, |
3041 | consumer_tv); |
3042 | |
3043 | // Build predicates for start positions as: |
3044 | // start_index + start_offset >= 0 |
3045 | auto start_offset = simplifyStartOffset(info.start_offset_); |
3046 | if (start_offset == nullptr) { |
3047 | info.start_predicate_ = GpuLower::current()->kernel()->trueVal(); |
3048 | } else { |
3049 | auto offsetted_start_index = |
3050 | SimplifyingIrBuilder::addExpr(start_index, start_offset); |
3051 | auto start_pred = |
3052 | SimplifyingIrBuilder::geExpr( |
3053 | offsetted_start_index, GpuLower::current()->kernel()->zeroVal()) |
3054 | ->as<Bool>(); |
3055 | info.start_predicate_ = start_pred; |
3056 | } |
3057 | |
3058 | // Build predicates for stop positions as: |
3059 | // stop_index + stop_offset < IterDomain::extent |
3060 | auto stop_offset = info.stop_offset_; |
3061 | if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) { |
3062 | info.stop_predicate_ = GpuLower::current()->kernel()->trueVal(); |
3063 | } else { |
3064 | auto offsetted_stop_index = |
3065 | SimplifyingIrBuilder::addExpr(stop_index, stop_offset); |
3066 | auto stop_pred = SimplifyingIrBuilder::ltExpr( |
3067 | offsetted_stop_index, contig_id->extent()) |
3068 | ->as<Bool>(); |
3069 | info.stop_predicate_ = stop_pred; |
3070 | } |
3071 | |
3072 | for (auto consumer_id : contig_id_entry.covered_ids) { |
3073 | info.root_ids_.insert(consumer_id); |
3074 | } |
3075 | pred_info_vec.emplace_back(info); |
3076 | } |
3077 | |
3078 | return pred_info_vec; |
3079 | } |
3080 | |
3081 | RootPredicateInfo RootPredicateInfo::getFalseInfo() { |
3082 | RootPredicateInfo info; |
3083 | info.start_predicate_ = GpuLower::current()->kernel()->falseVal(); |
3084 | info.stop_predicate_ = GpuLower::current()->kernel()->falseVal(); |
3085 | |
3086 | return info; |
3087 | } |
3088 | |
3089 | } // namespace cuda |
3090 | } // namespace fuser |
3091 | } // namespace jit |
3092 | } // namespace torch |
3093 | |