1 | #include <lower_thread_predicate.h> |
2 | |
3 | #include <arith.h> |
4 | #include <instrumentation.h> |
5 | #include <ir_iostream.h> |
6 | #include <ir_utils.h> |
7 | #include <lower2device.h> |
8 | #include <lower_utils.h> |
9 | |
10 | #include <c10/util/irange.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | namespace { |
18 | |
19 | Bool* getPredicatePerParallelType( |
20 | ParallelType pt, |
21 | const ThreadPredicateMap::PredicateInfo& pred_info) { |
22 | auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); |
23 | |
24 | // If pt is not used or is proven to be one, no need to predicate. |
25 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
26 | return GpuLower::current()->kernel()->trueVal(); |
27 | } |
28 | // When BID needs to be predicated, that means it's an output of a grid |
29 | // reduction and only the last block index in that dimension has the right |
30 | // value from the grid reduce. |
31 | if (isParallelTypeBlockDim(pt) && pred_info.limited_types.get(pt)) { |
32 | return SimplifyingIrBuilder::eqExpr( |
33 | NamedScalar::getParallelIndex(pt), |
34 | SimplifyingIrBuilder::subExpr( |
35 | NamedScalar::getParallelDim(pt), |
36 | GpuLower::current()->kernel()->oneVal())) |
37 | ->as<Bool>(); |
38 | } |
39 | |
40 | // Otherwise, only thread of index 0 executes the computation |
41 | return SimplifyingIrBuilder::eqExpr( |
42 | NamedScalar::getParallelIndex(pt), |
43 | GpuLower::current()->kernel()->zeroVal()) |
44 | ->as<Bool>(); |
45 | } |
46 | |
47 | } // namespace |
48 | |
49 | Bool* ThreadPredicateMap::getPredicateFromPredicateInfo( |
50 | const ThreadPredicateMap::PredicateInfo& pred_info) { |
51 | const auto pred_types = pred_info.limited_types | pred_info.redundant_types; |
52 | |
53 | if (pred_types.none()) { |
54 | return GpuLower::current()->kernel()->trueVal(); |
55 | } |
56 | |
57 | Bool* pred = nullptr; |
58 | for (const auto pt : pred_types) { |
59 | const auto tp = getPredicatePerParallelType(pt, pred_info); |
60 | pred = SimplifyingIrBuilder::andExpr(pred, tp)->as<Bool>(); |
61 | } |
62 | TORCH_INTERNAL_ASSERT(pred != nullptr); |
63 | |
64 | return pred; |
65 | } |
66 | |
67 | namespace { |
68 | |
69 | // Build redundant predicate flags. Will be stored as |
70 | // PredicateInfo.redundant_types for the given tensor. |
71 | ParallelTypeBitmap avoidRedundantWrites(const TensorView* out_tv) { |
72 | // If the memory type is Local, it's fine to write into it always as |
73 | // it's thread local. If it's Global, it's also fine to let each |
74 | // thread do its own write, unless out_tv is an output of a |
75 | // reduction. Standard reductions (forget gridReduce for the sake of this |
76 | // argument) directly into global memory buffers accumulate into the global |
77 | // memory buffer. If this is done redundantly then it could lead to incorrect |
78 | // results. Correctness issues here can come from smem aliasing, smem |
79 | // reductions or gmem reductions because the reduction itself performs an |
80 | // update to a value, not just a set. For performance it's safe to ommit the |
81 | // redundant writes to gmem or smem, this comment is just specifying it's not |
82 | // always just a performance optimization, but can also be a correctness |
83 | // requirement. |
84 | // |
85 | // For now this is enabled for shared memory buffers, global memory buffers |
86 | // undergoing a reduction, and global memory buffers with terminating outputs. |
87 | // This could be extended to all global memory buffer transactions, but in the |
88 | // test AdvancedIndexing11 there's a case where an intermediate global buffer |
89 | // is set and used to perform a broadcast. At the moment a grid sync is not |
90 | // being inserted here, and it's generally safe since it's just a set. We |
91 | // could enable this more generally for global memory buffers, but would have |
92 | // to insert a sync or a grid broadcast in that example. For now the |
93 | // approach is to only do this on a grid buffer (not undergoing a reduction) |
94 | // if there are no other uses in the kernel. |
95 | // |
96 | // TODO: Revisit if something like AdvancedIndexing11 could be happening at |
97 | // the same time of a global reduction in a way that could produce an |
98 | // incorrect result. |
99 | const bool is_reduction = ir_utils::isReductionOp(out_tv->definition()); |
100 | if (!(out_tv->getMemoryType() == MemoryType::Shared || |
101 | (out_tv->getMemoryType() == MemoryType::Global && is_reduction) || |
102 | (out_tv->getMemoryType() == MemoryType::Global && |
103 | out_tv->uses().empty()))) { |
104 | return ParallelTypeBitmap(); |
105 | } |
106 | |
107 | ParallelTypeBitmap pred; |
108 | // Track which TID types are not used to find redundant parallel |
109 | // types. Only TID types are checked if the tensor is on shared |
110 | // memory otherwise on global memory all TID and BID types are checked. |
111 | ParallelTypeBitmap unused_types; |
112 | // Initially all types are conservatively assumed to not be used. |
113 | unused_types = ~unused_types; |
114 | for (auto out_tv_id : out_tv->domain()->domain()) { |
115 | auto pt = out_tv_id->getParallelType(); |
116 | if (!isParallelTypeThread(pt)) { |
117 | continue; |
118 | } |
119 | // If the axis is a broadcast domain and is parallelized by TID, |
120 | // it is sufficient to use just one thread since the tensor is on |
121 | // shared memory. |
122 | if ((out_tv->getMemoryType() == MemoryType::Shared && |
123 | out_tv_id->isBroadcast() && isParallelTypeThreadDim(pt)) || |
124 | // Protect against global memory and is_reduction as we don't want to |
125 | // predicate grid dimensions as codegen will complain predication on |
126 | // block dimensions is not allowed in grid reductions. The old |
127 | // grid reduction runtime kernel does not differentiate |
128 | // non-reduction and predicated parallel types, so the sync |
129 | // integer buffer would need to be expanded even for |
130 | // predicated parallel types, which is not what |
131 | // getGridSyncBufferSize does. The right thing here is either: |
132 | // retire the old grid reduction kernel, or update the kernel |
133 | // to propertly ignore predicated types. The new kernel is |
134 | // significantly complex and has not been tested, so the |
135 | // latter option seems more reasonable for now. See #1671. |
136 | (!is_reduction && out_tv->getMemoryType() == MemoryType::Global && |
137 | out_tv_id->isBroadcast() && isParallelTypeThread(pt))) { |
138 | pred.set(pt); |
139 | } |
140 | unused_types.clear(pt); |
141 | } |
142 | |
143 | const auto& par_dim_map = GpuLower::current()->parallelDimensionMap(); |
144 | |
145 | for (const auto pt : unused_types) { |
146 | // For shared memory tensors, unused BID isn't redundant |
147 | if (isParallelTypeBlockDim(pt) && |
148 | out_tv->getMemoryType() == MemoryType::Shared) { |
149 | continue; |
150 | } |
151 | // If the pt is not used or is proven to be one, it is not |
152 | // really redundant. |
153 | auto pt_dim = par_dim_map.get(pt); |
154 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
155 | continue; |
156 | } |
157 | pred.set(pt); |
158 | } |
159 | |
160 | return pred; |
161 | } |
162 | |
163 | // If tv is an output of a reduction with unused parallel types, those |
164 | // unused parallel types need to be predicated if the tensor is on |
165 | // global memory. |
166 | ParallelTypeBitmap getReductionPredicateForUnusedParallelTypes( |
167 | const TensorView* tv, |
168 | const ThreadPredicateMap::PredicateInfo& pred_info) { |
169 | auto tv_def = tv->definition(); |
170 | if (!(tv_def && ir_utils::isReductionOp(tv_def) && |
171 | tv->getMemoryType() == MemoryType::Global)) { |
172 | return {}; |
173 | } |
174 | |
175 | // Unused types are set as redundant types of tv |
176 | return pred_info.redundant_types; |
177 | } |
178 | |
179 | } // namespace |
180 | |
181 | // Update the reduction_deps bitset based on provided Expr |
182 | void ThreadPredicateMap::updateBitSet(const Expr* expr) { |
183 | FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap::updateBitSet" ); |
184 | |
185 | // If all of the inputs are not updated and all of the outputs have |
186 | // already mappings, don't do anything |
187 | if (std::all_of( |
188 | ir_utils::filterByType<TensorView>(expr->inputs()).begin(), |
189 | ir_utils::filterByType<TensorView>(expr->inputs()).end(), |
190 | [this](TensorView* tv) { |
191 | return updated_tvs_.find(tv) == updated_tvs_.end(); |
192 | }) && |
193 | std::all_of( |
194 | ir_utils::filterByType<TensorView>(expr->outputs()).begin(), |
195 | ir_utils::filterByType<TensorView>(expr->outputs()).end(), |
196 | [this](TensorView* tv) { return find(tv) != end(); })) { |
197 | return; |
198 | } |
199 | |
200 | // Which predicates were set for the inputs |
201 | ParallelTypeBitmap input_preds; |
202 | |
203 | // Which dims are reductions in inputs |
204 | ParallelTypeBitmap input_reductions; |
205 | |
206 | // Run through inputs and update bitsets |
207 | for (const auto* inp : expr->inputs()) { |
208 | if (!ir_utils::isTV(inp)) |
209 | continue; |
210 | |
211 | auto tv_inp = inp->as<TensorView>(); |
212 | |
213 | // If tv_inp was an output of a multi-output expression, just change it to a |
214 | // consistent sibling to use a single predicate name. |
215 | if (auto tv_def = tv_inp->definition()) { |
216 | if (tv_def->outputs().size() > 1) { |
217 | tv_inp = ir_utils::getTvOutput(tv_def); |
218 | } |
219 | } |
220 | |
221 | TORCH_INTERNAL_ASSERT( |
222 | thread_predicates_.find(tv_inp) != thread_predicates_.end(), |
223 | "Thread predicate map was not initialized, couldn't find " , |
224 | inp); |
225 | |
226 | const auto& pred_info = at(tv_inp); |
227 | |
228 | ParallelTypeBitmap id_reductions; |
229 | ParallelTypeBitmap id_bcasts; |
230 | ParallelTypeBitmap id_ptypes; |
231 | |
232 | for (auto id : tv_inp->domain()->domain()) { |
233 | if (id->isThread()) { |
234 | id_ptypes.set(id->getParallelType()); |
235 | if (id->isReduction() && |
236 | !GpuLower::current()->fusedReductionInfo().isAllreduce(id)) { |
237 | id_reductions.set(id->getParallelType()); |
238 | } |
239 | if (id->isBroadcast() && |
240 | GpuLower::current()->concretizedBroadcastDomains()->isConcretized( |
241 | id)) { |
242 | id_bcasts.set(id->getParallelType()); |
243 | } |
244 | } |
245 | } |
246 | |
247 | // Validate the combination of ptypes, reductions, bcasts |
248 | for (const auto i : c10::irange(ParallelTypeBitmap::kNumParallelTypes)) { |
249 | if (input_reductions[i]) { |
250 | if (id_ptypes[i]) { |
251 | TORCH_INTERNAL_ASSERT( |
252 | id_reductions[i], |
253 | "Mismatched parallelized reductions found on inputs of epxr: " , |
254 | expr); |
255 | TORCH_CHECK( |
256 | !id_bcasts[i], |
257 | "Invalid broadcast and reduction combination, tried to parallelize both with the same thread dim: " , |
258 | inp); |
259 | } |
260 | } |
261 | } |
262 | |
263 | // Figure out which dims bcast wants to reset |
264 | auto this_input_preds = pred_info.limited_types; |
265 | const auto bcast_reset_mask = ~(this_input_preds & id_bcasts); |
266 | this_input_preds &= bcast_reset_mask; |
267 | |
268 | input_preds |= this_input_preds; |
269 | |
270 | id_reductions |= |
271 | getReductionPredicateForUnusedParallelTypes(tv_inp, at(tv_inp)); |
272 | |
273 | // Accumulate |
274 | input_reductions |= id_reductions; |
275 | } |
276 | |
277 | // Update map for this tv, before accumulating to other inputs |
278 | // Add any reductions this id has to any input predicates |
279 | auto output_preds = input_preds | input_reductions; |
280 | |
281 | // Run through outputs and set bitset predicates |
282 | for (auto* out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) { |
283 | auto redundant_types = avoidRedundantWrites(out_tv); |
284 | update(out_tv, output_preds, redundant_types); |
285 | } |
286 | } |
287 | |
288 | namespace { |
289 | |
290 | //! A simple backward data flow pass: |
291 | //! This pass propagates information backward to annotate "redundant use |
292 | //! chain"'s. |
293 | //! The reason this is needed is that, say for example, if we have a chain |
294 | //! of register-to-register ops that begins with a redundant shared mem write |
295 | //! and ends with an op that non-redundantly uses the result, we'd need to |
296 | //! insert a sync at the begining of the register-to-register chain. |
297 | //! |
298 | //! The same mechanism also applies in the case of a register/sharedmem chain |
299 | //! that starts and ends with global memory read/write. |
300 | //! |
301 | //! The propagation rule is summarized as follows: |
302 | //! |
303 | //! Shared TV val: |
304 | //! Reset all block redundant info to its own redundant write info |
305 | //! Backpropagate grid redundant info |
306 | //! Global TV val: |
307 | //! Reset all redundant info to its own redundant write info |
308 | //! Local Tv val: |
309 | //! Backpropagate all redundant info |
310 | //! Exprs: |
311 | //! Propagate redundant info backwards from outputs to inputs: |
312 | //! For each parallel type, |
313 | //! The parallel type is redundantly used in the expr input |
314 | //! only if all of the outputs redundantly use the same type. |
315 | class RedundantUseAnalysis : BackwardVisitor { |
316 | public: |
317 | RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map) |
318 | : fusion_(fusion), pred_map_(pred_map) { |
319 | traverseTo(fusion, fusion->terminatingMathVals()); |
320 | } |
321 | |
322 | //! Returns a bit map signifying the parallel dimensions |
323 | //! on which the given tv is redundantly used. On these |
324 | //! dimensions not all threads/blocks are required to |
325 | //! hold valid value for their dependent computations. |
326 | ParallelTypeBitmap getRedundantUseBitMap(const TensorView* tv) { |
327 | // Since all tv's consumers are visited at this point, we |
328 | // can aggregate the final redundant use info for this tv. |
329 | if (fusion_->unordered_uses(tv).empty()) { |
330 | // Base case, un-used is also not redundantly used |
331 | return ParallelTypeBitmap(); |
332 | } else { |
333 | // Aggregate redundant use as a conjunction of all |
334 | // consumer's redundant consumer info propagated |
335 | // backward from their consumer chains. |
336 | ParallelTypeBitmap redundant_use; |
337 | redundant_use.setAllBID(); |
338 | redundant_use.setAllTID(); |
339 | for (auto expr : fusion_->unordered_uses(tv)) { |
340 | redundant_use &= redundant_expr_use_map_.at(expr); |
341 | } |
342 | |
343 | return redundant_use; |
344 | } |
345 | } |
346 | |
347 | private: |
348 | using BackwardVisitor::handle; |
349 | |
350 | void handle(TensorView* tv) final { |
351 | auto redundant_tv_map = pred_map_.getPredicateInfo(tv).redundant_types; |
352 | |
353 | // Setup the info to propagate backward for the producer tv's and |
354 | // expressions. |
355 | ParallelTypeBitmap& redundant_consumer_map = |
356 | redundant_consumer_parallel_type_map_[tv]; |
357 | |
358 | // Initialize the use map to the redundant pred result |
359 | redundant_consumer_map = redundant_tv_map; |
360 | |
361 | if (tv->getMemoryType() == MemoryType::Shared) { |
362 | backPropagateRedundantUse( |
363 | redundant_consumer_map, |
364 | tv, |
365 | false, // no propagate TID redundant use for shared tv |
366 | true // propagate BID redundant use |
367 | ); |
368 | |
369 | } else if (tv->getMemoryType() == MemoryType::Local) { |
370 | backPropagateRedundantUse( |
371 | redundant_consumer_map, |
372 | tv, |
373 | true, // propagate TID redundant use |
374 | true // propagate BID redundant use |
375 | ); |
376 | } |
377 | } |
378 | |
379 | void backPropagateRedundantUse( |
380 | ParallelTypeBitmap& use_map, |
381 | TensorView* tv, |
382 | bool propagate_tid, |
383 | bool propagate_bid) { |
384 | // Clear the propagated part of the original result |
385 | if (propagate_bid) { |
386 | use_map.setAllBID(); |
387 | } |
388 | if (propagate_tid) { |
389 | use_map.setAllTID(); |
390 | } |
391 | |
392 | for (auto expr : fusion_->unordered_uses(tv)) { |
393 | // Assuming all consumer expressions have been |
394 | // visited at this point since we are traversing |
395 | // backward. |
396 | auto expr_use_map = redundant_expr_use_map_.at(expr); |
397 | // Clear the part of expression use map that does not |
398 | // need to be propagated. |
399 | if (!propagate_bid) { |
400 | expr_use_map.setAllBID(); |
401 | } |
402 | if (!propagate_tid) { |
403 | expr_use_map.setAllTID(); |
404 | } |
405 | |
406 | // Accumulate expression redundant usage |
407 | // This implements the `only if all` part in |
408 | // the discussion above. |
409 | use_map &= expr_use_map; |
410 | } |
411 | } |
412 | |
413 | void handle(Expr* expr) final { |
414 | if (ir_utils::isTvOp(expr)) { |
415 | // Initialize redundant info for current expr |
416 | c10::optional<ParallelTypeBitmap> maybe_expr_pred_map; |
417 | |
418 | for (auto consumer_tv : |
419 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
420 | auto tv_redundant_bitmap = |
421 | redundant_consumer_parallel_type_map_.at(consumer_tv); |
422 | |
423 | if (maybe_expr_pred_map.has_value()) { |
424 | // Accumulate redundant info of this tv output. |
425 | maybe_expr_pred_map.value() &= tv_redundant_bitmap; |
426 | } else { |
427 | // Copy the tv's redundant info as the first valid case. |
428 | maybe_expr_pred_map = tv_redundant_bitmap; |
429 | } |
430 | } |
431 | |
432 | TORCH_INTERNAL_ASSERT( |
433 | maybe_expr_pred_map.has_value(), "TV op not having a tv output" ); |
434 | redundant_expr_use_map_[expr] = maybe_expr_pred_map.value(); |
435 | } |
436 | } |
437 | |
438 | private: |
439 | // Populated redundant use information on the used tv's |
440 | // This map provides information on if the given tv does not require |
441 | // valid data from its producer on any parallel dimensions. |
442 | // For example: |
443 | // T1_local = T0_shared[...] |
444 | // if(tid.x == 0) |
445 | // T2_shared[...] = T1_local[...] |
446 | // Then tidx would be redundant consumer parallel type |
447 | // for T1, as T1 is local tensor, and only threads satisfying |
448 | // tidx == 0 would need to provide a valid data. |
449 | // In this case, not all threads would need to read correct data |
450 | // from T0_shared, which would help remove some sync's. |
451 | std::unordered_map<const TensorView*, ParallelTypeBitmap> |
452 | redundant_consumer_parallel_type_map_; |
453 | |
454 | // Populated redundant use information on the used tv expressions. |
455 | std::unordered_map<const Expr*, ParallelTypeBitmap> redundant_expr_use_map_; |
456 | |
457 | // Short cut to the owning fusion of this analysis. |
458 | Fusion* fusion_ = nullptr; |
459 | |
460 | // Short cut to the active pred map analysis this pass is running as part of. |
461 | const ThreadPredicateMap& pred_map_; |
462 | }; |
463 | |
464 | } // namespace |
465 | |
466 | void ThreadPredicateMap::build(Fusion* fusion) { |
467 | FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap" ); |
468 | |
469 | // Initialize mapping for input tensors |
470 | for (auto inp : fusion->inputs()) { |
471 | if (auto tv = dynamic_cast<const TensorView*>(inp)) { |
472 | update(tv, ParallelTypeBitmap(), ParallelTypeBitmap()); |
473 | } |
474 | } |
475 | for (auto expr : fusion->exprs()) { |
476 | updateBitSet(expr); |
477 | } |
478 | updated_tvs_.clear(); |
479 | populateRedundantUseMap(fusion); |
480 | } |
481 | |
482 | void ThreadPredicateMap::populateRedundantUseMap(Fusion* fusion) { |
483 | RedundantUseAnalysis redundant_use(fusion, *this); |
484 | for (auto& it : thread_predicates_) { |
485 | it.second.redundant_use_types = |
486 | redundant_use.getRedundantUseBitMap(it.first); |
487 | } |
488 | } |
489 | |
490 | ThreadPredicateMap::const_iterator ThreadPredicateMap::find( |
491 | const TensorView* tv) const { |
492 | return thread_predicates_.find(tv); |
493 | } |
494 | |
495 | ThreadPredicateMap::const_iterator ThreadPredicateMap::end() const { |
496 | return thread_predicates_.end(); |
497 | } |
498 | |
499 | const ThreadPredicateMap::PredicateInfo& ThreadPredicateMap::at( |
500 | const TensorView* tv) const { |
501 | return thread_predicates_.at(tv); |
502 | } |
503 | |
504 | ThreadPredicateMap::PredicateInfo& ThreadPredicateMap::at( |
505 | const TensorView* tv) { |
506 | return thread_predicates_.at(tv); |
507 | } |
508 | |
509 | ThreadPredicateMap::PredicateInfo ThreadPredicateMap::getPredicateInfo( |
510 | const TensorView* tv) const { |
511 | auto pred_info = thread_predicates_.at(tv); |
512 | // Do not predicate a paralell type if it is a parallel bcast domain |
513 | if (dynamic_cast<BroadcastOp*>(tv->definition())) { |
514 | auto parallel_bcast = getParallelBroadcastDomains(tv); |
515 | pred_info.limited_types ^= parallel_bcast; |
516 | } |
517 | return pred_info; |
518 | } |
519 | |
520 | ParallelTypeBitmap ThreadPredicateMap::getPredicatedParallelTypes( |
521 | const TensorView* tv) const { |
522 | auto pred_info = getPredicateInfo(tv); |
523 | return pred_info.limited_types | pred_info.redundant_types; |
524 | } |
525 | |
526 | bool ThreadPredicateMap::update( |
527 | const TensorView* tv, |
528 | const ParallelTypeBitmap& limited_types, |
529 | const ParallelTypeBitmap& redundant_types) { |
530 | return update(tv, {limited_types, redundant_types}); |
531 | } |
532 | |
533 | bool ThreadPredicateMap::update( |
534 | const TensorView* tv, |
535 | const PredicateInfo& pred_info) { |
536 | auto existing_mapping_it = thread_predicates_.find(tv); |
537 | if (existing_mapping_it != end()) { |
538 | PredicateInfo& existing_info = existing_mapping_it->second; |
539 | if (existing_info == pred_info) { |
540 | return false; |
541 | } else { |
542 | existing_info = pred_info; |
543 | markAsUpdated(tv); |
544 | return true; |
545 | } |
546 | } else { |
547 | thread_predicates_.insert({tv, pred_info}); |
548 | markAsUpdated(tv); |
549 | return true; |
550 | } |
551 | } |
552 | |
553 | Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const { |
554 | TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find " , tv); |
555 | auto pred_info = getPredicateInfo(tv); |
556 | return getPredicateFromPredicateInfo(pred_info); |
557 | } |
558 | |
559 | ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains( |
560 | const TensorView* tv) const { |
561 | // If no pred is found for tv, no predicate is necessary |
562 | if (find(tv) == end()) { |
563 | return ParallelTypeBitmap(); |
564 | } |
565 | |
566 | ParallelTypeBitmap parallel_broadcast; |
567 | |
568 | const auto& iter_domains = tv->domain()->domain(); |
569 | |
570 | // If the output is on shared memory, assume that all subsequent |
571 | // reads from all threads in its CTA can be done with no parallel |
572 | // broadcast. Only one thread will write to shared memory followed |
573 | // by a proper _syncthreads. |
574 | const bool output_smem = tv->getMemoryType() == MemoryType::Shared; |
575 | |
576 | for (auto id : iter_domains) { |
577 | if (!id->isBroadcast() || |
578 | !GpuLower::current()->concretizedBroadcastDomains()->isConcretized( |
579 | id)) { |
580 | continue; |
581 | } |
582 | if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { |
583 | parallel_broadcast.set(id->getParallelType()); |
584 | } |
585 | } |
586 | |
587 | return parallel_broadcast & at(tv).limited_types; |
588 | } |
589 | |
590 | ParallelTypeBitmap ThreadPredicateMap::getRedundantConsumerType( |
591 | Expr* expr) const { |
592 | c10::optional<ParallelTypeBitmap> result; |
593 | for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) { |
594 | auto out_tv_redundant_map = getPredicateInfo(out_tv).redundant_use_types; |
595 | if (!result.has_value()) { |
596 | result = out_tv_redundant_map; |
597 | } else { |
598 | result.value() &= out_tv_redundant_map; |
599 | } |
600 | } |
601 | |
602 | TORCH_INTERNAL_ASSERT( |
603 | result.has_value(), "ThreadPredicateMap : TV op assumed" ); |
604 | return result.value(); |
605 | } |
606 | |
607 | void ThreadPredicateMap::markAsUpdated(const TensorView* tv) { |
608 | updated_tvs_.insert(tv); |
609 | } |
610 | |
611 | void ThreadPredicateMap::print() const { |
612 | std::cout << "\nThreadPredicateMap\n" ; |
613 | std::cout << "--------------------------------\n" ; |
614 | for (const auto& kv : thread_predicates_) { |
615 | std::cout << "T" << kv.first->name(); |
616 | std::cout << " {" << kv.second.limited_types.toString() << "}\n" ; |
617 | std::cout << "{" << kv.second.redundant_types.toString() << "}\n" ; |
618 | std::cout << "{" << kv.second.redundant_use_types.toString() << "}\n" ; |
619 | } |
620 | std::cout << "--------------------------------\n\n" ; |
621 | } |
622 | |
623 | } // namespace cuda |
624 | } // namespace fuser |
625 | } // namespace jit |
626 | } // namespace torch |
627 | |