1 | |
2 | #include <instrumentation.h> |
3 | #include <ir_utils.h> |
4 | #include <lower2device.h> |
5 | |
6 | #include <lower_sync_information.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | namespace { |
14 | |
15 | // Validate parallelization of a single tensor |
16 | void validateParallelizationOfTensor(TensorView* tv) { |
17 | // Each ParallelType can be used only once. |
18 | ParallelTypeBitmap pt_map; |
19 | for (size_t i = 0; i < tv->nDims(); ++i) { |
20 | auto axis = tv->axis(i); |
21 | auto ptype = axis->getParallelType(); |
22 | if (!isParallelTypeThread(ptype)) { |
23 | continue; |
24 | } |
25 | |
26 | // It doesn't matter if this axis is a non-concretized broadcast |
27 | // TODO: merging broadcast and non-broadcast |
28 | if (axis->isBroadcast() && |
29 | !GpuLower::current()->concretizedBroadcastDomains()->isConcretized( |
30 | axis)) { |
31 | continue; |
32 | } |
33 | |
34 | TORCH_INTERNAL_ASSERT( |
35 | !pt_map.get(ptype), |
36 | "Multiple use of " , |
37 | ptype, |
38 | " in tensor t" , |
39 | tv->name(), |
40 | ": " , |
41 | tv); |
42 | pt_map.set(ptype); |
43 | } |
44 | |
45 | // If this tensor is predicated by a paralel type, it should not be |
46 | // used to parallelize any domain of this tensor |
47 | |
48 | const auto thread_pred = |
49 | GpuLower::current()->threadPredMap().getPredicateInfo(tv); |
50 | |
51 | auto predicated_parallel_types = pt_map & thread_pred.limited_types; |
52 | |
53 | TORCH_INTERNAL_ASSERT( |
54 | predicated_parallel_types.none(), |
55 | "Invalid parallelization of tensor t" , |
56 | tv->name(), |
57 | ". The tensor is parallelized with " , |
58 | predicated_parallel_types.toString(), |
59 | ", but it's invalid to use the types as the tensor is also predicated with them." , |
60 | ", thread pred: " , |
61 | thread_pred.limited_types.toString()); |
62 | } |
63 | |
64 | //! Return true if axis is derived from a root axis that is an input |
65 | //! to a CA leaf axis. |
66 | bool derivedFromRootCAAxes(TensorView* tv, IterDomain* axis) { |
67 | std::vector<IterDomain*> ca_axes( |
68 | tv->domain()->domain().begin(), |
69 | tv->domain()->domain().begin() + tv->getComputeAtPosition()); |
70 | |
71 | auto ca_root_vals = IterVisitor::getInputsTo( |
72 | std::vector<Val*>(ca_axes.begin(), ca_axes.end())); |
73 | |
74 | auto root_vals = IterVisitor::getInputsTo({axis}); |
75 | |
76 | return std::any_of( |
77 | root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) { |
78 | return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) != |
79 | ca_root_vals.end(); |
80 | }); |
81 | } |
82 | |
83 | } // namespace |
84 | |
85 | void SyncMap::build(Fusion* fusion) { |
86 | FUSER_PERF_SCOPE("GpuLower::Lower::validateParallelize" ); |
87 | FusionGuard fg(fusion); |
88 | |
89 | const auto& ca_map = GpuLower::current()->caMap(); |
90 | const auto& pred_map = GpuLower::current()->threadPredMap(); |
91 | |
92 | auto exprs = StmtSort::getExprs(fusion); |
93 | |
94 | // Run through expressions and check for communication across threads/blocks |
95 | // occuring from producer to consumer of the expression |
96 | for (auto expr : exprs) { |
97 | if (!ir_utils::isTvOp(expr)) { |
98 | continue; |
99 | } |
100 | |
101 | // Validate parallelization of each consumer by itself |
102 | for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { |
103 | validateParallelizationOfTensor(consumer); |
104 | } |
105 | |
106 | // It's probably enough to just check all producers to one consumer as |
107 | // multi-consumers are guaranteed to be transformed/parallelized the same, |
108 | // but to be conservative for now checking every producer <-> consumer |
109 | // relationship. |
110 | for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { |
111 | // Parallelization on input tensors have no effect. |
112 | if (producer->isFusionInput()) { |
113 | continue; |
114 | } |
115 | |
116 | ParallelTypeBitmap raw_dims; |
117 | |
118 | const auto parallel_bcast_doms = |
119 | pred_map.getParallelBroadcastDomains(producer); |
120 | |
121 | // Stash information about parallelized producer iteration domains |
122 | std::vector<IterDomain*> producer_parallel_ids( |
123 | ParallelTypeBitmap::kNumParallelTypes, nullptr); |
124 | ParallelTypeBitmap producer_parallel_bitmap; |
125 | |
126 | // Tracking for quick check later |
127 | std::unordered_set<IterDomain*> producer_within_compute_at; |
128 | |
129 | // Get the parallel types that producer will be predicated off in producer |
130 | // writes. |
131 | // In this case we need a sync whether the producer-consumer axes are |
132 | // mapped or not since the predicate pass will generate pattern like |
133 | // below to eliminate redundant writes: if(threadIdx.x == 0) |
134 | // shared[threadIdx.x + i] = ... |
135 | // We will need a raw sync after this pattern for correctness. |
136 | auto producer_redundant_types = GpuLower::current() |
137 | ->threadPredMap() |
138 | .getPredicateInfo(producer) |
139 | .redundant_types; |
140 | // Get the parallel types that are inactive in consumer's use chains. |
141 | auto producer_redundant_use_types = GpuLower::current() |
142 | ->threadPredMap() |
143 | .getPredicateInfo(producer) |
144 | .redundant_use_types; |
145 | |
146 | // In sync info pass we only consider the parallel types in |
147 | // producer that are redundantly produced but not redundantly consumed. |
148 | producer_redundant_types = |
149 | producer_redundant_types & (~producer_redundant_use_types); |
150 | |
151 | for (const auto producer_i : c10::irange(producer->nDims())) { |
152 | auto producer_axis = producer->axis(producer_i); |
153 | auto producer_ptype = |
154 | ca_map->getConcreteMappedID(producer_axis, IdMappingMode::LOOP) |
155 | ->getParallelType(); |
156 | |
157 | if (!isParallelTypeThread(producer_ptype)) { |
158 | continue; |
159 | } |
160 | |
161 | // Producer reductions shouldn't map to consumers |
162 | if (producer_axis->isReduction()) { |
163 | continue; |
164 | } |
165 | |
166 | if (producer_i < producer->getComputeAtPosition()) { |
167 | producer_within_compute_at.emplace(producer_axis); |
168 | } |
169 | |
170 | producer_parallel_bitmap.set(producer_ptype); |
171 | producer_parallel_ids[getParallelTypeBitMapOffset(producer_ptype)] = |
172 | producer_axis; |
173 | } |
174 | |
175 | for (auto consumer : |
176 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
177 | // Stash information about parallelized consumer iteration domains |
178 | std::vector<IterDomain*> consumer_parallel_ids( |
179 | ParallelTypeBitmap::kNumParallelTypes, nullptr); |
180 | ParallelTypeBitmap consumer_parallel_bitmap; |
181 | |
182 | for (const auto consumer_i : c10::irange(consumer->nDims())) { |
183 | auto consumer_axis = consumer->axis(consumer_i); |
184 | auto consumer_ptype = |
185 | ca_map->getConcreteMappedID(consumer_axis, IdMappingMode::LOOP) |
186 | ->getParallelType(); |
187 | |
188 | if (!isParallelTypeThread(consumer_ptype)) { |
189 | continue; |
190 | } |
191 | |
192 | // When the consumer axis is a broadcast, it is not really |
193 | // parallelized unless thread-predicated and eventually concretized |
194 | if (consumer_axis->isBroadcast() && |
195 | (!parallel_bcast_doms.get(consumer_ptype) || |
196 | !GpuLower::current() |
197 | ->concretizedBroadcastDomains() |
198 | ->isConcretized(consumer_axis))) { |
199 | continue; |
200 | } |
201 | |
202 | consumer_parallel_bitmap.set(consumer_ptype); |
203 | consumer_parallel_ids[getParallelTypeBitMapOffset(consumer_ptype)] = |
204 | consumer_axis; |
205 | } |
206 | |
207 | // At this point each parallel type that's present in the consumer or |
208 | // the producer will be present in their corresponding `_parallel_ids` |
209 | // map going from parallel index type (only size 6 for grid/block dims) |
210 | // to the iteration domain of that parallel type. |
211 | for (auto parallel_type : kParallelTypeThreads) { |
212 | // TIDx is reserved for lane_id in the case of mma ops. |
213 | // It is swizzled and handled separately in validateMma. |
214 | if (parallel_type == ParallelType::TIDx && expr->isA<MmaOp>()) { |
215 | continue; |
216 | } |
217 | |
218 | // In the case when the parallel id's are mapped by ca map, |
219 | // will additionally need to consider if the producer is |
220 | // a redundant write. The raw dim can be skipped only if |
221 | // consumer use chains only contain redundant uses. |
222 | // TODO: |
223 | // still losing a bit precision here for expr ordering |
224 | // sensitive cases, but we could wait until that becomes |
225 | // a perf limiter to fix. |
226 | if (producer_redundant_types.get(parallel_type)) { |
227 | raw_dims.set(parallel_type); |
228 | continue; |
229 | } |
230 | |
231 | auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type); |
232 | |
233 | auto p_id = producer_parallel_ids[parallel_type_i]; |
234 | auto c_id = consumer_parallel_ids[parallel_type_i]; |
235 | |
236 | if (p_id == nullptr && c_id == nullptr) { |
237 | continue; |
238 | } else if (p_id != nullptr && c_id != nullptr) { |
239 | if (GpuLower::current()->caMap()->areMapped( |
240 | p_id, c_id, IdMappingMode::PERMISSIVE)) { |
241 | const auto halo_info = GpuLower::current()->haloInfo(); |
242 | |
243 | if (halo_info->hasHaloWidth(p_id) != |
244 | halo_info->hasHaloWidth(c_id) || |
245 | (halo_info->hasHaloWidth(p_id) && |
246 | halo_info->hasHaloWidth(c_id) && |
247 | halo_info->getHaloWidth(p_id) != |
248 | halo_info->getHaloWidth(c_id))) { |
249 | raw_dims.set(parallel_type); |
250 | continue; |
251 | } |
252 | } |
253 | } else { |
254 | if (p_id != nullptr) { |
255 | auto it = std::find_if( |
256 | consumer->domain()->domain().begin(), |
257 | consumer->domain()->domain().end(), |
258 | [&](IterDomain* c_id) { |
259 | return GpuLower::current()->caMap()->areMapped( |
260 | p_id, c_id, IdMappingMode::PERMISSIVE); |
261 | }); |
262 | |
263 | // If there isn't a mapping from producer to a consumer domain, |
264 | // need to assume there's communication across this parallel |
265 | // dimension. |
266 | c_id = it == consumer->domain()->domain().end() ? nullptr : *it; |
267 | // i.e. if producer is parallelized across threadIdx.x in a |
268 | // certain split, if the consumer doesn't map to this split, |
269 | // then we need to assume it has to be in smem with proper |
270 | // syncs. |
271 | } else { |
272 | auto it = std::find_if( |
273 | producer->domain()->domain().begin(), |
274 | producer->domain()->domain().end(), |
275 | [&](IterDomain* p_id) { |
276 | return GpuLower::current()->caMap()->areMapped( |
277 | p_id, c_id, IdMappingMode::PERMISSIVE); |
278 | }); |
279 | if (it == producer->domain()->domain().end()) { |
280 | // Can't infer anything if producer doesn't have a matching axis |
281 | // to parallel consumer dim. |
282 | continue; |
283 | } |
284 | p_id = *it; |
285 | } |
286 | } |
287 | |
288 | // Comm pattern options (when parallel types don't have matching |
289 | // axes) and required memory, Chart is producer parallel type, |
290 | // consumer parallel type Parallel types are Serial(S), |
291 | // threadIdx(T), blockIdx(B), Memory required for the producer is |
292 | // Local(L), Shared(S), Global(G), Sync is None (N/A), blockSync(B), |
293 | // grid_sync(G) |
294 | // |
295 | // P C Mem Req Sync Type |
296 | // S S L N/A |
297 | // S T L N/A |
298 | // S B L N/A |
299 | // T S S B |
300 | // T T S B |
301 | // T B S B |
302 | // B S G G |
303 | // B T G G |
304 | // B B G G |
305 | |
306 | auto producer_ptype = |
307 | ca_map->getConcreteMappedID(p_id, IdMappingMode::LOOP) |
308 | ->getParallelType(); |
309 | auto consumer_ptype = c_id == nullptr |
310 | ? ParallelType::Serial |
311 | : ca_map->getConcreteMappedID(c_id, IdMappingMode::LOOP) |
312 | ->getParallelType(); |
313 | |
314 | if (!p_id->isBroadcast() && isParallelTypeThread(producer_ptype) && |
315 | !(isParallelTypeThread(consumer_ptype) && |
316 | parallel_bcast_doms.get(consumer_ptype)) && |
317 | // Being in compute at means consumer and producer rely on the |
318 | // same loop size |
319 | !producer_within_compute_at.count(p_id) && |
320 | // For usage of derivedFromRootCAAxes check |
321 | // NVFuserTest.FusionAdvancedIndexing1_CUDA |
322 | (c_id == nullptr || !derivedFromRootCAAxes(producer, p_id))) { |
323 | // There must be a consumer axis that uses the same indexing |
324 | // with the same parallel type as the producer axis. The index |
325 | // map is used to to find such an axis. In addition, even when |
326 | // no mapped axis is found in the index map, but when an mapped |
327 | // axis exists in the loop map, the producer and consumer axes |
328 | // may still use the same indexing. That only happens when the |
329 | // producer is derived from a root axis that is an input to any |
330 | // leaf CA axes. In such a case, the axis in the reference |
331 | // tensor that maps to the producer axis is created based on the |
332 | // consumer, so both the producer and consumer axes should have |
333 | // the same indexing. See issue #995 as well as the |
334 | // FusionValidateParallelize6 test for a concrete example. |
335 | auto it = std::find_if( |
336 | consumer->domain()->domain().begin(), |
337 | consumer->domain()->domain().end(), |
338 | [&](IterDomain* c_id_) { |
339 | return ca_map->areMapped(p_id, c_id_, IdMappingMode::EXACT); |
340 | }); |
341 | if (it == consumer->domain()->domain().end()) { |
342 | if (isParallelTypeThread(producer_ptype)) { |
343 | raw_dims.set(producer_ptype); |
344 | } |
345 | if (isParallelTypeThread(consumer_ptype)) { |
346 | raw_dims.set(consumer_ptype); |
347 | } |
348 | } |
349 | } |
350 | |
351 | // If any leaf id of producer is block or grid parallel and is |
352 | // involved |
353 | // in any swizzle pattern, track this parallel dim as a communication |
354 | // dimension that requires the corresponding synchronization and |
355 | // memory type. |
356 | if (isParallelTypeThread(producer_ptype) && |
357 | producer->hasSwizzleOp()) { |
358 | if (!ir_utils::getAllSwizzlesBetween( |
359 | producer->getMaybeRFactorDomain(), {p_id}) |
360 | .empty()) { |
361 | raw_dims.set(producer_ptype); |
362 | } |
363 | } |
364 | |
365 | // In shift or gather operations, if a thread or block |
366 | // domain's root ID is shifted or gathered, it can overlap |
367 | // in shared or global memory. This doesn't |
368 | // require a RAW sync since each thread would still write every value |
369 | // it would read, but it can require a WAR sync for Shared Memory. |
370 | // Since there isn't a separate structure for WAR than RAW for now |
371 | // we'll flag it on RAW which will trigger the WAR. |
372 | // See test FusionValidateParallelizeShift_CUDA for a |
373 | // concrete example where this sync is required. |
374 | if ((expr->getExprType() == ExprType::GatherOp || |
375 | expr->getExprType() == ExprType::ShiftOp) && |
376 | producer->getMemoryType() == MemoryType::Shared && |
377 | isParallelTypeThreadDim(producer_ptype)) { |
378 | std::unordered_set<Val*> shifted_rfactor_ids; |
379 | if (expr->getExprType() == ExprType::GatherOp) { |
380 | auto gather_op = expr->as<GatherOp>(); |
381 | for (auto root_i : |
382 | c10::irange(producer->getMaybeRFactorDomain().size())) { |
383 | auto rfactor_id = producer->getMaybeRFactorDomain()[root_i]; |
384 | // If the window shape is 1, it just copies the |
385 | // producer to the consumer |
386 | if (gather_op->windowShape()[root_i] != 1) { |
387 | shifted_rfactor_ids.insert(rfactor_id); |
388 | } |
389 | } |
390 | } else if (expr->getExprType() == ExprType::ShiftOp) { |
391 | auto shift_op = expr->as<ShiftOp>(); |
392 | for (auto root_i : |
393 | c10::irange(producer->getMaybeRFactorDomain().size())) { |
394 | auto rfactor_id = producer->getMaybeRFactorDomain()[root_i]; |
395 | // If the shift offset is 0, it doesn't actually shift |
396 | if (shift_op->offsets()[root_i] != 0) { |
397 | shifted_rfactor_ids.insert(rfactor_id); |
398 | } |
399 | } |
400 | } |
401 | |
402 | // Grab all values between shifted rfactor domains and p_id so we |
403 | // can identify which rfactor domains are inputs to the p_id |
404 | auto p_id_dep_vals = |
405 | DependencyCheck::getAllValsBetween(shifted_rfactor_ids, {p_id}); |
406 | // If this shifted rfactor domain is an input to p_id, we |
407 | // must have a WAR sync. Mark raw sync so it will be generated. |
408 | if (!p_id_dep_vals.empty()) { |
409 | raw_dims.set(producer_ptype); |
410 | } |
411 | } |
412 | |
413 | // When the producer axis is a broadcast, it is not really |
414 | // parallelized unless thread-predicated and concretized |
415 | if (isParallelTypeThread(producer_ptype) && p_id->isBroadcast() && |
416 | (!parallel_bcast_doms.get(producer_ptype) || |
417 | !GpuLower::current() |
418 | ->concretizedBroadcastDomains() |
419 | ->isConcretized(p_id))) { |
420 | continue; |
421 | } |
422 | |
423 | // If matching dims and matching parallel types, no comm is necessary. |
424 | if (producer_ptype == consumer_ptype && |
425 | GpuLower::current()->caMap()->areMapped( |
426 | p_id, c_id, IdMappingMode::PERMISSIVE)) { |
427 | continue; |
428 | } |
429 | |
430 | // Set parallel dimensions that communication is occuring over. |
431 | if (isParallelTypeThread(producer_ptype)) { |
432 | raw_dims.set(producer_ptype); |
433 | } |
434 | } // end for ptypes |
435 | |
436 | if (raw_dims.hasBID()) { |
437 | TORCH_INTERNAL_ASSERT( |
438 | producer->getMemoryType() == MemoryType::Global, |
439 | "Inconsistent parallelization found between TV" , |
440 | producer->name(), |
441 | " (" , |
442 | producer->toString(), |
443 | ") and TV" , |
444 | consumer->name(), |
445 | "(" , |
446 | consumer->toString(), |
447 | "). Producer is required to be in Global Memory based on parallelization strategy." ); |
448 | } else if (raw_dims.hasTID()) { |
449 | TORCH_INTERNAL_ASSERT( |
450 | producer->getMemoryType() == MemoryType::Global || |
451 | producer->getMemoryType() == MemoryType::Shared, |
452 | "Inconsistent parallelization found between TV" , |
453 | producer->name(), |
454 | " (" , |
455 | producer->toString(), |
456 | ") and TV" , |
457 | consumer->name(), |
458 | "(" , |
459 | consumer->toString(), |
460 | "). Producer is required to be in Global or Shared Memory based on parallelization strategy." ); |
461 | } |
462 | |
463 | } // end for consumers |
464 | |
465 | if (raw_dims.any()) { |
466 | needs_raw_sync_[producer] |= raw_dims; |
467 | } |
468 | |
469 | } // end producer |
470 | } |
471 | } |
472 | |
473 | std::string SyncMap::toString() const { |
474 | std::stringstream ss; |
475 | ss << "SyncMap:" ; |
476 | bool is_first = true; |
477 | for (auto entry : needs_raw_sync_) { |
478 | if (!is_first) { |
479 | ss << "," ; |
480 | } |
481 | ss << " " << entry.first->toString() << " -> " << entry.second.toString(); |
482 | is_first = false; |
483 | } |
484 | return ss.str(); |
485 | } |
486 | |
487 | } // namespace cuda |
488 | } // namespace fuser |
489 | } // namespace jit |
490 | } // namespace torch |
491 | |