1#include <c10/util/irange.h>
2#include <arith.h>
3#include <compute_at.h>
4#include <expr_evaluator.h>
5#include <fusion.h>
6#include <inlining.h>
7#include <ir_all_nodes.h>
8#include <ir_builder.h>
9#include <ir_cloner.h>
10#include <ir_interface_nodes.h>
11#include <ir_iostream.h>
12#include <ir_utils.h>
13#include <lower2device.h>
14#include <lower_double_buffer.h>
15#include <scheduler/mma_utils.h>
16
17// Cleanup
18#include <transform_iter.h>
19#include <transform_replay.h>
20
21namespace torch {
22namespace jit {
23namespace fuser {
24namespace cuda {
25
26namespace {
27DataType aten_opt_type_map(const c10::optional<at::ScalarType>& scalar_type) {
28 return scalar_type.has_value() ? aten_to_data_type(scalar_type.value())
29 : DataType::Null;
30}
31} // namespace
32
33TensorView::TensorView(
34 IrBuilderPasskey passkey,
35 TensorDomain* domain,
36 DataType dtype,
37 MemoryType mtype)
38 : Val(passkey, ValType::TensorView, dtype),
39 domain_(domain),
40 memory_type_(mtype) {}
41
42TensorView::TensorView(
43 IrBuilderPasskey passkey,
44 const std::shared_ptr<c10::TensorType>& tensor_type)
45 : Val(passkey,
46 ValType::TensorView,
47 aten_opt_type_map(tensor_type->scalarType())) {
48 TORCH_INTERNAL_ASSERT(
49 !container()->isA<kir::Kernel>(),
50 "Function invalid for kernel container.");
51 std::vector<IterDomain*> sizes;
52
53 TORCH_CHECK(
54 tensor_type->dim().has_value(), "Requires static rank for Tensor");
55
56 for (const auto i : c10::irange(tensor_type->dim().value())) {
57 if (tensor_type->sizes()[i].has_value() &&
58 tensor_type->sizes()[i].value() == 1) {
59 // If size is known to be 1, assuem it needs to be broadcasted.
60 sizes.push_back(
61 IterDomainBuilder(
62 passkey.ir_container_->zeroVal(), passkey.ir_container_->oneVal())
63 .iter_type(IterType::Broadcast)
64 .build());
65 } else {
66 sizes.push_back(
67 IterDomainBuilder(
68 passkey.ir_container_->zeroVal(), IrBuilder::create<Int>())
69 .build());
70 }
71 }
72 // [ Note -- stride_properties in tensor type ]
73 //
74 // `stride_properties()` returns a vector<optional<Stride>>, while
75 // Stride {
76 // optional<size_t> stride_index_;
77 // optional<bool> contiguous_;
78 // optional<size_t> stride_;
79 // };
80 // To keep things simple, we ignore all the optional wrapper, as in reality,
81 // they would always be available unless we start doing multiple profiling
82 // runs.
83 //
84 // `stride_properties()` returns the vector of Stride, where it is ordered
85 // from the fastest to slowest dimensions. i.e. stride_properties()[i] would
86 // give us the i-th fastest dimension. where:
87 // 1. `Stride::stride_index_` gives the index to the dimension;
88 // 2. `Stride::contiguous_` indicates whether this dimension is
89 // memory-dense*;
90 // 3. `Stride::stride_` is the actual stride for the given dimension.
91 // * note that memory-dense means different things depending on the order of
92 // the dimension. checkout `TensorType::computeStrideProps` for details
93
94 // default to non_contiguous;
95 std::vector<bool> contig_info(tensor_type->dim().value(), false);
96
97 // we iterate through stride_index_, which goes from fastest changing
98 // dimension to slowest, instead of iterating through sizes. This allows
99 // easier contiguity check;
100 for (const auto i : c10::irange(tensor_type->dim().value())) {
101 // if we don't have contiguous dimension at current stride index, don't
102 // bother;
103 const auto& stride_property_i = tensor_type->stride_properties()[i];
104 if (stride_property_i.has_value() &&
105 stride_property_i->stride_index_.has_value() &&
106 stride_property_i->contiguous_.has_value() &&
107 stride_property_i->contiguous_.value() == true) {
108 const size_t index = stride_property_i->stride_index_.value();
109 if (i == 0) {
110 // mark fastest changing dimension collapsible only when it's the last
111 // dim;
112 contig_info[index] = (index == tensor_type->dim().value() - 1);
113 } else {
114 // check the neighboring faster dimension, collapse if it is considered
115 // as inner dimension per stride_index
116 auto inner_index_opt =
117 tensor_type->stride_properties()[static_cast<int>(i) - 1]
118 ->stride_index_;
119 if (inner_index_opt.has_value() &&
120 inner_index_opt.value() == (index + 1)) {
121 // collapse if inner dimension has non-broadcasted strides
122 auto inner_stride_opt =
123 tensor_type->stride_properties()[static_cast<int>(i) - 1]
124 ->stride_;
125 contig_info[index] =
126 inner_stride_opt.has_value() && inner_stride_opt.value() != 0;
127 }
128 }
129 }
130 }
131
132 domain_ = IrBuilder::create<TensorDomain>(sizes, contig_info);
133}
134
135TensorView::TensorView(
136 IrBuilderPasskey passkey,
137 const std::shared_ptr<Value>& jit_value)
138 : TensorView(passkey, jit_value->type()->cast<c10::TensorType>()) {
139 TORCH_INTERNAL_ASSERT(
140 !container()->isA<kir::Kernel>(),
141 "Function invalid for kernel container.");
142}
143
144void TensorView::convertRfactorToRootDomain() {
145 // For a given TensorView, does its domain (root / rfactor) contain any
146 // concrete sized extents?
147 auto is_concrete_tensor = [](TensorView* tv) {
148 for (auto id : tv->getMaybeRFactorDomain()) {
149 if (!id->extent()->isConstScalar()) {
150 return false;
151 }
152 }
153 return true;
154 };
155
156 // Create a new root domain and replacement TensorDomain.
157 // Given an rfactor domain, create a new IterDomain.
158 // Otherwise, clone the previous IterDomain
159 auto createReplacementDomain =
160 [this](const std::vector<Val*>& replacement_extents) {
161 TORCH_INTERNAL_ASSERT(
162 !replacement_extents.empty() &&
163 getMaybeRFactorDomain().size() == replacement_extents.size());
164 size_t idx = 0;
165 std::vector<IterDomain*> new_root_domain(
166 getMaybeRFactorDomain().size());
167 for (const auto& id : getMaybeRFactorDomain()) {
168 if (replacement_extents[idx] != nullptr) {
169 new_root_domain[idx] = IterDomainBuilder(id)
170 .extent(replacement_extents[idx])
171 .resetSchedulingParams()
172 .build();
173 ++idx;
174 } else {
175 TORCH_INTERNAL_ASSERT(!id->isRFactorProduct());
176 new_root_domain[idx++] = id->cloneWithoutRFactor();
177 }
178 }
179
180 TORCH_INTERNAL_ASSERT(
181 new_root_domain.size() == domain()->contiguity().size());
182 setDomain(IrBuilder::create<TensorDomain>(
183 container(), new_root_domain, domain()->contiguity()));
184 };
185
186 std::vector<Val*> rfactor_extents;
187 std::unordered_map<Val*, Val*> replacement_map;
188 const auto kThisIsConcreteTensor = is_concrete_tensor(this);
189 for (const auto& id : getMaybeRFactorDomain()) {
190 if (id->isRFactorProduct()) {
191 // Create new symbolic extents for rfactor iterDomains
192 auto domain_extent = (!kThisIsConcreteTensor)
193 ? IrBuilder::create<Int>(container())
194 : id->extent();
195 rfactor_extents.push_back(domain_extent);
196 replacement_map.emplace(id->extent(), domain_extent);
197 } else {
198 rfactor_extents.push_back(nullptr);
199 }
200 }
201 createReplacementDomain(rfactor_extents);
202
203 // Propagate new extent throughout fusion using ValReplacementMutator
204 ir_utils::replaceValue(fusion(), replacement_map);
205}
206
207TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
208 : Val(src, ir_cloner),
209 domain_(ir_cloner->clone(src->domain_)),
210 compute_at_pos_(src->compute_at_pos_),
211 max_producer_pos_(src->max_producer_pos_),
212 memory_type_(src->memory_type_),
213 swizzle_type_(src->swizzle_type_),
214 is_double_buffered_(src->is_double_buffered_),
215 is_circular_buffered_(src->is_circular_buffered_),
216 circular_buffer_stage_(src->circular_buffer_stage_),
217 cpu_scalar_(src->cpu_scalar_),
218 has_swizzle_op_(src->has_swizzle_op_) {
219 for (const auto id : src->axesToSwizzle()) {
220 axes_to_swizzle_.push_back(ir_cloner->clone(id));
221 }
222}
223
224bool TensorView::hasAnyReduction() const {
225 return domain()->noReductions().size() != domain()->domain().size();
226}
227
228bool TensorView::hasReduction() const {
229 return domain()->hasReduction();
230}
231
232bool TensorView::hasBlockReduction() const {
233 return domain()->hasBlockReduction();
234}
235
236bool TensorView::hasGridReduction() const {
237 return domain()->hasGridReduction();
238}
239
240bool TensorView::hasBroadcast() const {
241 return domain()->hasBroadcast();
242}
243
244bool TensorView::hasRFactor() const {
245 return domain()->hasRFactor();
246}
247
248c10::optional<unsigned int> TensorView::getReductionAxis() const {
249 return domain()->getReductionAxis();
250}
251
252const std::vector<IterDomain*>& TensorView::getRootDomain() const {
253 return domain()->getRootDomain();
254};
255
256const std::vector<IterDomain*>& TensorView::getRFactorDomain() const {
257 return domain()->getRFactorDomain();
258};
259
260const std::vector<IterDomain*>& TensorView::getMaybeRFactorDomain() const {
261 return domain()->getMaybeRFactorDomain();
262};
263
264std::vector<IterDomain*>::size_type TensorView::nDims() const {
265 return domain()->nDims();
266}
267
268// sets cpu_scalar_ value, which is special handling for CPU based zero-dim
269// tensors (i.e. CPU Tensors that only have one value). This is only used if
270// on an input value, otherwise ignored. This is important as special handling
271// because these "scalars" should be type promoted as a tensor, but we want to
272// avoid explicit copying of the data, so we want to pass the data value as a
273// standard kernel argument value.
274void TensorView::setCpuScalar(bool is_cpu_scalar) {
275 TORCH_INTERNAL_ASSERT(
276 nDims() == 0, "Only 0-dim tensors can be marked as a cpu scalar.");
277 cpu_scalar_ = is_cpu_scalar;
278}
279
280IterDomain* TensorView::axis(int pos) const {
281 TORCH_INTERNAL_ASSERT(
282 nDims() > 0, "Tried to access an axis in a 0-dim TensorView");
283 if (pos < 0)
284 pos += domain()->nDims();
285 TORCH_CHECK(
286 pos >= 0 && (unsigned int)pos < domain()->nDims(),
287 "Tried to access position ",
288 pos,
289 " in domain: ",
290 domain());
291 return domain()->axis(pos);
292}
293
294void TensorView::inlineAt(
295 int64_t pos,
296 bool best_effort,
297 MaxPosCalculator* calc) {
298 TORCH_INTERNAL_ASSERT(
299 !container()->isA<kir::Kernel>(),
300 "Function invalid for kernel container.");
301
302 std::unique_ptr<MaxPosCalculator> calc_owner;
303 if (calc == nullptr) {
304 calc_owner = std::make_unique<MaxPosCalculator>();
305 calc = calc_owner.get();
306 }
307
308 if (pos < 0) {
309 pos += int64_t(nDims()) + 1;
310 }
311
312 TORCH_INTERNAL_ASSERT(
313 pos >= 0 && pos <= (int64_t)nDims(),
314 "Invalid inline position for T",
315 name(),
316 ": ",
317 pos);
318
319 auto max_inline_pos = calc->getMaxPosAll(this, best_effort);
320
321 if (best_effort) {
322 pos = std::min<int64_t>(max_inline_pos, pos);
323 }
324
325 // hoist inner most broadcast
326 while (pos > 0 && axis(pos - 1)->isBroadcast()) {
327 pos--;
328 }
329
330 TORCH_INTERNAL_ASSERT(
331 pos <= (int64_t)max_inline_pos,
332 "Invalid inline position for T",
333 name(),
334 ": ",
335 pos,
336 ". Maximum allowed value:",
337 max_inline_pos);
338
339 if (isFusionInput()) {
340 return;
341 }
342
343 if (pos > compute_at_pos_) {
344 compute_at_pos_ = pos;
345 for (auto consumer : ir_utils::consumerTvsOf(this)) {
346 consumer->updateMaxProducerPosition();
347 }
348 }
349}
350
351namespace {
352
353// Try to find the aligned position on consumer's domain corresponding to the
354// compute at position of producer domain. No checking on actual
355// producer-consumer relationship.
356unsigned int getConsumerPosAlignedToProducerCA(
357 TensorView* consumer,
358 TensorView* producer) {
359 // Locate consumer's position that aligns with
360 // the producer's new compute at axis. We need broadcast axes forwarded so we
361 // need to replay PasC as CasP will not forward braodcast dims. For example
362 // if we have:
363 // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
364 // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
365 // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
366 // NVFuserTest.FusionComplexBCast1_CUDA
367
368 auto disjoint_sets =
369 BestEffortReplay::replayPasC(
370 producer, consumer, -1, PairwiseRootDomainMap(producer, consumer))
371 .getDisjointSets();
372
373 // Find the innermost position of consumer that has
374 // been mapped within the producer ca axis.
375 unsigned int consumer_pos = consumer->nDims();
376 while (consumer_pos > 0) {
377 auto consumer_id = consumer->axis((int)consumer_pos - 1);
378 auto p_dom = producer->domain()->domain();
379 if (std::any_of(
380 p_dom.begin(),
381 p_dom.begin() + producer->getComputeAtPosition(),
382 [&consumer_id, &disjoint_sets](IterDomain* p_id) {
383 return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
384 })) {
385 break;
386 }
387 consumer_pos--;
388 }
389
390 return consumer_pos;
391}
392
393} // namespace
394
395void TensorView::updateMaxProducerPosition() {
396 TORCH_INTERNAL_ASSERT(
397 !container()->isA<kir::Kernel>(),
398 "Function invalid for kernel container.");
399 for (auto producer : ir_utils::producerTvsOf(this)) {
400 max_producer_pos_ = std::max(
401 max_producer_pos_, getConsumerPosAlignedToProducerCA(this, producer));
402 }
403}
404
405TensorView* TensorView::computeAt(
406 TensorView* consumer,
407 int position,
408 ComputeAtMode mode) {
409 TORCH_INTERNAL_ASSERT(
410 !container()->isA<kir::Kernel>(),
411 "Function invalid for kernel container.");
412 // Make sure this and consumer are not the same tensor, that's illegal
413 TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
414
415 // We support negative axes, so increment it by consumer->nDims() + 1 and make
416 // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
417 // means producer will be computed inline with consumer, hence the +1.
418 if (position < 0)
419 position += int(consumer->nDims()) + 1;
420
421 TORCH_CHECK(
422 (position >= 0 && (unsigned int)position < consumer->nDims() + 1) ||
423 mode == ComputeAtMode::BestEffort,
424 "Compute at called on an position outside valid range.");
425
426 if (mode == ComputeAtMode::BestEffort) {
427 position = std::max(-1, position);
428 position = std::min((int)consumer->nDims(), position);
429 }
430
431 ComputeAt::runAt(this, consumer, (unsigned int)position, mode);
432
433 return this;
434}
435
436TensorView* TensorView::computeWith(
437 TensorView* consumer,
438 int position,
439 ComputeAtMode mode) {
440 TORCH_INTERNAL_ASSERT(
441 !container()->isA<kir::Kernel>(),
442 "Function invalid for kernel container.");
443 // Make sure this and consumer are not the same tensor, that's illegal
444 TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
445
446 // We support negative axes, so increment it by this->nDims() + 1 and make
447 // sure the result is within this->nDims() + 1. being at this->nDims()
448 // means producer will be computed inline with this, hence the +1.
449 if (position < 0)
450 position += int(this->nDims()) + 1;
451 TORCH_CHECK(
452 position >= 0 && (unsigned int)position < this->nDims() + 1,
453 "Compute at called on an position outside valid range.");
454
455 ComputeAt::runWith(this, consumer, (unsigned int)position, mode);
456
457 return this;
458}
459
460TensorView* TensorView::split(
461 int axis_,
462 Val* factor,
463 bool inner_split,
464 bool trim_out_of_bounds) {
465 // Only check things associated with axis, factor will be validated in
466 // IterDomain
467 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
468
469 if (axis_ < 0)
470 axis_ += domain()->nDims();
471
472 TORCH_INTERNAL_ASSERT(
473 axis_ >= 0,
474 "Split axis is less than 0 even after adjusting for nDims: ",
475 axis_);
476
477 TORCH_CHECK(
478 axis_ >= (int)getComputeAtPosition(),
479 "Cannot split axis within compute at position. Axis = ",
480 axis_,
481 " computeAtPosition = ",
482 getComputeAtPosition());
483
484 TORCH_CHECK(
485 axis_ >= (int)getMaxProducerPosition(),
486 "Cannot split axis within max producer position. Axis = ",
487 axis_,
488 " maxProducerPosition = ",
489 getMaxProducerPosition());
490
491 TORCH_CHECK(
492 axis(axis_)->getParallelType() == ParallelType::Serial,
493 "Splitting an axis of non-Serial parallel type is not supported at this time."
494 " Parallelization strategy must be set after calling split.");
495
496 domain()->split(axis_, factor, inner_split, trim_out_of_bounds);
497 return this;
498}
499
500TensorView* TensorView::split(
501 int axis,
502 unsigned int factor,
503 bool inner_split,
504 bool trim_out_of_bounds) {
505 split(axis, IrBuilder::create<Int>(factor), inner_split, trim_out_of_bounds);
506 return this;
507}
508
509// Merge "axis_o" and "axis_i" into 1 dimension
510TensorView* TensorView::merge(int axis_o, int axis_i) {
511 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView");
512
513 if (axis_o < 0)
514 axis_o += domain()->nDims();
515
516 if (axis_i < 0)
517 axis_i += domain()->nDims();
518
519 TORCH_CHECK(
520 axis_o >= (int)getComputeAtPosition() &&
521 axis_i >= (int)getComputeAtPosition(),
522 false,
523 "Cannot merge axes within compute at position. Either axis ",
524 axis_o,
525 " or ",
526 axis_i,
527 " are within computeAtPosition = ",
528 getComputeAtPosition());
529
530 TORCH_CHECK(
531 axis_o >= (int)getMaxProducerPosition() &&
532 axis_i >= (int)getMaxProducerPosition(),
533 "Cannot merge axes within max producer position. Either axis ",
534 axis_o,
535 " or ",
536 axis_i,
537 " are within maxProducerPosition = ",
538 getMaxProducerPosition());
539
540 TORCH_CHECK(
541 axis(axis_o)->getParallelType() == ParallelType::Serial ||
542 axis(axis_i)->getParallelType() == ParallelType::Serial,
543 "Merging axes of non-Serial parallel type is not supported at this time."
544 " Parallelization strategy must be set after calling split.");
545
546 domain()->merge(axis_o, axis_i);
547 return this;
548}
549
550TensorView* TensorView::reorder(const std::unordered_map<int, int>& old2new_) {
551 TORCH_INTERNAL_ASSERT(
552 !container()->isA<kir::Kernel>(),
553 "Function invalid for kernel container.");
554 TORCH_INTERNAL_ASSERT(
555 !(nDims() == 0 && old2new_.size() > 0),
556 "Tried to reorder a 0-dim TensorView");
557
558 for (auto entry : old2new_) {
559 auto old_pos = entry.first < 0 ? entry.first + (int)nDims() : entry.first;
560 auto new_pos =
561 entry.second < 0 ? entry.second + (int)nDims() : entry.second;
562 if (old_pos == new_pos) {
563 continue;
564 }
565 TORCH_INTERNAL_ASSERT(
566 old_pos >= 0,
567 "Found \"old\" position that's less than 0 even though already adjusted by nDims: ",
568 old_pos);
569 TORCH_INTERNAL_ASSERT(
570 new_pos >= 0,
571 "Found \"new\" position that's less than 0 even though already adjusted by nDims: ",
572 new_pos);
573 TORCH_CHECK(
574 old_pos >= (int)getComputeAtPosition() &&
575 new_pos >= (int)getComputeAtPosition(),
576 "Cannot reorder axes within compute at position. Either axis ",
577 old_pos,
578 " or ",
579 new_pos,
580 " are within computeAtPosition = ",
581 getComputeAtPosition());
582
583 TORCH_CHECK(
584 old_pos >= (int)getMaxProducerPosition() &&
585 new_pos >= (int)getMaxProducerPosition(),
586 "Cannot reorder axes within max producer position. Either axis ",
587 old_pos,
588 " or ",
589 new_pos,
590 " are within maxProducerPosition = ",
591 getMaxProducerPosition());
592 }
593
594 domain()->reorder(old2new_);
595 return this;
596}
597
598TensorView* TensorView::swizzle(
599 SwizzleType type,
600 const std::vector<int>& axes) {
601 TORCH_INTERNAL_ASSERT(
602 !container()->isA<kir::Kernel>(),
603 "Function invalid for kernel container.");
604 swizzle_type_ = type;
605
606 // Clear previously set swizzle axes if any
607 if (axes_to_swizzle_.size()) {
608 axes_to_swizzle_.clear();
609 }
610
611 if (swizzle_type_ == SwizzleType::Transpose) {
612 TORCH_CHECK(
613 axes.size() == 2,
614 "Invalid axis list: ",
615 axes,
616 ". Number of axes must be two.");
617 TORCH_CHECK(
618 axes[0] != axes[1],
619 "Invalid axis list: ",
620 axes,
621 ". Two distinctive axes must be given.");
622 TORCH_CHECK(
623 getMemoryType() == MemoryType::Shared,
624 "Transpose swizzle is meant for tensors on shared memory.");
625 for (auto pos : axes) {
626 if (pos < 0) {
627 pos += nDims();
628 }
629 TORCH_CHECK(pos >= 0 && pos < (int)nDims(), "Invalid axis: ", pos);
630 TORCH_CHECK(
631 pos >= (int)getComputeAtPosition(),
632 "Invalid axis: ",
633 pos,
634 ". Axis outside computeAt position is not allocated.");
635 TORCH_CHECK(
636 !axis(pos)->isReduction(),
637 "Invalid axis: ",
638 pos,
639 ". Swizzling a reduction axis is not supported");
640 TORCH_CHECK(
641 !axis(pos)->isBroadcast(),
642 "Invalid axis: ",
643 pos,
644 ". Swizzling a broadcast axis is not supported");
645 axes_to_swizzle_.push_back(axis(pos));
646 }
647 }
648
649 return this;
650}
651
652TensorView* TensorView::swizzle(
653 Swizzle2DType swizzle_type,
654 int x,
655 int y,
656 SwizzleMode swizzle_mode) {
657 has_swizzle_op_ = true;
658 if (x < 0) {
659 x += domain()->nDims();
660 }
661 if (y < 0) {
662 y += domain()->nDims();
663 }
664
665 TORCH_CHECK(
666 x >= (int)getComputeAtPosition(),
667 false,
668 "Cannot swizzle axes within compute at position. Axis ",
669 x,
670 " is within computeAtPosition = ",
671 getComputeAtPosition());
672
673 TORCH_CHECK(
674 y >= (int)getMaxProducerPosition(),
675 "Cannot swizzle axes within max producer position. Axis ",
676 y,
677 " is within maxProducerPosition = ",
678 getMaxProducerPosition());
679
680 // Disable unsupported use cases at the current step.
681 // Currently do not support reducing or broadcasting
682 // swizzled dimensions.
683 auto all_inputs = InputsOf::outputs(fusion(), {axis(x), axis(y)});
684 for (auto id : ir_utils::filterByType<IterDomain>(all_inputs)) {
685 TORCH_INTERNAL_ASSERT(
686 !id->isBroadcast() && !id->isReduction(),
687 "Unsupported use case for swizzle.");
688 }
689
690 // Also checking that the scheduler is not trying to
691 // compose swizzles, which is not yet supported either.
692 auto all_exprs = DependencyCheck::getAllValsBetween(
693 {all_inputs.begin(), all_inputs.end()}, {axis(x), axis(y)});
694 for (auto expr : all_exprs) {
695 TORCH_INTERNAL_ASSERT(
696 !expr->isA<Swizzle2D>(), "Composing swizzles is not yet supported");
697 }
698
699 // Check swizzle specific constraints on the input axes:
700 if (swizzle_type != Swizzle2DType::ZShape) {
701 ExpressionEvaluator const_eval(fusion());
702
703 auto x_id = axis(x);
704 auto y_id = axis(y);
705
706 TORCH_INTERNAL_ASSERT(
707 x_id->extent()->isConstInt() && y_id->extent()->isConstInt(),
708 "Only constant iterdomains supported on given swizzle type");
709
710 int in_x_size = x_id->extent()->evaluateInt();
711 int in_y_size = y_id->extent()->evaluateInt();
712
713 // Check size constraints based on swizzle type
714 if (swizzle_type == Swizzle2DType::Transpose ||
715 swizzle_type == Swizzle2DType::XOR) {
716 TORCH_INTERNAL_ASSERT(
717 in_x_size == in_y_size, "Swizzle: equal dim iterdomains only");
718 }
719
720 if (swizzle_type == Swizzle2DType::Scatter) {
721 TORCH_INTERNAL_ASSERT(
722 in_y_size == 4, "Swizzle: unsupported id size must be 4 ", in_y_size);
723 TORCH_INTERNAL_ASSERT(
724 in_x_size == 8 || in_x_size == 16 || in_x_size == 32,
725 "Swizzle: unsupported id size must be 8, 16, or 32 ",
726 in_x_size);
727 }
728 }
729
730 domain()->swizzle(swizzle_type, x, y, swizzle_mode);
731
732 return this;
733}
734
735TensorView* TensorView::rFactor(const std::vector<int>& axes) {
736 TORCH_INTERNAL_ASSERT(
737 !container()->isA<kir::Kernel>(),
738 "Function invalid for kernel container.");
739 // TODO: I think we should do this but
740 // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the
741 // moment.
742
743 // TORCH_INTERNAL_ASSERT(
744 // !hasComputeAt(), "Cannot rfactor tensors after compute at has been
745 // set.");
746 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
747 FusionGuard fg(fusion());
748 TORCH_CHECK(
749 definition() != nullptr &&
750 (definition()->getExprType() == ExprType::ReductionOp ||
751 definition()->getExprType() == ExprType::MmaOp),
752 "Error rfactoring ",
753 this,
754 " its definition is either a nullptr or not a reduction.");
755 TORCH_CHECK(
756 !domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
757
758 TORCH_CHECK(
759 !definition()->isA<GroupedReductionOp>(),
760 "For GroupedReductionOp, use TensorView::rFactor(const std::vector<int>& axes, const std::vector<TensorView*>& tvs)");
761
762 // Split tensor view into 2 parts
763 auto domain_pair = domain()->rFactor(axes);
764
765 // Producer in the pair
766 auto producer_domain = domain_pair.first;
767 // Consumer in the pair
768 auto consumer_domain = domain_pair.second;
769
770 // This domain will be the consumer, so create the producer
771 TensorView* producer =
772 IrBuilder::create<TensorView>(producer_domain, getDataType().value());
773
774 // Set domain of consumer
775 setDomain(consumer_domain);
776 TensorView* consumer = this;
777
778 if (auto this_reduction = dynamic_cast<ReductionOp*>(definition())) {
779 // Setup dependency chain, inserting producer before this op.
780 // Expr* producer_definition =
781 IrBuilder::create<ReductionOp>(
782 this_reduction->getReductionOpType(),
783 this_reduction->init(),
784 producer,
785 this_reduction->in());
786
787 // Expr* consumer_definition =
788 IrBuilder::create<ReductionOp>(
789 this_reduction->getReductionOpType(),
790 this_reduction->init(),
791 consumer,
792 producer);
793 } else if (auto this_mma = dynamic_cast<MmaOp*>(definition())) {
794 // Initial reduction that still uses mma to combine
795 // the input.
796 IrBuilder::create<MmaOp>(
797 producer,
798 this_mma->inA(),
799 this_mma->inB(),
800 this_mma->init(),
801 this_mma->options());
802
803 // Remaining reduction that can be scheduled cross
804 // warp or cta.
805 IrBuilder::create<ReductionOp>(
806 BinaryOpType::Add, this_mma->init(), consumer, producer);
807 } else {
808 TORCH_INTERNAL_ASSERT(false, "RFactor: unsupported tensor definition");
809 }
810 return producer;
811}
812
813TensorView* TensorView::multiOutputRfactorHelper(
814 TensorView* tv,
815 const std::vector<int>& axes) {
816 TORCH_INTERNAL_ASSERT(
817 !container()->isA<kir::Kernel>(),
818 "Function invalid for kernel container.");
819 // Hack:
820 // Semantically we should always keep the outputs of multi reduction ops
821 // scheduled the same but the user end cannot guarantee that. In order to
822 // guarantee that the rFactor is defined meaningfully the scheduling of the
823 // output TV that got the rfactor call is force replayed towards the other two
824
825 if (!sameAs(tv)) {
826 auto root = tv->getRootDomain();
827 auto this_root = getRootDomain();
828
829 // construct a trivial root domain map
830 std::unordered_map<IterDomain*, IterDomain*> id_map;
831 for (const auto i : c10::irange(root.size())) {
832 id_map[this_root[i]] = root[i];
833 }
834
835 // replay on the target tv
836 ReplayTransformations replay(domain()->domain(), id_map);
837
838 // construct the new tensor domain
839 std::vector<IterDomain*> new_id;
840 for (auto id : domain()->domain()) {
841 TORCH_INTERNAL_ASSERT(
842 replay.getReplay().count(id), "Multi-output reduction replay failed");
843 new_id.push_back(replay.getReplay().at(id));
844 }
845
846 std::vector<bool> new_contig(
847 tv->domain()->contiguity().begin(), tv->domain()->contiguity().end());
848 // replace tensor domain of target tv
849 tv->setDomain(IrBuilder::create<TensorDomain>(
850 tv->getRootDomain(), new_id, new_contig));
851 }
852
853 // Split tensor view into 2 parts
854 auto domain_pair = tv->domain()->rFactor(axes);
855 // Producer in the pair
856 auto producer_domain = domain_pair.first;
857 // Consumer in the pair
858 auto consumer_domain = domain_pair.second;
859
860 // This domain will be the consumer, so create the producer
861 TensorView* producer =
862 IrBuilder::create<TensorView>(producer_domain, tv->getDataType().value());
863
864 // Set domain of consumer
865 tv->setDomain(consumer_domain);
866
867 return producer;
868}
869
870std::vector<TensorView*> TensorView::rFactor(
871 const std::vector<int>& axes,
872 const std::vector<TensorView*>& tvs) {
873 TORCH_CHECK(
874 !container()->isA<kir::Kernel>(),
875 "Function invalid for kernel container.");
876 TORCH_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
877 FusionGuard fg(fusion());
878 TORCH_CHECK(
879 definition() != nullptr && ir_utils::isReductionOp(definition()),
880 "Error rfactoring multi-output reduction op ",
881 this,
882 " its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");
883
884 TORCH_CHECK(
885 !domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
886
887 TORCH_CHECK(
888 definition()->outputs().size() == tvs.size(),
889 "Rfactor of a multi-output reduction not used correctly");
890
891 for (const auto i : c10::irange(tvs.size())) {
892 TORCH_CHECK(
893 definition()->output(i) == tvs.at(i),
894 "Rfactor of a multi-output reduction not used correctly");
895 }
896
897 // Currently grouping of welford is only supported through
898 // ParallelType::Group, so GroupedWelfordOp is only created during
899 // the lowering time. As rFactor is done before lowering, there
900 // should be no GroupedWelfordOp at this point.
901 TORCH_INTERNAL_ASSERT(
902 !definition()->isA<GroupedWelfordOp>(),
903 "GroupedWelfordOp found: ",
904 definition()->toString());
905
906 std::vector<TensorView*> rf_tvs(tvs.size());
907
908 // Make sure this gets rfactored last so everybody gets
909 // replayed correctly
910 for (const auto i : c10::irange(tvs.size())) {
911 if (this != tvs.at(i)) {
912 rf_tvs.at(i) = multiOutputRfactorHelper(tvs.at(i), axes);
913 }
914 }
915
916 for (const auto i : c10::irange(tvs.size())) {
917 if (this == tvs.at(i)) {
918 rf_tvs.at(i) = multiOutputRfactorHelper(tvs.at(i), axes);
919 }
920 }
921
922 if (auto wop = dynamic_cast<WelfordOp*>(definition())) {
923 TensorView* producer_avg = rf_tvs.at(0);
924 TensorView* producer_var = rf_tvs.at(1);
925 TensorView* producer_n = rf_tvs.at(2);
926
927 // Setup dependency chain, inserting producer before this op.
928 // Expr* producer_definition =
929 IrBuilder::create<WelfordOp>(
930 producer_avg,
931 producer_var,
932 producer_n,
933 wop->inAvg(),
934 wop->inVar(),
935 wop->inN(),
936 wop->initAvg(),
937 wop->initVar(),
938 wop->initN());
939
940 // Expr* consumer_definition =
941 IrBuilder::create<WelfordOp>(
942 wop->outAvg(),
943 wop->outVar(),
944 wop->outN(),
945 producer_avg,
946 producer_var,
947 producer_n,
948 wop->initAvg(),
949 wop->initVar(),
950 wop->initN());
951 } else if (
952 auto grouped_rop = dynamic_cast<GroupedReductionOp*>(definition())) {
953 IrBuilder::create<GroupedReductionOp>(
954 grouped_rop->getReductionOpTypes(),
955 grouped_rop->initVals(),
956 std::vector<Val*>{rf_tvs.begin(), rf_tvs.end()},
957 grouped_rop->inputs());
958
959 IrBuilder::create<GroupedReductionOp>(
960 grouped_rop->getReductionOpTypes(),
961 grouped_rop->initVals(),
962 grouped_rop->outputs(),
963 std::vector<Val*>{rf_tvs.begin(), rf_tvs.end()});
964 } else {
965 TORCH_INTERNAL_ASSERT(
966 false, "Invalid definition: ", definition()->toString());
967 }
968
969 return rf_tvs;
970}
971
972TensorView* TensorView::cacheBefore(c10::optional<LoadStoreOpType> cache_op) {
973 TORCH_INTERNAL_ASSERT(
974 !container()->isA<kir::Kernel>(),
975 "Function invalid for kernel container.");
976 FusionGuard fg(fusion());
977
978 TORCH_CHECK(
979 definition() != nullptr && !isFusionInput(),
980 "Error adding cacheBefore ",
981 this,
982 " its definition is a nullptr and we restrict using cacheBefore on an input.");
983
984 // Previously, caching computed-at tensors was allowed but was never
985 // really robust. Make it an error unless it is really needed.
986 TORCH_CHECK(
987 !hasComputeAt(),
988 "Caching computed-at tensors is not allowed. Apply caching before computeAt");
989
990 // It also did additional transformation when a producer tensor has computeAt.
991 // Make sure we no longer rely on that behavior.
992 if (definition() != nullptr) {
993 for (TensorView* producer_of_producer :
994 ir_utils::filterByType<TensorView>(definition()->inputs())) {
995 TORCH_CHECK(
996 !producer_of_producer->hasComputeAt(),
997 "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
998 }
999 }
1000
1001 // Create Producer Domain
1002 // This domain will be the consumer which needs a new domain, so replace the
1003 // producers domain with this domain.
1004
1005 TensorView* producer = IrBuilder::create<TensorView>(
1006 container(),
1007 IrBuilder::create<TensorDomain>(
1008 container(),
1009 domain()->getRootDomain(),
1010 domain()->getRFactorDomain(),
1011 domain()->domain(),
1012 domain()->contiguity()),
1013 getDataType().value());
1014
1015 // Set domain of consumer
1016 TensorView* consumer = this;
1017
1018 size_t i = 0;
1019 auto no_reduction_root_domain =
1020 TensorDomain::noReductions(getMaybeRFactorDomain());
1021 std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
1022 for (const auto& dom : no_reduction_root_domain) {
1023 new_root_domain[i++] = dom->cloneWithoutRFactor();
1024 }
1025
1026 consumer->setDomain(IrBuilder::create<TensorDomain>(
1027 container(),
1028 new_root_domain,
1029 std::vector<bool>(new_root_domain.size(), true)));
1030
1031 // Insert producer - Cache_Before (CB) - before this TV.
1032 // Before: Prev TV -> [Definition Op] -> This TV
1033 // After: Prev TV -> [Definition Op] -> New CB TV -> [Set Op] -> This TV
1034
1035 // Get inputs for origin expression
1036 auto expr_inputs = definition()->inputs();
1037 // Expr* producer_definition =
1038 ir_utils::replaceValInExpr(definition(), this, producer);
1039
1040 // Expr* producer_uses =
1041 if (cache_op.has_value()) {
1042 IrBuilder::create<LoadStoreOp>(
1043 container(), cache_op.value(), consumer, producer);
1044 } else {
1045 IrBuilder::create<UnaryOp>(
1046 container(), UnaryOpType::Set, consumer, producer);
1047 }
1048
1049 // definition_ is no longer valid
1050 // setDefinition(nullptr);
1051
1052 auto replayed_consumer_pair =
1053 TransformReplay::replayCasP(consumer, producer, -1);
1054 consumer->setDomain(replayed_consumer_pair.first);
1055
1056 return producer;
1057}
1058
1059TensorView* TensorView::cacheFork() {
1060 TORCH_INTERNAL_ASSERT(
1061 !container()->isA<kir::Kernel>(),
1062 "Function invalid for kernel container.");
1063 FusionGuard fg(fusion());
1064
1065 // Before: [Expr] -> This TV (Global Output) -> [Usage Expr]
1066 // After: [Expr] -> This TV (Local) -> [Usage Expr] > Next TV
1067 // (Fork) -> [Set Expr] -> New TV (Global Output)
1068
1069 TORCH_CHECK(
1070 this->isFusionOutput() && !this->uses().empty(),
1071 "Error adding cacheFork ",
1072 this,
1073 " this TensorView must be an output with subsequent uses");
1074
1075 // Previously, caching computed-at tensors was allowed but was never
1076 // really robust. Make it an error unless it is really needed.
1077 TORCH_CHECK(
1078 !hasComputeAt(),
1079 "Caching computed-at tensors is not allowed. Apply caching before computeAt");
1080
1081 // This domain will be the producer, so create the consumer
1082 auto root_domain = TensorDomain::noReductions(getMaybeRFactorDomain());
1083 TensorView* new_output = IrBuilder::create<TensorView>(
1084 container(),
1085 IrBuilder::create<TensorDomain>(
1086 container(),
1087 IterDomain::clone(root_domain),
1088 std::vector<bool>(root_domain.size(), true)),
1089 getDataType().value());
1090
1091 // Create write operation from this TV to new output
1092 IrBuilder::create<UnaryOp>(container(), UnaryOpType::Set, new_output, this);
1093
1094 // The new TV becomes an output.
1095 // New TV has global memory type.
1096 // This TV has local memory type.
1097 fusion()->replaceOutput(this, new_output);
1098
1099 // Transform new output according to this TV
1100 auto replayed_output_pair = TransformReplay::replayCasP(new_output, this, -1);
1101 new_output->setDomain(replayed_output_pair.first);
1102
1103 return new_output;
1104}
1105
1106TensorView* TensorView::cacheAfter(c10::optional<LoadStoreOpType> cache_op) {
1107 TORCH_INTERNAL_ASSERT(
1108 !container()->isA<kir::Kernel>(),
1109 "Function invalid for kernel container.");
1110 FusionGuard fg(fusion());
1111
1112 // Get all the uses for this Tensorview
1113 TORCH_CHECK(
1114 !isFusionOutput(),
1115 "Error adding cacheAfter ",
1116 this,
1117 " we restrict using cacheAfter on an output.");
1118
1119 // Previously, caching computed-at tensors was allowed but was never
1120 // really robust. Make it an error unless it is really needed.
1121 TORCH_CHECK(
1122 !hasComputeAt(),
1123 "Caching computed-at tensors is not allowed. Apply caching before computeAt.");
1124
1125 // It also did additional transformation when this tensor is an
1126 // input and the outputs of its consumers have computeAt. Make sure
1127 // we no longer rely on that behavior.
1128 if (isFusionInput()) {
1129 for (const auto& expr : uses()) {
1130 for (TensorView* output :
1131 ir_utils::filterByType<TensorView>(expr->outputs())) {
1132 TORCH_CHECK(
1133 !output->hasComputeAt(),
1134 "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
1135 }
1136 }
1137 }
1138
1139 // Create Consumer Domain
1140 // Keep Broadcast Axis (Permanent)
1141 // Remove Reduction Axis
1142 size_t i = 0;
1143 auto no_reduction_root_domain =
1144 TensorDomain::noReductions(getMaybeRFactorDomain());
1145 std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
1146 for (const auto& dom : no_reduction_root_domain) {
1147 new_root_domain[i++] = dom->cloneWithoutRFactor();
1148 }
1149
1150 // This domain will be the producer, so create the consumer
1151 TensorView* consumer = IrBuilder::create<TensorView>(
1152 container(),
1153 IrBuilder::create<TensorDomain>(
1154 container(),
1155 new_root_domain,
1156 std::vector<bool>(new_root_domain.size(), true)),
1157 getDataType().value());
1158
1159 // Set domain of producer - No Change
1160 TensorView* producer = this;
1161
1162 // Insert consumer - Cache_After (CA) - after this TV.
1163 // Before: This TV -> [Use Op] -> Next TV
1164 // After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV
1165
1166 // Expr* consumer_uses =
1167 for (auto expr : fusion()->unordered_uses(this)) {
1168 ir_utils::replaceValInExpr(expr, this, consumer);
1169 }
1170
1171 // Expr* consumer_definition =
1172 if (cache_op.has_value()) {
1173 IrBuilder::create<LoadStoreOp>(
1174 container(), cache_op.value(), consumer, producer);
1175 } else {
1176 IrBuilder::create<UnaryOp>(
1177 container(), UnaryOpType::Set, consumer, producer);
1178 }
1179
1180 return consumer;
1181}
1182
1183void TensorView::setMemoryType(MemoryType mt) {
1184 memory_type_ = mt;
1185 if (isFusionInput() || isFusionOutput()) {
1186 TORCH_INTERNAL_ASSERT(
1187 mt == MemoryType::Global,
1188 "Tried to set an input or output to the fusion to a non-global memory type.");
1189 }
1190}
1191
1192void TensorView::clearReductionIterDomains() {
1193 TORCH_INTERNAL_ASSERT(
1194 !domain()->hasRFactor(),
1195 "should not call clearReductionIterDomains on rfactor tv");
1196
1197 TORCH_INTERNAL_ASSERT(
1198 domain()->domain() == getRootDomain(),
1199 "should not call clearReductionIterDomains on already transformed TensorDomains");
1200
1201 std::vector<IterDomain*> new_root;
1202 std::vector<bool> new_contig;
1203 for (const auto i : c10::irange(getRootDomain().size())) {
1204 if (!getRootDomain()[i]->isReduction()) {
1205 new_root.push_back(getRootDomain()[i]);
1206 new_contig.push_back(domain()->contiguity()[i]);
1207 }
1208 }
1209
1210 setDomain(IrBuilder::create<TensorDomain>(container(), new_root, new_contig));
1211}
1212
1213void TensorView::doubleBuffer() {
1214 // Early correctness checking. May miss eventual errors as the
1215 // checks depend on memory types and parallelization, which may not
1216 // be finalized until lowering.
1217 validateDoubleBufferedTensor(this);
1218 is_double_buffered_ = true;
1219}
1220
1221void TensorView::circularBuffer(unsigned int stage) {
1222 // Early correctness checking. May miss eventual errors as the
1223 // checks depend on memory types and parallelization, which may not
1224 // be finalized until lowering.
1225 TORCH_INTERNAL_ASSERT(stage > 1, "Unsupported stage number");
1226 if (stage == 2) {
1227 // Re-direct to double buffer interface if stage is 2;
1228 doubleBuffer();
1229 return;
1230 }
1231 validateDoubleBufferedTensor(this);
1232 is_circular_buffered_ = true;
1233 circular_buffer_stage_ = stage;
1234}
1235
1236bool TensorView::isEmptyTensor() const {
1237 auto& root_domain = getMaybeRFactorDomain();
1238 return std::all_of(
1239 root_domain.begin(), root_domain.end(), [](IterDomain* id) {
1240 return id->extent()->isZeroInt();
1241 });
1242}
1243
1244void TensorView::applyMmaSwizzle(MmaOptions options) {
1245 switch (options.operand) {
1246 case MmaOptions::Operand::Accumulator:
1247 mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options);
1248 break;
1249 case MmaOptions::Operand::A:
1250 case MmaOptions::Operand::B:
1251 mma_util::WarpMmaSwizzler::scheduleOperandRead(this, options);
1252 break;
1253 default:
1254 TORCH_INTERNAL_ASSERT(false, "unknown operand flag");
1255 break;
1256 }
1257}
1258
1259TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) {
1260 TORCH_CHECK(shape_.empty() || shape_.size() == ndims);
1261 TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims);
1262 ndims_ = ndims;
1263 return *this;
1264}
1265
1266TensorViewBuilder& TensorViewBuilder::dtype(DataType dtype) {
1267 dtype_ = dtype;
1268 return *this;
1269}
1270
1271TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) {
1272 TORCH_CHECK(contiguity_.empty(), "Attempting to reset contiguity");
1273 if (!contiguity.empty()) {
1274 TORCH_CHECK(ndims_ == 0 || ndims_ == contiguity.size());
1275 ndims_ = contiguity.size();
1276 }
1277 contiguity_ = std::move(contiguity);
1278 return *this;
1279}
1280
1281TensorViewBuilder& TensorViewBuilder::shape(const std::vector<int64_t>& shape) {
1282 TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
1283 if (!shape.empty()) {
1284 TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
1285 ndims_ = shape.size();
1286 }
1287 shape_.clear();
1288 shape_.reserve(shape.size());
1289 for (int64_t i : shape) {
1290 if (i == -1) {
1291 shape_.emplace_back(IrBuilder::create<Int>());
1292 } else {
1293 TORCH_CHECK(
1294 i >= 0,
1295 "Invalid extent value. ",
1296 "For a tensor representing a single scalar use ndims = 0 with no sizes set.");
1297 shape_.emplace_back(IrBuilder::create<Int>(i));
1298 }
1299 }
1300 return *this;
1301}
1302
1303TensorViewBuilder& TensorViewBuilder::shape(std::vector<Val*> shape) {
1304 TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
1305 if (!shape.empty()) {
1306 TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
1307 ndims_ = shape.size();
1308 }
1309 shape_ = std::move(shape);
1310 return *this;
1311}
1312
1313TensorView* TensorViewBuilder::build() const {
1314 // Build the domain
1315 std::vector<IterDomain*> domain(ndims_, nullptr);
1316 for (const auto i : c10::irange(ndims_)) {
1317 if (shape_.empty()) {
1318 domain[i] =
1319 IterDomainBuilder(
1320 FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create<Int>())
1321 .build();
1322 } else {
1323 if (shape_[i]->isOneInt()) {
1324 // If size is known to be 1, assume it needs to be broadcasted.
1325 domain[i] = IterDomainBuilder(
1326 FusionGuard::getCurFusion()->zeroVal(),
1327 FusionGuard::getCurFusion()->oneVal())
1328 .iter_type(IterType::Broadcast)
1329 .build();
1330 } else {
1331 domain[i] =
1332 IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), shape_[i])
1333 .build();
1334 }
1335 }
1336 }
1337
1338 // Create the final TensorView
1339 return IrBuilder::create<TensorView>(
1340 IrBuilder::create<TensorDomain>(domain, contiguity_), dtype_);
1341}
1342
1343} // namespace cuda
1344} // namespace fuser
1345} // namespace jit
1346} // namespace torch
1347