1 | #include <arith.h> |
2 | #include <codegen.h> |
3 | #include <disjoint_set.h> |
4 | #include <fusion.h> |
5 | #include <fusion_segmenter.h> |
6 | #include <instrumentation.h> |
7 | #include <ir_all_nodes.h> |
8 | #include <ir_cloner.h> |
9 | #include <ir_printer.h> |
10 | #include <ir_utils.h> |
11 | #include <iter_visitor.h> |
12 | #include <kernel.h> |
13 | #include <lower2device.h> |
14 | #include <lower_bank_conflict.h> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | namespace fuser { |
19 | namespace cuda { |
20 | |
21 | static thread_local Fusion* ACTIVE_FUSION = nullptr; // NOLINT |
22 | |
23 | FusionGuard::FusionGuard(Fusion* fusion) { |
24 | prev_fusion = ACTIVE_FUSION; |
25 | ACTIVE_FUSION = fusion; |
26 | } |
27 | |
28 | FusionGuard::~FusionGuard() { |
29 | ACTIVE_FUSION = prev_fusion; |
30 | } |
31 | |
32 | Fusion* FusionGuard::getCurFusion() { |
33 | return ACTIVE_FUSION; |
34 | } |
35 | void FusionGuard::setCurFusion(Fusion* fusion) { |
36 | ACTIVE_FUSION = fusion; |
37 | } |
38 | |
39 | void swap(Fusion& a, Fusion& b) noexcept { |
40 | FUSER_PERF_SCOPE("Fusion swap" ); |
41 | |
42 | using std::swap; |
43 | |
44 | swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b)); |
45 | |
46 | swap(a.inputs_, b.inputs_); |
47 | swap(a.outputs_, b.outputs_); |
48 | |
49 | swap(a.io_alias_, b.io_alias_); |
50 | swap(a.permuted_input_map_, b.permuted_input_map_); |
51 | swap(a.permuted_output_map_, b.permuted_output_map_); |
52 | } |
53 | |
54 | std::unique_ptr<SegmentedFusion> Fusion::segment( |
55 | const KernelArgumentHolder& args) { |
56 | FUSER_PERF_SCOPE("Segment Fusion" ); |
57 | return SegmentCandidateFinder::segment(this, args); |
58 | } |
59 | |
60 | IrCloner Fusion::copy(const Fusion* from, Fusion* to) { |
61 | to->clear(); |
62 | auto ir_cloner = IrContainer::copy(from, to); |
63 | |
64 | for (auto val : from->vals_) { |
65 | ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); |
66 | ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); |
67 | } |
68 | |
69 | to->inputs_ = ir_cloner.clone(from->inputs_); |
70 | to->outputs_ = ir_cloner.clone(from->outputs_); |
71 | for (auto inp : to->inputs_) { |
72 | inp->setIsFusionInput(true); |
73 | } |
74 | for (auto out : to->outputs_) { |
75 | out->setIsFusionOutput(true); |
76 | } |
77 | |
78 | // TODO: put this into ir_cloner instead |
79 | for (const auto& entry : from->io_alias_) { |
80 | Val* copied_output = ir_cloner.clone(entry.first); |
81 | Val* copied_input = ir_cloner.clone(entry.second); |
82 | to->io_alias_[copied_output] = copied_input; |
83 | } |
84 | |
85 | to->permuted_input_map_ = from->permuted_input_map_; |
86 | to->permuted_output_map_ = from->permuted_output_map_; |
87 | |
88 | to->all_tv_uses_valid_ = from->all_tv_uses_valid_; |
89 | // This should never be true on copy, but copying for completeness. |
90 | to->is_during_update_uses_ = from->is_during_update_uses_; |
91 | |
92 | return ir_cloner; |
93 | } |
94 | |
95 | // Clang tidy complains when using default constructor for IrContainer instead |
96 | // of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's |
97 | // redundant to use the IrContainer copy constructor, but it is harmless since |
98 | // Fusion::copy starts by calling clear(). |
99 | Fusion::Fusion(const Fusion& other) : IrContainer(other) { |
100 | FUSER_PERF_SCOPE("Fusion copy" ); |
101 | Fusion::copy(&other, this); |
102 | } |
103 | |
104 | Fusion::Fusion(Fusion&& other) noexcept { |
105 | FUSER_PERF_SCOPE("Fusion move" ); |
106 | swap(*this, other); |
107 | } |
108 | |
109 | Fusion& Fusion::operator=(const Fusion& other) { |
110 | FUSER_PERF_SCOPE("Fusion copy assign" ); |
111 | Fusion copy(other); |
112 | clear(); |
113 | swap(*this, copy); |
114 | return *this; |
115 | } |
116 | |
117 | Fusion& Fusion::operator=(Fusion&& other) noexcept { |
118 | FUSER_PERF_SCOPE("Fusion move assign" ); |
119 | clear(); |
120 | swap(*this, other); |
121 | return *this; |
122 | } |
123 | |
124 | Fusion::~Fusion() { |
125 | clear(); |
126 | } |
127 | |
128 | void Fusion::clear() noexcept { |
129 | FUSER_PERF_SCOPE("Fusion clear" ); |
130 | |
131 | IrContainer::clear(); |
132 | |
133 | inputs_.clear(); |
134 | outputs_.clear(); |
135 | |
136 | io_alias_.clear(); |
137 | |
138 | permuted_input_map_.clear(); |
139 | permuted_output_map_.clear(); |
140 | |
141 | all_tv_uses_valid_ = false; |
142 | is_during_update_uses_ = false; |
143 | } |
144 | |
145 | void Fusion::removeExpr(Expr* expr) { |
146 | assertInContainer(expr, "Cannot remove expr " ); |
147 | // If we hit this error too frequently, we could lighten the restrictions so |
148 | // that removing something that doesn't exist simply does nothing. For now, |
149 | // we're going with the strictest model which errors. |
150 | |
151 | for (auto out : expr->outputs()) { |
152 | out->setDefinition(nullptr); |
153 | } |
154 | |
155 | for (auto inp : expr->inputs()) { |
156 | auto uses_copy = inp->uses(); |
157 | auto it = std::find(uses_copy.begin(), uses_copy.end(), expr); |
158 | if (it != uses_copy.end()) { |
159 | uses_copy.erase(it); |
160 | inp->setUses(uses_copy); |
161 | } |
162 | } |
163 | |
164 | IrContainer::removeExpr(expr); |
165 | } |
166 | |
167 | void Fusion::removeVal(Val* val) { |
168 | assertInContainer(val, "Cannot remove val " ); |
169 | |
170 | TORCH_CHECK( |
171 | !val->isFusionInput(), |
172 | "Cannot remove val as it is an input of the fusion." ); |
173 | TORCH_CHECK( |
174 | !val->isFusionOutput(), |
175 | "Cannot remove val as it is an output of the fusion." ); |
176 | |
177 | Expr* orig = val->definition(); |
178 | if (orig != nullptr) |
179 | removeExpr(val->definition()); |
180 | |
181 | for (Expr* use : unordered_uses(val)) { |
182 | removeExpr(use); |
183 | } |
184 | IrContainer::removeVal(val); |
185 | } |
186 | |
187 | void Fusion::addInput(Val* input) { |
188 | assertInContainer(input, "Cannot register input " ); |
189 | |
190 | TORCH_INTERNAL_ASSERT( |
191 | input->getDataType() != DataType::Index, |
192 | "Data type Index is a local compile time data type only, it cannot be used as an input in case it was generated from another kernel." ); |
193 | |
194 | if (input->getValType().value() == ValType::TensorView) { |
195 | auto tv = input->as<TensorView>(); |
196 | tv->setMemoryType(MemoryType::Global); |
197 | } else if (input->getValType().value() == ValType::Scalar) { |
198 | TORCH_CHECK( |
199 | !input->isConst(), |
200 | "Immediate scalar value cannot be added as an input. It is not necessary to pass it as an input." ); |
201 | } |
202 | |
203 | inputs_.push_back(input); |
204 | input->setIsFusionInput(true); |
205 | |
206 | all_tv_uses_valid_ = false; |
207 | } |
208 | |
209 | void Fusion::addOutput(Val* output) { |
210 | // We currently don't support explicitly outputing aliased inputs. This is |
211 | // because they are already marked as output for in-place update. It's tricky |
212 | // to allow marking them explicitly as real output, since that requires us to |
213 | // register/identify output not only by `Val*` pointer, but also by indices; |
214 | // it also requires us to magically arrange `outputs_` entries in proper order |
215 | // ^^^ this doesn't look intuitive on `outputs_` in fusion. |
216 | // I think we can solve this by marking addOutput on io_alias_ keys after |
217 | // fusion is fully defined. Tracking this in #1488 |
218 | // Apparently we can't do this neither at the time. I think segmentation |
219 | // unfortunately would call addOutput after we marked io_alias_ map. |
220 | // TORCH_CHECK(io_alias_.count(output) == 0, |
221 | // "can't register aliased output as real output"); |
222 | |
223 | assertInContainer(output, "Cannot register output " ); |
224 | if (output->getValType().value() == ValType::TensorView) { |
225 | auto tv = output->as<TensorView>(); |
226 | tv->setMemoryType(MemoryType::Global); |
227 | } |
228 | outputs_.push_back(output); |
229 | output->setIsFusionOutput(true); |
230 | |
231 | all_tv_uses_valid_ = false; |
232 | } |
233 | |
234 | void Fusion::removeInput(Val* input) { |
235 | auto find_input = std::find(inputs_.begin(), inputs_.end(), input); |
236 | if (find_input != inputs_.end()) { |
237 | inputs_.erase(find_input); |
238 | } |
239 | input->setIsFusionInput(false); |
240 | all_tv_uses_valid_ = false; |
241 | } |
242 | |
243 | void Fusion::removeOutput(Val* output) { |
244 | auto find_output = std::find(outputs_.begin(), outputs_.end(), output); |
245 | if (find_output != outputs_.end()) { |
246 | outputs_.erase(find_output); |
247 | } |
248 | output->setIsFusionOutput(false); |
249 | all_tv_uses_valid_ = false; |
250 | } |
251 | |
252 | void Fusion::replaceOutput(Val* output, Val* replacement) { |
253 | auto find_output = std::find(outputs_.begin(), outputs_.end(), output); |
254 | TORCH_CHECK(find_output != outputs_.end(), "Unable to find output in Fusion" ); |
255 | |
256 | if (find_output != outputs_.end()) { |
257 | std::replace_if( |
258 | outputs_.begin(), |
259 | outputs_.end(), |
260 | [&output](Val* v) { return v == output; }, |
261 | replacement); |
262 | |
263 | if (replacement->getValType().value() == ValType::TensorView) { |
264 | replacement->setIsFusionOutput(true); |
265 | replacement->as<TensorView>()->setMemoryType(MemoryType::Global); |
266 | } |
267 | if (output->getValType().value() == ValType::TensorView) { |
268 | output->setIsFusionOutput(false); |
269 | output->as<TensorView>()->setMemoryType(MemoryType::Local); |
270 | } |
271 | resetTvUses(); |
272 | } |
273 | |
274 | // Temporary WAR for issue #1112 |
275 | // (https://github.com/csarofeen/pytorch/issues/1112) |
276 | if (io_alias_.count(output) != 0) { |
277 | auto input = io_alias_[output]; |
278 | io_alias_.erase(output); |
279 | io_alias_[replacement] = input; |
280 | } |
281 | } |
282 | |
283 | std::vector<Expr*> Fusion::exprs() { |
284 | return StmtSort::getExprs(this); |
285 | } |
286 | |
287 | std::vector<Val*> Fusion::inputsOf(Val* val) { |
288 | return InputsOf::output(this, val); |
289 | } |
290 | |
291 | void Fusion::validateInputs() { |
292 | std::unordered_set<Val*> all_inputs; |
293 | for (Val* out : outputs()) { |
294 | for (Val* input : inputsOf(out)) { |
295 | all_inputs.insert(input); |
296 | } |
297 | } |
298 | |
299 | std::unordered_set<Val*> input_dims; |
300 | auto inp_tvs = ir_utils::filterByType<TensorView>(inputs()); |
301 | for (auto tv : inp_tvs) { |
302 | for (auto id : tv->getMaybeRFactorDomain()) { |
303 | input_dims.emplace(id->extent()); |
304 | } |
305 | } |
306 | for (Val* input : all_inputs) { |
307 | if (!input->isConstScalar()) { |
308 | TORCH_CHECK( |
309 | input->isFusionInput() || |
310 | // TODO: Switch: |
311 | inContainer(input), |
312 | // to: input_dims.find(input) != input_dims.end(), |
313 | // https://github.com/csarofeen/pytorch/issues/1365 |
314 | "Could not figure out how " , |
315 | input->toString(), |
316 | " is generated, however it was not specified as an input." ); |
317 | } |
318 | } |
319 | } |
320 | |
321 | void Fusion::print() { |
322 | FUSER_PERF_SCOPE("Fusion::print" ); |
323 | |
324 | FusionGuard fg(this); |
325 | std::cout << "\n%kernel {\n" ; |
326 | IrMathPrinter op_exprs(std::cout); |
327 | op_exprs.handle(this); |
328 | std::cout << "\nTransformPrinter : \n" ; |
329 | IrTransformPrinter t_exprs(std::cout); |
330 | t_exprs.handle(this); |
331 | std::cout << "}\n\n" ; |
332 | } |
333 | |
334 | void Fusion::printKernel(DataType index_type) { |
335 | FUSER_PERF_SCOPE("Fusion::printKernel" ); |
336 | TORCH_INTERNAL_ASSERT( |
337 | !this->isA<kir::Kernel>(), |
338 | "Cannot \"print kernel\" of a kernel container. " , |
339 | "This would require lowering during lowering." ); |
340 | std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel()); |
341 | } |
342 | |
343 | std::unordered_map<std::string, std::pair<int, int>> Fusion::bankConflictInfo( |
344 | DataType index_type) { |
345 | GpuLower lower(this, index_type); |
346 | auto kernel = lower.kernel(); |
347 | auto info = getBankConflictInfo(kernel); |
348 | // The container of exprs goes out of scope, so we return a map of string here |
349 | std::unordered_map<std::string, std::pair<int, int>> result; |
350 | result.reserve(info.size()); |
351 | for (auto i : info) { |
352 | result[i.first->toString()] = i.second; |
353 | } |
354 | return result; |
355 | } |
356 | |
357 | void Fusion::printMath(bool from_outputs_only) { |
358 | FUSER_PERF_SCOPE("Fusion::printMath" ); |
359 | |
360 | FusionGuard fg(this); |
361 | auto exprs_for_print = exprs(); |
362 | std::cout << "Inputs:" << std::endl; |
363 | for (auto inp : inputs()) { |
364 | std::cout << " " << inp << ", " << inp->getDataType().value() << std::endl; |
365 | } |
366 | |
367 | std::cout << "Outputs:" << std::endl; |
368 | for (auto out : outputs()) { |
369 | std::cout << " " << out << ", " << out->getDataType().value() << std::endl; |
370 | } |
371 | |
372 | // If we want everything in the fusion, grab all values without uses to |
373 | // traverse from. |
374 | if (!from_outputs_only) { |
375 | std::vector<Val*> leaf_vals; |
376 | for (auto val : deterministic_vals()) { |
377 | if (val->uses().empty()) { |
378 | leaf_vals.push_back(val); |
379 | } |
380 | } |
381 | exprs_for_print = StmtSort::getExprs(this, leaf_vals); |
382 | } |
383 | |
384 | std::cout << "\n%kernel_math {\n" ; |
385 | for (auto expr : exprs_for_print) { |
386 | std::cout << expr; |
387 | } |
388 | std::cout << "}\n\n" ; |
389 | } |
390 | |
391 | std::vector<Val*> Fusion::inputsAndCreated() { |
392 | auto result = inputs_; |
393 | for (auto expr : exprs()) { |
394 | auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs()); |
395 | if (tv_inputs.empty()) { |
396 | for (auto v : expr->outputs()) { |
397 | result.emplace_back(v); |
398 | } |
399 | } |
400 | } |
401 | return result; |
402 | } |
403 | |
404 | void Fusion::printTransforms() { |
405 | FUSER_PERF_SCOPE("Fusion::printTransforms" ); |
406 | |
407 | FusionGuard fg(this); |
408 | IrTransformPrinter t_exprs(std::cout); |
409 | t_exprs.handle(this); |
410 | } |
411 | |
412 | void Fusion::registerVal(Val* val) { |
413 | if (inContainer(val)) { |
414 | return; |
415 | } |
416 | |
417 | if (val->fusion()) { |
418 | TORCH_CHECK( |
419 | val->fusion() == this, val, " was not found in the active fusion." ); |
420 | } |
421 | |
422 | IrContainer::registerVal(val); |
423 | } |
424 | |
425 | void Fusion::registerExpr(Expr* expr) { |
426 | if (inContainer(expr)) { |
427 | return; |
428 | } |
429 | |
430 | if (expr->fusion()) { |
431 | TORCH_CHECK( |
432 | expr->fusion() == this, expr, " was not found in the active fusion." ); |
433 | } |
434 | |
435 | IrContainer::registerExpr(expr); |
436 | |
437 | bool has_tv = false; |
438 | |
439 | for (Val* input : expr->inputs()) { |
440 | has_tv = has_tv || input->isA<TensorView>(); |
441 | assertInContainer(input, "Input to expr is invalid, " ); |
442 | auto uses_copy = input->uses(); |
443 | if (std::find(uses_copy.begin(), uses_copy.end(), expr) == |
444 | uses_copy.end()) { |
445 | uses_copy.push_back(expr); |
446 | input->setUses(uses_copy); |
447 | } |
448 | } |
449 | |
450 | // Kernel is the only container type that is non-ssa. This is mainly (maybe |
451 | // only) because of initialization expressions which would overwrite tensor |
452 | // view definitions. |
453 | bool is_ssa = !this->isA<kir::Kernel>(); |
454 | |
455 | for (Val* output : expr->outputs()) { |
456 | has_tv = has_tv || output->isA<TensorView>(); |
457 | assertInContainer(output, "Output to expr is invalid, " ); |
458 | if (output->definition() != nullptr && is_ssa) { |
459 | removeExpr(output->definition()); |
460 | } |
461 | if (is_ssa || (!is_ssa && output->definition() == nullptr)) { |
462 | output->setDefinition(expr); |
463 | } |
464 | } |
465 | |
466 | if (has_tv) { |
467 | resetTvUses(); |
468 | } |
469 | } |
470 | |
471 | void Fusion::resetTvUses() { |
472 | FUSER_PERF_SCOPE("Fusion::resetTvUses" ); |
473 | is_during_update_uses_ = true; |
474 | |
475 | // getExprs only uses definition, so even if we've modified uses already to |
476 | // remove dead exprs, this could reinsert them. getExprs is also boundeds by |
477 | // inputs as registered inputs will return nullptr as their definition. |
478 | const auto all_tvs = ir_utils::filterByType<TensorView>(vals_); |
479 | const auto used_exprs = StmtSort::getExprs(this); |
480 | |
481 | for (auto tv : all_tvs) { |
482 | tv->setUses({}); |
483 | } |
484 | |
485 | // Same as in register expr |
486 | for (auto expr : used_exprs) { |
487 | for (Val* input : expr->inputs()) { |
488 | auto uses_copy = input->uses(); |
489 | if (std::find(uses_copy.begin(), uses_copy.end(), expr) == |
490 | uses_copy.end()) { |
491 | uses_copy.push_back(expr); |
492 | input->setUses(uses_copy); |
493 | } |
494 | } |
495 | } |
496 | |
497 | all_tv_uses_valid_ = true; |
498 | is_during_update_uses_ = false; |
499 | } |
500 | |
501 | std::vector<Val*> Fusion::usedMathVals() { |
502 | // Note that using fusion->inputs() as the argument for the first |
503 | // parameter of getAllValsBetween does not grab all used vals as |
504 | // there can be vals that are created inside a fusion without using |
505 | // anything from inputs. See, for example, tv0 in the |
506 | // FusionOuterSplit test. |
507 | const auto inputs = InputsOf::outputs(this, outputs()); |
508 | auto used_math_vals = DependencyCheck::getAllValsBetween( |
509 | {inputs.begin(), inputs.end()}, outputs()); |
510 | // When an expre has multiple outputs and only some of them are |
511 | // used, the rest aren't included in used_math_vals as they are not |
512 | // used. However, we want them to be included as they must show up |
513 | // in the fusion. |
514 | std::vector<Val*> vals_to_add; |
515 | std::unordered_set<Val*> added_vals; |
516 | |
517 | for (auto val : used_math_vals) { |
518 | auto def = val->definition(); |
519 | if (def == nullptr || def->outputs().size() < 2) { |
520 | continue; |
521 | } |
522 | for (auto out : def->outputs()) { |
523 | if (std::find(used_math_vals.begin(), used_math_vals.end(), out) == |
524 | used_math_vals.end()) { |
525 | if (!added_vals.count(out)) { |
526 | vals_to_add.push_back(out); |
527 | added_vals.insert(out); |
528 | } |
529 | } |
530 | } |
531 | } |
532 | |
533 | used_math_vals.insert( |
534 | used_math_vals.end(), vals_to_add.begin(), vals_to_add.end()); |
535 | |
536 | return used_math_vals; |
537 | } |
538 | |
539 | std::vector<Val*> Fusion::terminatingMathVals() { |
540 | VectorOfUniqueEntries<Val*> result; |
541 | auto used_vals = usedMathVals(); |
542 | for (auto v : used_vals) { |
543 | // Locate the vals that are not expr outputs but have valid definitions. |
544 | if (unordered_uses(v).empty() && v->definition() != nullptr) { |
545 | result.pushBack(v); |
546 | } |
547 | } |
548 | return result.vector(); |
549 | } |
550 | |
551 | std::unordered_set<Expr*> Fusion::unordered_uses(const Val* val) const { |
552 | return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end()); |
553 | } |
554 | |
555 | Expr* Fusion::definition(const Val* val) const { |
556 | assertInContainer(val, "Cannot detect the definition of val, " ); |
557 | return val->definition(); |
558 | } |
559 | |
560 | // Indicate to kernel to set itself up to generate random numbers |
561 | bool Fusion::isStochastic() { |
562 | for (auto expr : exprs()) { |
563 | if (expr->getExprType() == ExprType::RNGOp) { |
564 | return true; |
565 | } |
566 | } |
567 | return false; |
568 | } |
569 | |
570 | std::vector<Val*> Fusion::getTerminatingOutputs() const { |
571 | FUSER_PERF_SCOPE("getTerminatingOutputs" ); |
572 | |
573 | auto is_reachable_to_output = [](Val* val) { |
574 | // traverse to consumers of val and see if there is an output |
575 | std::deque<Val*> consumers; |
576 | for (auto use : val->uses()) { |
577 | for (auto consumer : use->outputs()) { |
578 | consumers.push_back(consumer); |
579 | } |
580 | } |
581 | while (!consumers.empty()) { |
582 | auto consumer = consumers.back(); |
583 | consumers.pop_back(); |
584 | if (consumer->isFusionOutput()) { |
585 | return true; |
586 | } |
587 | // consumer is not an output; proceed to its consumers |
588 | for (auto use : consumer->uses()) { |
589 | for (auto consumer_of_consumer : use->outputs()) { |
590 | consumers.push_back(consumer_of_consumer); |
591 | } |
592 | } |
593 | } |
594 | return false; |
595 | }; |
596 | |
597 | std::vector<Val*> terminating_outputs; |
598 | |
599 | for (auto out : outputs()) { |
600 | // If there is another output reachable from this output, it's not |
601 | // terminating. |
602 | if (is_reachable_to_output(out)) { |
603 | continue; |
604 | } |
605 | terminating_outputs.push_back(out); |
606 | } |
607 | |
608 | return terminating_outputs; |
609 | } |
610 | |
611 | bool Fusion::isAliasCompatible(Val* left, Val* right) { |
612 | // Nullptr check |
613 | if (left == nullptr || right == nullptr) { |
614 | return false; |
615 | } |
616 | |
617 | // DataType check |
618 | if (!left->getDataType().has_value() || !right->getDataType().has_value() || |
619 | left->getDataType().value() != right->getDataType().value()) { |
620 | return false; |
621 | } |
622 | |
623 | // ValType check |
624 | if (!left->getValType().has_value() || !right->getValType().has_value() || |
625 | left->getValType().value() != right->getValType().value()) { |
626 | return false; |
627 | } |
628 | |
629 | // Check same number of dimensions if both values are TensorViews |
630 | if (ir_utils::isTV(left) && ir_utils::isTV(right)) { |
631 | return left->as<TensorView>()->nDims() == right->as<TensorView>()->nDims(); |
632 | } |
633 | return false; |
634 | } |
635 | |
636 | void Fusion::aliasOutputToInput(Val* output, Val* input) { |
637 | // Because we could cast output when input is cast. |
638 | TORCH_INTERNAL_ASSERT( |
639 | !output->isFusionOutput(), |
640 | "Do NOT add aliased output to fusion output outside of `aliasOutputToInput" ); |
641 | |
642 | if (!input->isFusionInput()) { |
643 | auto input_expr = input->definition(); |
644 | // TORCH_INTERNAL_ASSERT(input_def.etype() == ExprType::UnaryOp, "expected |
645 | // unary op for aliased input"); |
646 | TORCH_INTERNAL_ASSERT( |
647 | input_expr->isA<UnaryOp>(), "expected unary op for aliased input" ); |
648 | auto input_uop = input_expr->as<UnaryOp>(); |
649 | TORCH_INTERNAL_ASSERT( |
650 | input_uop->getUnaryOpType() == UnaryOpType::Cast, |
651 | "expected aliased input to be output of cast op" ); |
652 | input = input_uop->in(); |
653 | } |
654 | TORCH_INTERNAL_ASSERT( |
655 | input->getDataType().has_value() && output->getDataType().has_value(), |
656 | "requires DataType to be available for aliased output to input" ); |
657 | |
658 | if (input->getDataType().value() != output->getDataType().value()) { |
659 | output = castOp(input->getDataType().value(), output); |
660 | } |
661 | // TODO: output should be marked at the end of fusion definition #1488 |
662 | addOutput(output); |
663 | |
664 | TORCH_INTERNAL_ASSERT( |
665 | isAliasCompatible(input, output), |
666 | "The input and output values are not alias-compatible." ); |
667 | io_alias_[output] = input; |
668 | } |
669 | |
670 | Val* Fusion::getOutputAlias(Val* output) { |
671 | auto search = io_alias_.find(output); |
672 | if (search != io_alias_.end()) { |
673 | return search->second; |
674 | } |
675 | return nullptr; |
676 | } |
677 | |
678 | std::unordered_set<int> Fusion::getOutputAliasIndices() const { |
679 | if (io_alias_.empty()) { |
680 | return {}; |
681 | } |
682 | |
683 | std::unordered_set<int> alias_indices; |
684 | |
685 | for (const auto i : c10::irange(outputs_.size())) { |
686 | if (io_alias_.count(outputs_[i]) != 0) { |
687 | alias_indices.insert(i); |
688 | } |
689 | } |
690 | return alias_indices; |
691 | } |
692 | |
693 | std::vector<std::pair<int, int>> Fusion::getInputAliasIndices() const { |
694 | if (io_alias_.empty()) { |
695 | return {}; |
696 | } |
697 | |
698 | std::vector<std::pair<int, int>> alias_indices; |
699 | for (const auto i : c10::irange(outputs_.size())) { |
700 | if (io_alias_.count(outputs_[i]) != 0) { |
701 | bool found = false; |
702 | for (const auto j : c10::irange(inputs_.size())) { |
703 | if (io_alias_.at(outputs_[i]) == inputs_[j]) { |
704 | alias_indices.emplace_back(i, j); |
705 | found = true; |
706 | break; |
707 | } |
708 | } |
709 | TORCH_INTERNAL_ASSERT( |
710 | found, |
711 | "io_alias_ mapping failure, alias output is not present in inputs" ); |
712 | } |
713 | } |
714 | // can't assert here, we could have segmented fusion where not all alias |
715 | // outputs are present |
716 | |
717 | return alias_indices; |
718 | } |
719 | |
720 | } // namespace cuda |
721 | } // namespace fuser |
722 | } // namespace jit |
723 | } // namespace torch |
724 | |