1 | #include <lower_utils.h> |
2 | |
3 | #include <ATen/cuda/CUDAContext.h> |
4 | #include <c10/util/irange.h> |
5 | #include <arith.h> |
6 | #include <ir_iostream.h> |
7 | #include <ir_utils.h> |
8 | #include <iter_visitor.h> |
9 | #include <kernel_ir_dispatch.h> |
10 | #include <lower2device.h> |
11 | #include <lower_thread_predicate.h> |
12 | #include <root_domain_map.h> |
13 | |
14 | #include <algorithm> |
15 | |
16 | // TODO: refactor this file (one per namespace) |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace fuser { |
21 | namespace cuda { |
22 | |
23 | namespace scope_utils { |
24 | |
25 | //! Create an **empty** Forloop and copy the metadata. |
26 | kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop) { |
27 | return IrBuilder::create<kir::ForLoop>(for_loop); |
28 | } |
29 | |
30 | //! Create an **empty** IfThenElse and copy the metadata. |
31 | kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite) { |
32 | return IrBuilder::create<kir::IfThenElse>(ite->predicate()); |
33 | } |
34 | |
35 | } // namespace scope_utils |
36 | |
37 | namespace ir_utils { |
38 | |
39 | TVDomainGuard::TVDomainGuard(TensorView* tv, TensorDomain* td) |
40 | : tv_(tv), prev_domain_(tv_->domain()) { |
41 | tv_->setDomain(td); |
42 | } |
43 | |
44 | TVDomainGuard::TVDomainGuard(TVDomainGuard&& guard) |
45 | : tv_(nullptr), prev_domain_(guard.prev_domain_) { |
46 | std::swap(tv_, guard.tv_); |
47 | } |
48 | |
49 | TVDomainGuard::~TVDomainGuard() { |
50 | if (tv_ != nullptr) { |
51 | tv_->setDomain(prev_domain_); |
52 | } |
53 | } |
54 | |
55 | ir_utils::TVDomainGuard overrideContiguityGuard( |
56 | TensorView* tv, |
57 | bool contiguity) { |
58 | // Use domain guard to ignore the contiguity of |
59 | // consumer tv. |
60 | TensorDomain* domain_with_specified_contiguity = nullptr; |
61 | std::vector<bool> contiguity_vector( |
62 | tv->getMaybeRFactorDomain().size(), contiguity); |
63 | if (tv->hasRFactor()) { |
64 | domain_with_specified_contiguity = IrBuilder::create<TensorDomain>( |
65 | tv->getRootDomain(), |
66 | tv->getRFactorDomain(), |
67 | tv->domain()->domain(), |
68 | contiguity_vector); |
69 | } else { |
70 | domain_with_specified_contiguity = IrBuilder::create<TensorDomain>( |
71 | tv->getRootDomain(), tv->domain()->domain(), contiguity_vector); |
72 | } |
73 | |
74 | return ir_utils::TVDomainGuard(tv, domain_with_specified_contiguity); |
75 | } |
76 | |
77 | std::vector<IterDomain*> iterDomainInputsOf( |
78 | const std::vector<IterDomain*>& input_ids, |
79 | const std::vector<IterDomain*>& all_inputs) { |
80 | auto inputs = IterVisitor::getInputsTo( |
81 | {input_ids.begin(), input_ids.end()}, |
82 | {all_inputs.begin(), all_inputs.end()}); |
83 | std::vector<IterDomain*> id_inputs( |
84 | ir_utils::filterByType<IterDomain>(inputs).begin(), |
85 | ir_utils::filterByType<IterDomain>(inputs).end()); |
86 | return id_inputs; |
87 | } |
88 | |
89 | std::vector<IterDomain*> iterDomainInputsOfOrderedAs( |
90 | const std::vector<IterDomain*>& of, |
91 | const std::vector<IterDomain*>& order) { |
92 | auto inputs_vec = iterDomainInputsOf(of, order); |
93 | |
94 | std::unordered_set<IterDomain*> inputs_set( |
95 | inputs_vec.begin(), inputs_vec.end()); |
96 | |
97 | std::vector<IterDomain*> ordered_inputs; |
98 | std::copy_if( |
99 | order.begin(), |
100 | order.end(), |
101 | std::back_inserter(ordered_inputs), |
102 | [&inputs_set](const auto& id) { |
103 | return inputs_set.find(id) != inputs_set.end(); |
104 | }); |
105 | |
106 | return ordered_inputs; |
107 | } |
108 | |
109 | bool isTV(const Val* val) { |
110 | return val->getValType().value() == ValType::TensorView || |
111 | val->getValType().value() == ValType::TensorIndex; |
112 | } |
113 | |
114 | // Check if we're a TensorView op that we can generate code for. |
115 | bool isTvOp(const Expr* expr) { |
116 | if (std::any_of( |
117 | expr->outputs().begin(), |
118 | expr->outputs().end(), |
119 | [](Val* v) { return isTV(v); }) && |
120 | (expr->getExprType().value() == ExprType::UnaryOp || |
121 | expr->getExprType().value() == ExprType::BinaryOp || |
122 | expr->getExprType().value() == ExprType::TernaryOp || |
123 | expr->getExprType().value() == ExprType::RNGOp || |
124 | expr->getExprType().value() == ExprType::FullOp || |
125 | expr->getExprType().value() == ExprType::ARangeOp || |
126 | expr->getExprType().value() == ExprType::EyeOp || |
127 | expr->getExprType().value() == ExprType::ReductionOp || |
128 | expr->getExprType().value() == ExprType::GroupedReductionOp || |
129 | expr->getExprType().value() == ExprType::WelfordOp || |
130 | expr->getExprType().value() == ExprType::GroupedWelfordOp || |
131 | expr->getExprType().value() == ExprType::LoadStoreOp || |
132 | expr->getExprType().value() == ExprType::MmaOp || |
133 | expr->getExprType().value() == ExprType::BroadcastOp || |
134 | expr->getExprType().value() == ExprType::TransposeOp || |
135 | expr->getExprType().value() == ExprType::ExpandOp || |
136 | expr->getExprType().value() == ExprType::ShiftOp || |
137 | expr->getExprType().value() == ExprType::GatherOp || |
138 | expr->getExprType().value() == ExprType::ViewAsScalar || |
139 | expr->getExprType().value() == ExprType::ViewOp || |
140 | expr->getExprType().value() == ExprType::GridReduction || |
141 | expr->getExprType().value() == ExprType::GroupedGridReduction || |
142 | expr->getExprType().value() == ExprType::GridBroadcast || |
143 | expr->getExprType().value() == ExprType::GridWelford || |
144 | expr->getExprType().value() == ExprType::GroupedGridWelford)) { |
145 | return true; |
146 | } |
147 | return false; |
148 | } |
149 | |
150 | bool isLdMatrixOp(const Expr* expr) { |
151 | if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) { |
152 | return ldst->opType() == LoadStoreOpType::LdMatrix || |
153 | ldst->opType() == LoadStoreOpType::LdMatrixTranspose; |
154 | } |
155 | return false; |
156 | } |
157 | |
158 | bool isCpAsyncOp(const Expr* expr) { |
159 | if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) { |
160 | return ldst->opType() == LoadStoreOpType::CpAsync; |
161 | } |
162 | return false; |
163 | } |
164 | |
165 | bool isTensorScalarFillOp(const Expr* expr) { |
166 | // Check that the input is a single scalar. |
167 | if (expr->inputs().size() == 1 && expr->input(0)->isScalar()) { |
168 | // All load store op with a single scalar input |
169 | // should be a scalar filling op. Semantically |
170 | // it literally means `Store`'ing a scalar |
171 | // into a tensor. |
172 | if (expr->isA<LoadStoreOp>()) { |
173 | return true; |
174 | } |
175 | // Unary copy op is also a scalar filling op. |
176 | if (auto uop = dynamic_cast<const UnaryOp*>(expr)) { |
177 | return uop->getUnaryOpType() == UnaryOpType::Set; |
178 | } |
179 | } |
180 | // Ideally any scalar expression that outputs |
181 | // to a tensor should be considered in this function |
182 | // but since we currently only limit scope to |
183 | // initialization patterns so other scalar expr's |
184 | // are low priority and are excluded here to avoid confusion. |
185 | return false; |
186 | } |
187 | |
188 | TensorView* getTv(Val* val) { |
189 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
190 | return const_cast<TensorView*>(getTv(const_cast<const Val*>(val))); |
191 | } |
192 | |
193 | const TensorView* getTv(const Val* val) { |
194 | if (val->isA<TensorView>()) { |
195 | return val->as<TensorView>(); |
196 | } else if (val->isA<kir::TensorIndex>()) { |
197 | return val->as<kir::TensorIndex>()->view(); |
198 | } |
199 | return nullptr; |
200 | } |
201 | |
202 | std::vector<TensorView*> getTvs(const std::vector<Val*>& vals) { |
203 | std::vector<TensorView*> tvs; |
204 | for (auto val : vals) { |
205 | auto tv = ir_utils::getTv(val); |
206 | if (tv) { |
207 | tvs.emplace_back(tv); |
208 | } |
209 | } |
210 | return tvs; |
211 | } |
212 | |
213 | TensorView* getTvOutput(const Expr* expr) { |
214 | for (auto out : expr->outputs()) { |
215 | if (auto tv = getTv(out)) { |
216 | return tv; |
217 | } |
218 | } |
219 | return nullptr; |
220 | } |
221 | |
222 | TensorView* getTvInput(const Expr* expr) { |
223 | for (auto inp : expr->inputs()) { |
224 | if (auto tv = getTv(inp)) { |
225 | return tv; |
226 | } |
227 | } |
228 | return nullptr; |
229 | } |
230 | |
231 | bool isScalarOp(const Expr* expr) { |
232 | for (auto out : expr->outputs()) |
233 | if (!out->isScalar()) |
234 | return false; |
235 | return true; |
236 | } |
237 | |
238 | c10::optional<IterDomain*> getMaybeWarpReductionDim( |
239 | const Val* output, |
240 | const Val* input) { |
241 | auto tv_out = getTv(output); |
242 | if (tv_out == nullptr) { |
243 | return c10::nullopt; |
244 | } |
245 | |
246 | auto tv_in = getTv(input); |
247 | // only support reducing to registers for now. |
248 | if (tv_in->getMemoryType() != MemoryType::Local || |
249 | tv_out->getMemoryType() != MemoryType::Local) { |
250 | return c10::nullopt; |
251 | } |
252 | |
253 | IterDomain* reduction_on_xdim = nullptr; |
254 | for (auto id : tv_out->domain()->domain()) { |
255 | // Currently warp reduction only allows |
256 | // serial and block.x parallel reductions |
257 | if (id->isReduction() && id->isParallelized()) { |
258 | if (id->getParallelType() == ParallelType::TIDx) { |
259 | reduction_on_xdim = id; |
260 | } else if (id->isThread()) { |
261 | return c10::nullopt; |
262 | } |
263 | } |
264 | } |
265 | if (!reduction_on_xdim) { |
266 | return c10::nullopt; |
267 | } |
268 | |
269 | if (!reduction_on_xdim->start()->isZeroInt()) { |
270 | return c10::nullopt; |
271 | } |
272 | |
273 | if (reduction_on_xdim->hasPaddingToMultipleOfWarp()) { |
274 | return c10::optional<IterDomain*>(reduction_on_xdim); |
275 | } |
276 | |
277 | if (reduction_on_xdim->extent()->isConstInt()) { |
278 | auto extent_value = reduction_on_xdim->extent()->evaluateInt(); |
279 | if (extent_value % at::cuda::warp_size() == 0) { |
280 | return c10::optional<IterDomain*>(reduction_on_xdim); |
281 | } |
282 | } |
283 | |
284 | return c10::nullopt; |
285 | } |
286 | |
287 | bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis) { |
288 | std::vector<IterDomain*> ca_axes( |
289 | tv->domain()->domain().begin(), |
290 | tv->domain()->domain().begin() + tv->getComputeAtPosition()); |
291 | |
292 | auto ca_root_vals = IterVisitor::getInputsTo( |
293 | std::vector<Val*>(ca_axes.begin(), ca_axes.end())); |
294 | |
295 | auto root_vals = IterVisitor::getInputsTo({axis}); |
296 | |
297 | return std::any_of( |
298 | root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) { |
299 | return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) != |
300 | ca_root_vals.end(); |
301 | }); |
302 | } |
303 | |
304 | std::unordered_map<ParallelType, IterDomain*, TypeHash> getParallelDomains( |
305 | const Val* val) { |
306 | const TensorView* tv = nullptr; |
307 | if (val->isA<TensorView>()) { |
308 | tv = val->as<TensorView>(); |
309 | } else if (val->isA<kir::TensorIndex>()) { |
310 | tv = val->as<kir::TensorIndex>()->view(); |
311 | } else { |
312 | TORCH_INTERNAL_ASSERT( |
313 | false, "Provided val is not TensorIndex or TensorView." ); |
314 | } |
315 | |
316 | std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains; |
317 | for (auto d : tv->domain()->domain()) { |
318 | if (d->isThread()) { |
319 | parallel_domains.insert(std::make_pair(d->getParallelType(), d)); |
320 | } |
321 | } |
322 | return parallel_domains; |
323 | } |
324 | |
325 | bool isCpAsyncInit(const Expr* expr) { |
326 | return isTensorScalarFillOp(expr) && |
327 | // FIXME: |
328 | // We'd need to add a flag to all the init |
329 | // exprs so we could robustly detect initialization |
330 | // in all cases. |
331 | isCpAsyncOp(getTvOutput(expr)->definition()); |
332 | } |
333 | |
334 | c10::optional<Expr*> getMaybePredicatedSingleton(Expr* expr) { |
335 | if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
336 | if (ite->elseBody().empty()) { |
337 | if (ite->thenBody().size() == 1) { |
338 | return ite->thenBody().exprs()[0]; |
339 | } |
340 | } |
341 | } |
342 | return c10::nullopt; |
343 | } |
344 | |
345 | //! Short-cut for checking if the expression loads from global memory. |
346 | bool isGlobalLoad(const Expr* expr) { |
347 | if (expr->isA<LoadStoreOp>() || |
348 | (expr->isA<UnaryOp>() && |
349 | expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set)) { |
350 | if (auto in_tv = getTv(expr->input(0))) { |
351 | return in_tv->getMemoryType() == MemoryType::Global; |
352 | } |
353 | } |
354 | return false; |
355 | } |
356 | |
357 | //! Short-cut for checking if the given expression initializes buffers |
358 | //! for global memory load. |
359 | bool isGlobalLoadInit(const Expr* expr) { |
360 | if (auto uop = dynamic_cast<const UnaryOp*>(expr)) { |
361 | if (uop->in()->isScalar()) { |
362 | // FIXME: |
363 | // We'd need to add a flag to all the init |
364 | // exprs so we could robustly detect initialization |
365 | // in all cases. |
366 | if (isGlobalLoad(getTvOutput(uop)->definition())) { |
367 | return true; |
368 | } |
369 | } |
370 | } |
371 | return false; |
372 | } |
373 | |
374 | namespace { |
375 | |
376 | class ExprFlattener : private kir::IrVisitor { |
377 | private: |
378 | using kir::IrVisitor::handle; |
379 | |
380 | void handle(Expr* expr) final { |
381 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
382 | kir::IrVisitor::handle(expr); |
383 | } else { |
384 | flat_exprs_.push_back(expr); |
385 | } |
386 | } |
387 | |
388 | private: |
389 | std::vector<Expr*> flat_exprs_; |
390 | |
391 | public: |
392 | //! Flattens scopes extracting out a single ordered list of exprs. |
393 | static std::vector<Expr*> flatten(const std::vector<Expr*>& loop_nests) { |
394 | ExprFlattener flattener; |
395 | for (auto expr : loop_nests) { |
396 | flattener.handle(expr); |
397 | } |
398 | return flattener.flat_exprs_; |
399 | } |
400 | }; |
401 | |
402 | } // namespace |
403 | |
404 | std::vector<Expr*> flattenScopedExprs(const std::vector<Expr*>& loop_nests) { |
405 | return ExprFlattener::flatten(loop_nests); |
406 | } |
407 | |
408 | namespace { |
409 | |
410 | class ReplaceExprInput : private kir::ExprMutator { |
411 | public: |
412 | static std::vector<Expr*> replace( |
413 | const std::vector<Expr*>& exprs, |
414 | const std::unordered_map<Val*, Val*>& replacement_map) { |
415 | ReplaceExprInput replacer(replacement_map); |
416 | replacer.traverseAndInsert(exprs); |
417 | return replacer.exprs_; |
418 | } |
419 | |
420 | private: |
421 | ReplaceExprInput(const std::unordered_map<Val*, Val*>& replacement_map) |
422 | : replacement_map_(replacement_map) {} |
423 | |
424 | using kir::ExprMutator::handle; |
425 | |
426 | c10::optional<std::unordered_map<Val*, Val*>> getMaybeInputReplacementMap( |
427 | Expr* expr) { |
428 | bool need_replacement = false; |
429 | |
430 | std::unordered_map<Val*, Val*> replaced_val; |
431 | for (auto in : expr->inputs()) { |
432 | auto replace_it = replacement_map_.find(in); |
433 | if (replace_it != replacement_map_.end()) { |
434 | need_replacement = true; |
435 | replaced_val[in] = replace_it->second; |
436 | } else { |
437 | replaced_val[in] = in; |
438 | } |
439 | } |
440 | if (need_replacement) { |
441 | return c10::optional<std::unordered_map<Val*, Val*>>(replaced_val); |
442 | } else { |
443 | return c10::nullopt; |
444 | } |
445 | } |
446 | |
447 | // Copy predicates and register expression replacement |
448 | void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) { |
449 | new_expr = new_expr->withPredicate(old_expr->predicate()) |
450 | ->withWritePredicate(old_expr->writePredicate()); |
451 | registerReplace(old_expr, new_expr); |
452 | } |
453 | |
454 | void handle(UnaryOp* node) final { |
455 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
456 | if (replaced_inputs.has_value()) { |
457 | auto replacement = IrBuilder::create<UnaryOp>( |
458 | node->getUnaryOpType(), node->out(), replaced_inputs->at(node->in())); |
459 | registerReplaceWithPredicate(node, replacement); |
460 | } |
461 | } |
462 | |
463 | void handle(BinaryOp* node) final { |
464 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
465 | if (replaced_inputs.has_value()) { |
466 | auto replacement = IrBuilder::create<BinaryOp>( |
467 | node->getBinaryOpType(), |
468 | node->out(), |
469 | replaced_inputs->at(node->lhs()), |
470 | replaced_inputs->at(node->rhs())); |
471 | registerReplaceWithPredicate(node, replacement); |
472 | } |
473 | } |
474 | |
475 | void handle(TernaryOp* node) final { |
476 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
477 | if (replaced_inputs.has_value()) { |
478 | auto replacement = IrBuilder::create<TernaryOp>( |
479 | node->getTernaryOpType(), |
480 | node->out(), |
481 | replaced_inputs->at(node->in1()), |
482 | replaced_inputs->at(node->in2()), |
483 | replaced_inputs->at(node->in3())); |
484 | registerReplaceWithPredicate(node, replacement); |
485 | } |
486 | } |
487 | |
488 | void handle(RNGOp* node) final { |
489 | // RNGOp has no input |
490 | return; |
491 | } |
492 | |
493 | void handle(ReductionOp* node) final { |
494 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
495 | if (replaced_inputs.has_value()) { |
496 | auto replacement = IrBuilder::create<ReductionOp>( |
497 | node->getReductionOpType(), |
498 | node->init(), |
499 | node->out(), |
500 | replaced_inputs->at(node->in()), |
501 | node->isAllreduce()); |
502 | registerReplaceWithPredicate(node, replacement); |
503 | } |
504 | } |
505 | |
506 | void handle(GroupedReductionOp* node) final { |
507 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
508 | if (replaced_inputs.has_value()) { |
509 | const auto& map = replaced_inputs.value(); |
510 | auto inputs = node->inputs(); |
511 | for (auto& input : inputs) { |
512 | auto it = map.find(input); |
513 | if (it != map.end()) { |
514 | input = it->second; |
515 | } |
516 | } |
517 | auto replacement = IrBuilder::create<GroupedReductionOp>( |
518 | node->getReductionOpTypes(), |
519 | node->initVals(), |
520 | node->outputs(), |
521 | inputs, |
522 | node->isAllreduce()); |
523 | registerReplaceWithPredicate(node, replacement); |
524 | } |
525 | } |
526 | void handle(BroadcastOp* node) final { |
527 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
528 | if (replaced_inputs.has_value()) { |
529 | auto replacement = IrBuilder::create<BroadcastOp>( |
530 | node->out(), |
531 | replaced_inputs->at(node->in()), |
532 | node->getBroadcastDimFlags()); |
533 | registerReplaceWithPredicate(node, replacement); |
534 | } |
535 | } |
536 | |
537 | void handle(WelfordOp* node) final { |
538 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
539 | if (replaced_inputs.has_value()) { |
540 | auto replacement = IrBuilder::create<WelfordOp>( |
541 | node->outAvg(), |
542 | node->outVar(), |
543 | node->outN(), |
544 | node->initAvg(), |
545 | node->initVar(), |
546 | node->initN(), |
547 | replaced_inputs->at(node->inAvg()), |
548 | replaced_inputs->at(node->inVar()), |
549 | replaced_inputs->at(node->inN())); |
550 | registerReplaceWithPredicate(node, replacement); |
551 | } |
552 | } |
553 | |
554 | void handle(MmaOp* node) final { |
555 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
556 | if (replaced_inputs.has_value()) { |
557 | auto replacement = IrBuilder::create<MmaOp>( |
558 | node->out(), |
559 | replaced_inputs->at(node->inA()), |
560 | replaced_inputs->at(node->inB()), |
561 | node->init(), |
562 | node->options()); |
563 | registerReplaceWithPredicate(node, replacement); |
564 | } |
565 | } |
566 | |
567 | void handle(LoadStoreOp* node) final { |
568 | auto replaced_inputs = getMaybeInputReplacementMap(node); |
569 | if (replaced_inputs.has_value()) { |
570 | auto replacement = IrBuilder::create<LoadStoreOp>( |
571 | node->opType(), node->out(), node->in()); |
572 | registerReplaceWithPredicate(node, replacement); |
573 | } |
574 | } |
575 | |
576 | private: |
577 | const std::unordered_map<Val*, Val*>& replacement_map_; |
578 | }; |
579 | |
580 | } // namespace |
581 | |
582 | std::vector<Expr*> replaceInputsInExpr( |
583 | const std::vector<Expr*>& exprs, |
584 | const std::unordered_map<Val*, Val*>& replacement_map) { |
585 | return ReplaceExprInput::replace(exprs, replacement_map); |
586 | } |
587 | |
588 | std::vector<Expr*> getAllSwizzlesBetween( |
589 | std::vector<IterDomain*> from, |
590 | std::vector<IterDomain*> to) { |
591 | auto all_expr = DependencyCheck::getAllExprsBetween( |
592 | {from.begin(), from.end()}, {to.begin(), to.end()}); |
593 | |
594 | std::vector<Expr*> all_swizzles; |
595 | |
596 | std::copy_if( |
597 | all_expr.begin(), |
598 | all_expr.end(), |
599 | std::back_inserter(all_swizzles), |
600 | [](Expr* expr) { |
601 | return expr->getExprType().has_value() && |
602 | (expr->etype() == ExprType::Swizzle2D); |
603 | }); |
604 | |
605 | return all_swizzles; |
606 | } |
607 | |
608 | } // namespace ir_utils |
609 | |
610 | namespace lower_utils { |
611 | |
612 | bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { |
613 | if (expr->isA<kir::BlockSync>()) { |
614 | return true; |
615 | } |
616 | |
617 | if (!ir_utils::isTvOp(expr)) { |
618 | return false; |
619 | } |
620 | |
621 | if (!(ir_utils::isReductionOp(expr) || expr->isA<BroadcastOp>() || |
622 | expr->isA<kir::GridBroadcast>())) { |
623 | return false; |
624 | } |
625 | |
626 | // GroupedReductionOp can have multiple output TVs, but they must be |
627 | // parallelized in the same way, so just checking one of them is enough. |
628 | auto tv = ir_utils::getTvOutput(expr); |
629 | |
630 | if (tv->hasBlockReduction() || tv->hasGridReduction()) { |
631 | return true; |
632 | } else if (expr->isA<BroadcastOp>()) { |
633 | const ParallelTypeBitmap pt_map = |
634 | GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv); |
635 | return pt_map.any(); |
636 | } |
637 | |
638 | return false; |
639 | } |
640 | |
641 | kir::Allocate* allocGlobalBufferForGridComm( |
642 | Val* buffer_size, |
643 | DataType dtype, |
644 | bool zero_init) { |
645 | const std::vector<IterDomain*> new_buffer_ids = { |
646 | IrBuilder::create<IterDomain>(IterDomainBuilder( |
647 | GpuLower::current()->kernel()->zeroVal(), buffer_size))}; |
648 | const auto buffer_domain = IrBuilder::create<TensorDomain>(new_buffer_ids); |
649 | const auto buffer_tv = |
650 | IrBuilder::create<TensorView>(buffer_domain, dtype, MemoryType::Global); |
651 | return IrBuilder::create<kir::Allocate>( |
652 | buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init); |
653 | } |
654 | |
655 | BasicAllocInfo getAllocInformation( |
656 | const TensorView* tv, |
657 | const std::vector<kir::ForLoop*>& for_loops, |
658 | const std::unordered_map<IterDomain*, IterDomain*>& id_map, |
659 | bool use_id_map) { |
660 | BasicAllocInfo info; |
661 | auto gpu_lower = GpuLower::current(); |
662 | |
663 | bool outer_alloc_found = false; |
664 | |
665 | for (auto fl : for_loops) { |
666 | if (info.alloc_pos == tv->getComputeAtPosition()) { |
667 | break; |
668 | } |
669 | |
670 | if (tv->axis(info.alloc_pos)->isReduction()) { |
671 | const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs(); |
672 | TORCH_INTERNAL_ASSERT( |
673 | std::find(outputs.begin(), outputs.end(), tv) != outputs.end(), |
674 | "Invalid computeAt of T" , |
675 | tv->name(), |
676 | ". A reducation axis is detected outside computeAt point even though it is not an output tensor." ); |
677 | break; |
678 | } |
679 | |
680 | auto fl_id = fl->iter_domain(); |
681 | |
682 | if (fl_id->getParallelType() == ParallelType::Unroll) { |
683 | break; |
684 | } |
685 | |
686 | // Shared memory must be allocated outside of unswitched |
687 | // domains. See issue #1133. |
688 | if (fl_id->getParallelType() == ParallelType::Unswitch && |
689 | tv->getMemoryType() == MemoryType::Shared) { |
690 | outer_alloc_found = true; |
691 | } |
692 | |
693 | // Assume global memory is allocated at outer most scope. |
694 | if (tv->getMemoryType() == MemoryType::Global) { |
695 | outer_alloc_found = true; |
696 | } |
697 | |
698 | // Allocation of a double buffered tensor is placed outside its |
699 | // double buffer axis. |
700 | if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) && |
701 | tv->axis(info.alloc_pos) == |
702 | gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) { |
703 | outer_alloc_found = true; |
704 | } |
705 | |
706 | auto local_id = tv->axis(info.alloc_pos); |
707 | |
708 | if (use_id_map) { |
709 | auto id_it = id_map.find(local_id); |
710 | if (id_it != id_map.end()) { |
711 | local_id = id_it->second; |
712 | } |
713 | } |
714 | |
715 | if (GpuLower::current()->caMap()->areMapped( |
716 | local_id, fl_id, IdMappingMode::PERMISSIVE)) { |
717 | info.alloc_pos++; |
718 | } |
719 | |
720 | info.init_for_loop = fl; |
721 | |
722 | if (!outer_alloc_found) { |
723 | info.alloc_for_loop = fl; |
724 | } |
725 | } |
726 | |
727 | return info; |
728 | } |
729 | |
730 | //! Implementing this in here to avoid including too many headers |
731 | //! in type.cpp. Conceptually this should be a generic definition |
732 | //! rather than a util. |
733 | bool supportInlinePredicate(Expr* expr) { |
734 | if (ir_utils::isCpAsyncOp(expr)) { |
735 | return true; |
736 | } |
737 | // TODO: build out support. |
738 | return false; |
739 | } |
740 | |
741 | } // namespace lower_utils |
742 | |
743 | } // namespace cuda |
744 | } // namespace fuser |
745 | } // namespace jit |
746 | } // namespace torch |
747 | |