1
2#include <instrumentation.h>
3#include <ir_utils.h>
4#include <lower2device.h>
5
6#include <lower_sync_information.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13namespace {
14
15// Validate parallelization of a single tensor
16void validateParallelizationOfTensor(TensorView* tv) {
17 // Each ParallelType can be used only once.
18 ParallelTypeBitmap pt_map;
19 for (size_t i = 0; i < tv->nDims(); ++i) {
20 auto axis = tv->axis(i);
21 auto ptype = axis->getParallelType();
22 if (!isParallelTypeThread(ptype)) {
23 continue;
24 }
25
26 // It doesn't matter if this axis is a non-concretized broadcast
27 // TODO: merging broadcast and non-broadcast
28 if (axis->isBroadcast() &&
29 !GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
30 axis)) {
31 continue;
32 }
33
34 TORCH_INTERNAL_ASSERT(
35 !pt_map.get(ptype),
36 "Multiple use of ",
37 ptype,
38 " in tensor t",
39 tv->name(),
40 ": ",
41 tv);
42 pt_map.set(ptype);
43 }
44
45 // If this tensor is predicated by a paralel type, it should not be
46 // used to parallelize any domain of this tensor
47
48 const auto thread_pred =
49 GpuLower::current()->threadPredMap().getPredicateInfo(tv);
50
51 auto predicated_parallel_types = pt_map & thread_pred.limited_types;
52
53 TORCH_INTERNAL_ASSERT(
54 predicated_parallel_types.none(),
55 "Invalid parallelization of tensor t",
56 tv->name(),
57 ". The tensor is parallelized with ",
58 predicated_parallel_types.toString(),
59 ", but it's invalid to use the types as the tensor is also predicated with them.",
60 ", thread pred: ",
61 thread_pred.limited_types.toString());
62}
63
64//! Return true if axis is derived from a root axis that is an input
65//! to a CA leaf axis.
66bool derivedFromRootCAAxes(TensorView* tv, IterDomain* axis) {
67 std::vector<IterDomain*> ca_axes(
68 tv->domain()->domain().begin(),
69 tv->domain()->domain().begin() + tv->getComputeAtPosition());
70
71 auto ca_root_vals = IterVisitor::getInputsTo(
72 std::vector<Val*>(ca_axes.begin(), ca_axes.end()));
73
74 auto root_vals = IterVisitor::getInputsTo({axis});
75
76 return std::any_of(
77 root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) {
78 return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) !=
79 ca_root_vals.end();
80 });
81}
82
83} // namespace
84
85void SyncMap::build(Fusion* fusion) {
86 FUSER_PERF_SCOPE("GpuLower::Lower::validateParallelize");
87 FusionGuard fg(fusion);
88
89 const auto& ca_map = GpuLower::current()->caMap();
90 const auto& pred_map = GpuLower::current()->threadPredMap();
91
92 auto exprs = StmtSort::getExprs(fusion);
93
94 // Run through expressions and check for communication across threads/blocks
95 // occuring from producer to consumer of the expression
96 for (auto expr : exprs) {
97 if (!ir_utils::isTvOp(expr)) {
98 continue;
99 }
100
101 // Validate parallelization of each consumer by itself
102 for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
103 validateParallelizationOfTensor(consumer);
104 }
105
106 // It's probably enough to just check all producers to one consumer as
107 // multi-consumers are guaranteed to be transformed/parallelized the same,
108 // but to be conservative for now checking every producer <-> consumer
109 // relationship.
110 for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
111 // Parallelization on input tensors have no effect.
112 if (producer->isFusionInput()) {
113 continue;
114 }
115
116 ParallelTypeBitmap raw_dims;
117
118 const auto parallel_bcast_doms =
119 pred_map.getParallelBroadcastDomains(producer);
120
121 // Stash information about parallelized producer iteration domains
122 std::vector<IterDomain*> producer_parallel_ids(
123 ParallelTypeBitmap::kNumParallelTypes, nullptr);
124 ParallelTypeBitmap producer_parallel_bitmap;
125
126 // Tracking for quick check later
127 std::unordered_set<IterDomain*> producer_within_compute_at;
128
129 // Get the parallel types that producer will be predicated off in producer
130 // writes.
131 // In this case we need a sync whether the producer-consumer axes are
132 // mapped or not since the predicate pass will generate pattern like
133 // below to eliminate redundant writes: if(threadIdx.x == 0)
134 // shared[threadIdx.x + i] = ...
135 // We will need a raw sync after this pattern for correctness.
136 auto producer_redundant_types = GpuLower::current()
137 ->threadPredMap()
138 .getPredicateInfo(producer)
139 .redundant_types;
140 // Get the parallel types that are inactive in consumer's use chains.
141 auto producer_redundant_use_types = GpuLower::current()
142 ->threadPredMap()
143 .getPredicateInfo(producer)
144 .redundant_use_types;
145
146 // In sync info pass we only consider the parallel types in
147 // producer that are redundantly produced but not redundantly consumed.
148 producer_redundant_types =
149 producer_redundant_types & (~producer_redundant_use_types);
150
151 for (const auto producer_i : c10::irange(producer->nDims())) {
152 auto producer_axis = producer->axis(producer_i);
153 auto producer_ptype =
154 ca_map->getConcreteMappedID(producer_axis, IdMappingMode::LOOP)
155 ->getParallelType();
156
157 if (!isParallelTypeThread(producer_ptype)) {
158 continue;
159 }
160
161 // Producer reductions shouldn't map to consumers
162 if (producer_axis->isReduction()) {
163 continue;
164 }
165
166 if (producer_i < producer->getComputeAtPosition()) {
167 producer_within_compute_at.emplace(producer_axis);
168 }
169
170 producer_parallel_bitmap.set(producer_ptype);
171 producer_parallel_ids[getParallelTypeBitMapOffset(producer_ptype)] =
172 producer_axis;
173 }
174
175 for (auto consumer :
176 ir_utils::filterByType<TensorView>(expr->outputs())) {
177 // Stash information about parallelized consumer iteration domains
178 std::vector<IterDomain*> consumer_parallel_ids(
179 ParallelTypeBitmap::kNumParallelTypes, nullptr);
180 ParallelTypeBitmap consumer_parallel_bitmap;
181
182 for (const auto consumer_i : c10::irange(consumer->nDims())) {
183 auto consumer_axis = consumer->axis(consumer_i);
184 auto consumer_ptype =
185 ca_map->getConcreteMappedID(consumer_axis, IdMappingMode::LOOP)
186 ->getParallelType();
187
188 if (!isParallelTypeThread(consumer_ptype)) {
189 continue;
190 }
191
192 // When the consumer axis is a broadcast, it is not really
193 // parallelized unless thread-predicated and eventually concretized
194 if (consumer_axis->isBroadcast() &&
195 (!parallel_bcast_doms.get(consumer_ptype) ||
196 !GpuLower::current()
197 ->concretizedBroadcastDomains()
198 ->isConcretized(consumer_axis))) {
199 continue;
200 }
201
202 consumer_parallel_bitmap.set(consumer_ptype);
203 consumer_parallel_ids[getParallelTypeBitMapOffset(consumer_ptype)] =
204 consumer_axis;
205 }
206
207 // At this point each parallel type that's present in the consumer or
208 // the producer will be present in their corresponding `_parallel_ids`
209 // map going from parallel index type (only size 6 for grid/block dims)
210 // to the iteration domain of that parallel type.
211 for (auto parallel_type : kParallelTypeThreads) {
212 // TIDx is reserved for lane_id in the case of mma ops.
213 // It is swizzled and handled separately in validateMma.
214 if (parallel_type == ParallelType::TIDx && expr->isA<MmaOp>()) {
215 continue;
216 }
217
218 // In the case when the parallel id's are mapped by ca map,
219 // will additionally need to consider if the producer is
220 // a redundant write. The raw dim can be skipped only if
221 // consumer use chains only contain redundant uses.
222 // TODO:
223 // still losing a bit precision here for expr ordering
224 // sensitive cases, but we could wait until that becomes
225 // a perf limiter to fix.
226 if (producer_redundant_types.get(parallel_type)) {
227 raw_dims.set(parallel_type);
228 continue;
229 }
230
231 auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);
232
233 auto p_id = producer_parallel_ids[parallel_type_i];
234 auto c_id = consumer_parallel_ids[parallel_type_i];
235
236 if (p_id == nullptr && c_id == nullptr) {
237 continue;
238 } else if (p_id != nullptr && c_id != nullptr) {
239 if (GpuLower::current()->caMap()->areMapped(
240 p_id, c_id, IdMappingMode::PERMISSIVE)) {
241 const auto halo_info = GpuLower::current()->haloInfo();
242
243 if (halo_info->hasHaloWidth(p_id) !=
244 halo_info->hasHaloWidth(c_id) ||
245 (halo_info->hasHaloWidth(p_id) &&
246 halo_info->hasHaloWidth(c_id) &&
247 halo_info->getHaloWidth(p_id) !=
248 halo_info->getHaloWidth(c_id))) {
249 raw_dims.set(parallel_type);
250 continue;
251 }
252 }
253 } else {
254 if (p_id != nullptr) {
255 auto it = std::find_if(
256 consumer->domain()->domain().begin(),
257 consumer->domain()->domain().end(),
258 [&](IterDomain* c_id) {
259 return GpuLower::current()->caMap()->areMapped(
260 p_id, c_id, IdMappingMode::PERMISSIVE);
261 });
262
263 // If there isn't a mapping from producer to a consumer domain,
264 // need to assume there's communication across this parallel
265 // dimension.
266 c_id = it == consumer->domain()->domain().end() ? nullptr : *it;
267 // i.e. if producer is parallelized across threadIdx.x in a
268 // certain split, if the consumer doesn't map to this split,
269 // then we need to assume it has to be in smem with proper
270 // syncs.
271 } else {
272 auto it = std::find_if(
273 producer->domain()->domain().begin(),
274 producer->domain()->domain().end(),
275 [&](IterDomain* p_id) {
276 return GpuLower::current()->caMap()->areMapped(
277 p_id, c_id, IdMappingMode::PERMISSIVE);
278 });
279 if (it == producer->domain()->domain().end()) {
280 // Can't infer anything if producer doesn't have a matching axis
281 // to parallel consumer dim.
282 continue;
283 }
284 p_id = *it;
285 }
286 }
287
288 // Comm pattern options (when parallel types don't have matching
289 // axes) and required memory, Chart is producer parallel type,
290 // consumer parallel type Parallel types are Serial(S),
291 // threadIdx(T), blockIdx(B), Memory required for the producer is
292 // Local(L), Shared(S), Global(G), Sync is None (N/A), blockSync(B),
293 // grid_sync(G)
294 //
295 // P C Mem Req Sync Type
296 // S S L N/A
297 // S T L N/A
298 // S B L N/A
299 // T S S B
300 // T T S B
301 // T B S B
302 // B S G G
303 // B T G G
304 // B B G G
305
306 auto producer_ptype =
307 ca_map->getConcreteMappedID(p_id, IdMappingMode::LOOP)
308 ->getParallelType();
309 auto consumer_ptype = c_id == nullptr
310 ? ParallelType::Serial
311 : ca_map->getConcreteMappedID(c_id, IdMappingMode::LOOP)
312 ->getParallelType();
313
314 if (!p_id->isBroadcast() && isParallelTypeThread(producer_ptype) &&
315 !(isParallelTypeThread(consumer_ptype) &&
316 parallel_bcast_doms.get(consumer_ptype)) &&
317 // Being in compute at means consumer and producer rely on the
318 // same loop size
319 !producer_within_compute_at.count(p_id) &&
320 // For usage of derivedFromRootCAAxes check
321 // NVFuserTest.FusionAdvancedIndexing1_CUDA
322 (c_id == nullptr || !derivedFromRootCAAxes(producer, p_id))) {
323 // There must be a consumer axis that uses the same indexing
324 // with the same parallel type as the producer axis. The index
325 // map is used to to find such an axis. In addition, even when
326 // no mapped axis is found in the index map, but when an mapped
327 // axis exists in the loop map, the producer and consumer axes
328 // may still use the same indexing. That only happens when the
329 // producer is derived from a root axis that is an input to any
330 // leaf CA axes. In such a case, the axis in the reference
331 // tensor that maps to the producer axis is created based on the
332 // consumer, so both the producer and consumer axes should have
333 // the same indexing. See issue #995 as well as the
334 // FusionValidateParallelize6 test for a concrete example.
335 auto it = std::find_if(
336 consumer->domain()->domain().begin(),
337 consumer->domain()->domain().end(),
338 [&](IterDomain* c_id_) {
339 return ca_map->areMapped(p_id, c_id_, IdMappingMode::EXACT);
340 });
341 if (it == consumer->domain()->domain().end()) {
342 if (isParallelTypeThread(producer_ptype)) {
343 raw_dims.set(producer_ptype);
344 }
345 if (isParallelTypeThread(consumer_ptype)) {
346 raw_dims.set(consumer_ptype);
347 }
348 }
349 }
350
351 // If any leaf id of producer is block or grid parallel and is
352 // involved
353 // in any swizzle pattern, track this parallel dim as a communication
354 // dimension that requires the corresponding synchronization and
355 // memory type.
356 if (isParallelTypeThread(producer_ptype) &&
357 producer->hasSwizzleOp()) {
358 if (!ir_utils::getAllSwizzlesBetween(
359 producer->getMaybeRFactorDomain(), {p_id})
360 .empty()) {
361 raw_dims.set(producer_ptype);
362 }
363 }
364
365 // In shift or gather operations, if a thread or block
366 // domain's root ID is shifted or gathered, it can overlap
367 // in shared or global memory. This doesn't
368 // require a RAW sync since each thread would still write every value
369 // it would read, but it can require a WAR sync for Shared Memory.
370 // Since there isn't a separate structure for WAR than RAW for now
371 // we'll flag it on RAW which will trigger the WAR.
372 // See test FusionValidateParallelizeShift_CUDA for a
373 // concrete example where this sync is required.
374 if ((expr->getExprType() == ExprType::GatherOp ||
375 expr->getExprType() == ExprType::ShiftOp) &&
376 producer->getMemoryType() == MemoryType::Shared &&
377 isParallelTypeThreadDim(producer_ptype)) {
378 std::unordered_set<Val*> shifted_rfactor_ids;
379 if (expr->getExprType() == ExprType::GatherOp) {
380 auto gather_op = expr->as<GatherOp>();
381 for (auto root_i :
382 c10::irange(producer->getMaybeRFactorDomain().size())) {
383 auto rfactor_id = producer->getMaybeRFactorDomain()[root_i];
384 // If the window shape is 1, it just copies the
385 // producer to the consumer
386 if (gather_op->windowShape()[root_i] != 1) {
387 shifted_rfactor_ids.insert(rfactor_id);
388 }
389 }
390 } else if (expr->getExprType() == ExprType::ShiftOp) {
391 auto shift_op = expr->as<ShiftOp>();
392 for (auto root_i :
393 c10::irange(producer->getMaybeRFactorDomain().size())) {
394 auto rfactor_id = producer->getMaybeRFactorDomain()[root_i];
395 // If the shift offset is 0, it doesn't actually shift
396 if (shift_op->offsets()[root_i] != 0) {
397 shifted_rfactor_ids.insert(rfactor_id);
398 }
399 }
400 }
401
402 // Grab all values between shifted rfactor domains and p_id so we
403 // can identify which rfactor domains are inputs to the p_id
404 auto p_id_dep_vals =
405 DependencyCheck::getAllValsBetween(shifted_rfactor_ids, {p_id});
406 // If this shifted rfactor domain is an input to p_id, we
407 // must have a WAR sync. Mark raw sync so it will be generated.
408 if (!p_id_dep_vals.empty()) {
409 raw_dims.set(producer_ptype);
410 }
411 }
412
413 // When the producer axis is a broadcast, it is not really
414 // parallelized unless thread-predicated and concretized
415 if (isParallelTypeThread(producer_ptype) && p_id->isBroadcast() &&
416 (!parallel_bcast_doms.get(producer_ptype) ||
417 !GpuLower::current()
418 ->concretizedBroadcastDomains()
419 ->isConcretized(p_id))) {
420 continue;
421 }
422
423 // If matching dims and matching parallel types, no comm is necessary.
424 if (producer_ptype == consumer_ptype &&
425 GpuLower::current()->caMap()->areMapped(
426 p_id, c_id, IdMappingMode::PERMISSIVE)) {
427 continue;
428 }
429
430 // Set parallel dimensions that communication is occuring over.
431 if (isParallelTypeThread(producer_ptype)) {
432 raw_dims.set(producer_ptype);
433 }
434 } // end for ptypes
435
436 if (raw_dims.hasBID()) {
437 TORCH_INTERNAL_ASSERT(
438 producer->getMemoryType() == MemoryType::Global,
439 "Inconsistent parallelization found between TV",
440 producer->name(),
441 " (",
442 producer->toString(),
443 ") and TV",
444 consumer->name(),
445 "(",
446 consumer->toString(),
447 "). Producer is required to be in Global Memory based on parallelization strategy.");
448 } else if (raw_dims.hasTID()) {
449 TORCH_INTERNAL_ASSERT(
450 producer->getMemoryType() == MemoryType::Global ||
451 producer->getMemoryType() == MemoryType::Shared,
452 "Inconsistent parallelization found between TV",
453 producer->name(),
454 " (",
455 producer->toString(),
456 ") and TV",
457 consumer->name(),
458 "(",
459 consumer->toString(),
460 "). Producer is required to be in Global or Shared Memory based on parallelization strategy.");
461 }
462
463 } // end for consumers
464
465 if (raw_dims.any()) {
466 needs_raw_sync_[producer] |= raw_dims;
467 }
468
469 } // end producer
470 }
471}
472
473std::string SyncMap::toString() const {
474 std::stringstream ss;
475 ss << "SyncMap:";
476 bool is_first = true;
477 for (auto entry : needs_raw_sync_) {
478 if (!is_first) {
479 ss << ",";
480 }
481 ss << " " << entry.first->toString() << " -> " << entry.second.toString();
482 is_first = false;
483 }
484 return ss.str();
485}
486
487} // namespace cuda
488} // namespace fuser
489} // namespace jit
490} // namespace torch
491