1 | #include <c10/util/irange.h> |
2 | #include <arith.h> |
3 | #include <compute_at.h> |
4 | #include <expr_evaluator.h> |
5 | #include <fusion.h> |
6 | #include <inlining.h> |
7 | #include <ir_all_nodes.h> |
8 | #include <ir_builder.h> |
9 | #include <ir_cloner.h> |
10 | #include <ir_interface_nodes.h> |
11 | #include <ir_iostream.h> |
12 | #include <ir_utils.h> |
13 | #include <lower2device.h> |
14 | #include <lower_double_buffer.h> |
15 | #include <scheduler/mma_utils.h> |
16 | |
17 | // Cleanup |
18 | #include <transform_iter.h> |
19 | #include <transform_replay.h> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | namespace fuser { |
24 | namespace cuda { |
25 | |
26 | namespace { |
27 | DataType aten_opt_type_map(const c10::optional<at::ScalarType>& scalar_type) { |
28 | return scalar_type.has_value() ? aten_to_data_type(scalar_type.value()) |
29 | : DataType::Null; |
30 | } |
31 | } // namespace |
32 | |
33 | TensorView::TensorView( |
34 | IrBuilderPasskey passkey, |
35 | TensorDomain* domain, |
36 | DataType dtype, |
37 | MemoryType mtype) |
38 | : Val(passkey, ValType::TensorView, dtype), |
39 | domain_(domain), |
40 | memory_type_(mtype) {} |
41 | |
42 | TensorView::TensorView( |
43 | IrBuilderPasskey passkey, |
44 | const std::shared_ptr<c10::TensorType>& tensor_type) |
45 | : Val(passkey, |
46 | ValType::TensorView, |
47 | aten_opt_type_map(tensor_type->scalarType())) { |
48 | TORCH_INTERNAL_ASSERT( |
49 | !container()->isA<kir::Kernel>(), |
50 | "Function invalid for kernel container." ); |
51 | std::vector<IterDomain*> sizes; |
52 | |
53 | TORCH_CHECK( |
54 | tensor_type->dim().has_value(), "Requires static rank for Tensor" ); |
55 | |
56 | for (const auto i : c10::irange(tensor_type->dim().value())) { |
57 | if (tensor_type->sizes()[i].has_value() && |
58 | tensor_type->sizes()[i].value() == 1) { |
59 | // If size is known to be 1, assuem it needs to be broadcasted. |
60 | sizes.push_back( |
61 | IterDomainBuilder( |
62 | passkey.ir_container_->zeroVal(), passkey.ir_container_->oneVal()) |
63 | .iter_type(IterType::Broadcast) |
64 | .build()); |
65 | } else { |
66 | sizes.push_back( |
67 | IterDomainBuilder( |
68 | passkey.ir_container_->zeroVal(), IrBuilder::create<Int>()) |
69 | .build()); |
70 | } |
71 | } |
72 | // [ Note -- stride_properties in tensor type ] |
73 | // |
74 | // `stride_properties()` returns a vector<optional<Stride>>, while |
75 | // Stride { |
76 | // optional<size_t> stride_index_; |
77 | // optional<bool> contiguous_; |
78 | // optional<size_t> stride_; |
79 | // }; |
80 | // To keep things simple, we ignore all the optional wrapper, as in reality, |
81 | // they would always be available unless we start doing multiple profiling |
82 | // runs. |
83 | // |
84 | // `stride_properties()` returns the vector of Stride, where it is ordered |
85 | // from the fastest to slowest dimensions. i.e. stride_properties()[i] would |
86 | // give us the i-th fastest dimension. where: |
87 | // 1. `Stride::stride_index_` gives the index to the dimension; |
88 | // 2. `Stride::contiguous_` indicates whether this dimension is |
89 | // memory-dense*; |
90 | // 3. `Stride::stride_` is the actual stride for the given dimension. |
91 | // * note that memory-dense means different things depending on the order of |
92 | // the dimension. checkout `TensorType::computeStrideProps` for details |
93 | |
94 | // default to non_contiguous; |
95 | std::vector<bool> contig_info(tensor_type->dim().value(), false); |
96 | |
97 | // we iterate through stride_index_, which goes from fastest changing |
98 | // dimension to slowest, instead of iterating through sizes. This allows |
99 | // easier contiguity check; |
100 | for (const auto i : c10::irange(tensor_type->dim().value())) { |
101 | // if we don't have contiguous dimension at current stride index, don't |
102 | // bother; |
103 | const auto& stride_property_i = tensor_type->stride_properties()[i]; |
104 | if (stride_property_i.has_value() && |
105 | stride_property_i->stride_index_.has_value() && |
106 | stride_property_i->contiguous_.has_value() && |
107 | stride_property_i->contiguous_.value() == true) { |
108 | const size_t index = stride_property_i->stride_index_.value(); |
109 | if (i == 0) { |
110 | // mark fastest changing dimension collapsible only when it's the last |
111 | // dim; |
112 | contig_info[index] = (index == tensor_type->dim().value() - 1); |
113 | } else { |
114 | // check the neighboring faster dimension, collapse if it is considered |
115 | // as inner dimension per stride_index |
116 | auto inner_index_opt = |
117 | tensor_type->stride_properties()[static_cast<int>(i) - 1] |
118 | ->stride_index_; |
119 | if (inner_index_opt.has_value() && |
120 | inner_index_opt.value() == (index + 1)) { |
121 | // collapse if inner dimension has non-broadcasted strides |
122 | auto inner_stride_opt = |
123 | tensor_type->stride_properties()[static_cast<int>(i) - 1] |
124 | ->stride_; |
125 | contig_info[index] = |
126 | inner_stride_opt.has_value() && inner_stride_opt.value() != 0; |
127 | } |
128 | } |
129 | } |
130 | } |
131 | |
132 | domain_ = IrBuilder::create<TensorDomain>(sizes, contig_info); |
133 | } |
134 | |
135 | TensorView::TensorView( |
136 | IrBuilderPasskey passkey, |
137 | const std::shared_ptr<Value>& jit_value) |
138 | : TensorView(passkey, jit_value->type()->cast<c10::TensorType>()) { |
139 | TORCH_INTERNAL_ASSERT( |
140 | !container()->isA<kir::Kernel>(), |
141 | "Function invalid for kernel container." ); |
142 | } |
143 | |
144 | void TensorView::convertRfactorToRootDomain() { |
145 | // For a given TensorView, does its domain (root / rfactor) contain any |
146 | // concrete sized extents? |
147 | auto is_concrete_tensor = [](TensorView* tv) { |
148 | for (auto id : tv->getMaybeRFactorDomain()) { |
149 | if (!id->extent()->isConstScalar()) { |
150 | return false; |
151 | } |
152 | } |
153 | return true; |
154 | }; |
155 | |
156 | // Create a new root domain and replacement TensorDomain. |
157 | // Given an rfactor domain, create a new IterDomain. |
158 | // Otherwise, clone the previous IterDomain |
159 | auto createReplacementDomain = |
160 | [this](const std::vector<Val*>& replacement_extents) { |
161 | TORCH_INTERNAL_ASSERT( |
162 | !replacement_extents.empty() && |
163 | getMaybeRFactorDomain().size() == replacement_extents.size()); |
164 | size_t idx = 0; |
165 | std::vector<IterDomain*> new_root_domain( |
166 | getMaybeRFactorDomain().size()); |
167 | for (const auto& id : getMaybeRFactorDomain()) { |
168 | if (replacement_extents[idx] != nullptr) { |
169 | new_root_domain[idx] = IterDomainBuilder(id) |
170 | .extent(replacement_extents[idx]) |
171 | .resetSchedulingParams() |
172 | .build(); |
173 | ++idx; |
174 | } else { |
175 | TORCH_INTERNAL_ASSERT(!id->isRFactorProduct()); |
176 | new_root_domain[idx++] = id->cloneWithoutRFactor(); |
177 | } |
178 | } |
179 | |
180 | TORCH_INTERNAL_ASSERT( |
181 | new_root_domain.size() == domain()->contiguity().size()); |
182 | setDomain(IrBuilder::create<TensorDomain>( |
183 | container(), new_root_domain, domain()->contiguity())); |
184 | }; |
185 | |
186 | std::vector<Val*> rfactor_extents; |
187 | std::unordered_map<Val*, Val*> replacement_map; |
188 | const auto kThisIsConcreteTensor = is_concrete_tensor(this); |
189 | for (const auto& id : getMaybeRFactorDomain()) { |
190 | if (id->isRFactorProduct()) { |
191 | // Create new symbolic extents for rfactor iterDomains |
192 | auto domain_extent = (!kThisIsConcreteTensor) |
193 | ? IrBuilder::create<Int>(container()) |
194 | : id->extent(); |
195 | rfactor_extents.push_back(domain_extent); |
196 | replacement_map.emplace(id->extent(), domain_extent); |
197 | } else { |
198 | rfactor_extents.push_back(nullptr); |
199 | } |
200 | } |
201 | createReplacementDomain(rfactor_extents); |
202 | |
203 | // Propagate new extent throughout fusion using ValReplacementMutator |
204 | ir_utils::replaceValue(fusion(), replacement_map); |
205 | } |
206 | |
207 | TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) |
208 | : Val(src, ir_cloner), |
209 | domain_(ir_cloner->clone(src->domain_)), |
210 | compute_at_pos_(src->compute_at_pos_), |
211 | max_producer_pos_(src->max_producer_pos_), |
212 | memory_type_(src->memory_type_), |
213 | swizzle_type_(src->swizzle_type_), |
214 | is_double_buffered_(src->is_double_buffered_), |
215 | is_circular_buffered_(src->is_circular_buffered_), |
216 | circular_buffer_stage_(src->circular_buffer_stage_), |
217 | cpu_scalar_(src->cpu_scalar_), |
218 | has_swizzle_op_(src->has_swizzle_op_) { |
219 | for (const auto id : src->axesToSwizzle()) { |
220 | axes_to_swizzle_.push_back(ir_cloner->clone(id)); |
221 | } |
222 | } |
223 | |
224 | bool TensorView::hasAnyReduction() const { |
225 | return domain()->noReductions().size() != domain()->domain().size(); |
226 | } |
227 | |
228 | bool TensorView::hasReduction() const { |
229 | return domain()->hasReduction(); |
230 | } |
231 | |
232 | bool TensorView::hasBlockReduction() const { |
233 | return domain()->hasBlockReduction(); |
234 | } |
235 | |
236 | bool TensorView::hasGridReduction() const { |
237 | return domain()->hasGridReduction(); |
238 | } |
239 | |
240 | bool TensorView::hasBroadcast() const { |
241 | return domain()->hasBroadcast(); |
242 | } |
243 | |
244 | bool TensorView::hasRFactor() const { |
245 | return domain()->hasRFactor(); |
246 | } |
247 | |
248 | c10::optional<unsigned int> TensorView::getReductionAxis() const { |
249 | return domain()->getReductionAxis(); |
250 | } |
251 | |
252 | const std::vector<IterDomain*>& TensorView::getRootDomain() const { |
253 | return domain()->getRootDomain(); |
254 | }; |
255 | |
256 | const std::vector<IterDomain*>& TensorView::getRFactorDomain() const { |
257 | return domain()->getRFactorDomain(); |
258 | }; |
259 | |
260 | const std::vector<IterDomain*>& TensorView::getMaybeRFactorDomain() const { |
261 | return domain()->getMaybeRFactorDomain(); |
262 | }; |
263 | |
264 | std::vector<IterDomain*>::size_type TensorView::nDims() const { |
265 | return domain()->nDims(); |
266 | } |
267 | |
268 | // sets cpu_scalar_ value, which is special handling for CPU based zero-dim |
269 | // tensors (i.e. CPU Tensors that only have one value). This is only used if |
270 | // on an input value, otherwise ignored. This is important as special handling |
271 | // because these "scalars" should be type promoted as a tensor, but we want to |
272 | // avoid explicit copying of the data, so we want to pass the data value as a |
273 | // standard kernel argument value. |
274 | void TensorView::setCpuScalar(bool is_cpu_scalar) { |
275 | TORCH_INTERNAL_ASSERT( |
276 | nDims() == 0, "Only 0-dim tensors can be marked as a cpu scalar." ); |
277 | cpu_scalar_ = is_cpu_scalar; |
278 | } |
279 | |
280 | IterDomain* TensorView::axis(int pos) const { |
281 | TORCH_INTERNAL_ASSERT( |
282 | nDims() > 0, "Tried to access an axis in a 0-dim TensorView" ); |
283 | if (pos < 0) |
284 | pos += domain()->nDims(); |
285 | TORCH_CHECK( |
286 | pos >= 0 && (unsigned int)pos < domain()->nDims(), |
287 | "Tried to access position " , |
288 | pos, |
289 | " in domain: " , |
290 | domain()); |
291 | return domain()->axis(pos); |
292 | } |
293 | |
294 | void TensorView::inlineAt( |
295 | int64_t pos, |
296 | bool best_effort, |
297 | MaxPosCalculator* calc) { |
298 | TORCH_INTERNAL_ASSERT( |
299 | !container()->isA<kir::Kernel>(), |
300 | "Function invalid for kernel container." ); |
301 | |
302 | std::unique_ptr<MaxPosCalculator> calc_owner; |
303 | if (calc == nullptr) { |
304 | calc_owner = std::make_unique<MaxPosCalculator>(); |
305 | calc = calc_owner.get(); |
306 | } |
307 | |
308 | if (pos < 0) { |
309 | pos += int64_t(nDims()) + 1; |
310 | } |
311 | |
312 | TORCH_INTERNAL_ASSERT( |
313 | pos >= 0 && pos <= (int64_t)nDims(), |
314 | "Invalid inline position for T" , |
315 | name(), |
316 | ": " , |
317 | pos); |
318 | |
319 | auto max_inline_pos = calc->getMaxPosAll(this, best_effort); |
320 | |
321 | if (best_effort) { |
322 | pos = std::min<int64_t>(max_inline_pos, pos); |
323 | } |
324 | |
325 | // hoist inner most broadcast |
326 | while (pos > 0 && axis(pos - 1)->isBroadcast()) { |
327 | pos--; |
328 | } |
329 | |
330 | TORCH_INTERNAL_ASSERT( |
331 | pos <= (int64_t)max_inline_pos, |
332 | "Invalid inline position for T" , |
333 | name(), |
334 | ": " , |
335 | pos, |
336 | ". Maximum allowed value:" , |
337 | max_inline_pos); |
338 | |
339 | if (isFusionInput()) { |
340 | return; |
341 | } |
342 | |
343 | if (pos > compute_at_pos_) { |
344 | compute_at_pos_ = pos; |
345 | for (auto consumer : ir_utils::consumerTvsOf(this)) { |
346 | consumer->updateMaxProducerPosition(); |
347 | } |
348 | } |
349 | } |
350 | |
351 | namespace { |
352 | |
353 | // Try to find the aligned position on consumer's domain corresponding to the |
354 | // compute at position of producer domain. No checking on actual |
355 | // producer-consumer relationship. |
356 | unsigned int getConsumerPosAlignedToProducerCA( |
357 | TensorView* consumer, |
358 | TensorView* producer) { |
359 | // Locate consumer's position that aligns with |
360 | // the producer's new compute at axis. We need broadcast axes forwarded so we |
361 | // need to replay PasC as CasP will not forward braodcast dims. For example |
362 | // if we have: |
363 | // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) |
364 | // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will |
365 | // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to |
366 | // NVFuserTest.FusionComplexBCast1_CUDA |
367 | |
368 | auto disjoint_sets = |
369 | BestEffortReplay::replayPasC( |
370 | producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) |
371 | .getDisjointSets(); |
372 | |
373 | // Find the innermost position of consumer that has |
374 | // been mapped within the producer ca axis. |
375 | unsigned int consumer_pos = consumer->nDims(); |
376 | while (consumer_pos > 0) { |
377 | auto consumer_id = consumer->axis((int)consumer_pos - 1); |
378 | auto p_dom = producer->domain()->domain(); |
379 | if (std::any_of( |
380 | p_dom.begin(), |
381 | p_dom.begin() + producer->getComputeAtPosition(), |
382 | [&consumer_id, &disjoint_sets](IterDomain* p_id) { |
383 | return disjoint_sets.permissiveAreMapped(consumer_id, p_id); |
384 | })) { |
385 | break; |
386 | } |
387 | consumer_pos--; |
388 | } |
389 | |
390 | return consumer_pos; |
391 | } |
392 | |
393 | } // namespace |
394 | |
395 | void TensorView::updateMaxProducerPosition() { |
396 | TORCH_INTERNAL_ASSERT( |
397 | !container()->isA<kir::Kernel>(), |
398 | "Function invalid for kernel container." ); |
399 | for (auto producer : ir_utils::producerTvsOf(this)) { |
400 | max_producer_pos_ = std::max( |
401 | max_producer_pos_, getConsumerPosAlignedToProducerCA(this, producer)); |
402 | } |
403 | } |
404 | |
405 | TensorView* TensorView::computeAt( |
406 | TensorView* consumer, |
407 | int position, |
408 | ComputeAtMode mode) { |
409 | TORCH_INTERNAL_ASSERT( |
410 | !container()->isA<kir::Kernel>(), |
411 | "Function invalid for kernel container." ); |
412 | // Make sure this and consumer are not the same tensor, that's illegal |
413 | TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)" ); |
414 | |
415 | // We support negative axes, so increment it by consumer->nDims() + 1 and make |
416 | // sure the result is within consumer->nDims() + 1. being at consumer->nDims() |
417 | // means producer will be computed inline with consumer, hence the +1. |
418 | if (position < 0) |
419 | position += int(consumer->nDims()) + 1; |
420 | |
421 | TORCH_CHECK( |
422 | (position >= 0 && (unsigned int)position < consumer->nDims() + 1) || |
423 | mode == ComputeAtMode::BestEffort, |
424 | "Compute at called on an position outside valid range." ); |
425 | |
426 | if (mode == ComputeAtMode::BestEffort) { |
427 | position = std::max(-1, position); |
428 | position = std::min((int)consumer->nDims(), position); |
429 | } |
430 | |
431 | ComputeAt::runAt(this, consumer, (unsigned int)position, mode); |
432 | |
433 | return this; |
434 | } |
435 | |
436 | TensorView* TensorView::computeWith( |
437 | TensorView* consumer, |
438 | int position, |
439 | ComputeAtMode mode) { |
440 | TORCH_INTERNAL_ASSERT( |
441 | !container()->isA<kir::Kernel>(), |
442 | "Function invalid for kernel container." ); |
443 | // Make sure this and consumer are not the same tensor, that's illegal |
444 | TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)" ); |
445 | |
446 | // We support negative axes, so increment it by this->nDims() + 1 and make |
447 | // sure the result is within this->nDims() + 1. being at this->nDims() |
448 | // means producer will be computed inline with this, hence the +1. |
449 | if (position < 0) |
450 | position += int(this->nDims()) + 1; |
451 | TORCH_CHECK( |
452 | position >= 0 && (unsigned int)position < this->nDims() + 1, |
453 | "Compute at called on an position outside valid range." ); |
454 | |
455 | ComputeAt::runWith(this, consumer, (unsigned int)position, mode); |
456 | |
457 | return this; |
458 | } |
459 | |
460 | TensorView* TensorView::split( |
461 | int axis_, |
462 | Val* factor, |
463 | bool inner_split, |
464 | bool trim_out_of_bounds) { |
465 | // Only check things associated with axis, factor will be validated in |
466 | // IterDomain |
467 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView" ); |
468 | |
469 | if (axis_ < 0) |
470 | axis_ += domain()->nDims(); |
471 | |
472 | TORCH_INTERNAL_ASSERT( |
473 | axis_ >= 0, |
474 | "Split axis is less than 0 even after adjusting for nDims: " , |
475 | axis_); |
476 | |
477 | TORCH_CHECK( |
478 | axis_ >= (int)getComputeAtPosition(), |
479 | "Cannot split axis within compute at position. Axis = " , |
480 | axis_, |
481 | " computeAtPosition = " , |
482 | getComputeAtPosition()); |
483 | |
484 | TORCH_CHECK( |
485 | axis_ >= (int)getMaxProducerPosition(), |
486 | "Cannot split axis within max producer position. Axis = " , |
487 | axis_, |
488 | " maxProducerPosition = " , |
489 | getMaxProducerPosition()); |
490 | |
491 | TORCH_CHECK( |
492 | axis(axis_)->getParallelType() == ParallelType::Serial, |
493 | "Splitting an axis of non-Serial parallel type is not supported at this time." |
494 | " Parallelization strategy must be set after calling split." ); |
495 | |
496 | domain()->split(axis_, factor, inner_split, trim_out_of_bounds); |
497 | return this; |
498 | } |
499 | |
500 | TensorView* TensorView::split( |
501 | int axis, |
502 | unsigned int factor, |
503 | bool inner_split, |
504 | bool trim_out_of_bounds) { |
505 | split(axis, IrBuilder::create<Int>(factor), inner_split, trim_out_of_bounds); |
506 | return this; |
507 | } |
508 | |
509 | // Merge "axis_o" and "axis_i" into 1 dimension |
510 | TensorView* TensorView::merge(int axis_o, int axis_i) { |
511 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView" ); |
512 | |
513 | if (axis_o < 0) |
514 | axis_o += domain()->nDims(); |
515 | |
516 | if (axis_i < 0) |
517 | axis_i += domain()->nDims(); |
518 | |
519 | TORCH_CHECK( |
520 | axis_o >= (int)getComputeAtPosition() && |
521 | axis_i >= (int)getComputeAtPosition(), |
522 | false, |
523 | "Cannot merge axes within compute at position. Either axis " , |
524 | axis_o, |
525 | " or " , |
526 | axis_i, |
527 | " are within computeAtPosition = " , |
528 | getComputeAtPosition()); |
529 | |
530 | TORCH_CHECK( |
531 | axis_o >= (int)getMaxProducerPosition() && |
532 | axis_i >= (int)getMaxProducerPosition(), |
533 | "Cannot merge axes within max producer position. Either axis " , |
534 | axis_o, |
535 | " or " , |
536 | axis_i, |
537 | " are within maxProducerPosition = " , |
538 | getMaxProducerPosition()); |
539 | |
540 | TORCH_CHECK( |
541 | axis(axis_o)->getParallelType() == ParallelType::Serial || |
542 | axis(axis_i)->getParallelType() == ParallelType::Serial, |
543 | "Merging axes of non-Serial parallel type is not supported at this time." |
544 | " Parallelization strategy must be set after calling split." ); |
545 | |
546 | domain()->merge(axis_o, axis_i); |
547 | return this; |
548 | } |
549 | |
550 | TensorView* TensorView::reorder(const std::unordered_map<int, int>& old2new_) { |
551 | TORCH_INTERNAL_ASSERT( |
552 | !container()->isA<kir::Kernel>(), |
553 | "Function invalid for kernel container." ); |
554 | TORCH_INTERNAL_ASSERT( |
555 | !(nDims() == 0 && old2new_.size() > 0), |
556 | "Tried to reorder a 0-dim TensorView" ); |
557 | |
558 | for (auto entry : old2new_) { |
559 | auto old_pos = entry.first < 0 ? entry.first + (int)nDims() : entry.first; |
560 | auto new_pos = |
561 | entry.second < 0 ? entry.second + (int)nDims() : entry.second; |
562 | if (old_pos == new_pos) { |
563 | continue; |
564 | } |
565 | TORCH_INTERNAL_ASSERT( |
566 | old_pos >= 0, |
567 | "Found \"old\" position that's less than 0 even though already adjusted by nDims: " , |
568 | old_pos); |
569 | TORCH_INTERNAL_ASSERT( |
570 | new_pos >= 0, |
571 | "Found \"new\" position that's less than 0 even though already adjusted by nDims: " , |
572 | new_pos); |
573 | TORCH_CHECK( |
574 | old_pos >= (int)getComputeAtPosition() && |
575 | new_pos >= (int)getComputeAtPosition(), |
576 | "Cannot reorder axes within compute at position. Either axis " , |
577 | old_pos, |
578 | " or " , |
579 | new_pos, |
580 | " are within computeAtPosition = " , |
581 | getComputeAtPosition()); |
582 | |
583 | TORCH_CHECK( |
584 | old_pos >= (int)getMaxProducerPosition() && |
585 | new_pos >= (int)getMaxProducerPosition(), |
586 | "Cannot reorder axes within max producer position. Either axis " , |
587 | old_pos, |
588 | " or " , |
589 | new_pos, |
590 | " are within maxProducerPosition = " , |
591 | getMaxProducerPosition()); |
592 | } |
593 | |
594 | domain()->reorder(old2new_); |
595 | return this; |
596 | } |
597 | |
598 | TensorView* TensorView::swizzle( |
599 | SwizzleType type, |
600 | const std::vector<int>& axes) { |
601 | TORCH_INTERNAL_ASSERT( |
602 | !container()->isA<kir::Kernel>(), |
603 | "Function invalid for kernel container." ); |
604 | swizzle_type_ = type; |
605 | |
606 | // Clear previously set swizzle axes if any |
607 | if (axes_to_swizzle_.size()) { |
608 | axes_to_swizzle_.clear(); |
609 | } |
610 | |
611 | if (swizzle_type_ == SwizzleType::Transpose) { |
612 | TORCH_CHECK( |
613 | axes.size() == 2, |
614 | "Invalid axis list: " , |
615 | axes, |
616 | ". Number of axes must be two." ); |
617 | TORCH_CHECK( |
618 | axes[0] != axes[1], |
619 | "Invalid axis list: " , |
620 | axes, |
621 | ". Two distinctive axes must be given." ); |
622 | TORCH_CHECK( |
623 | getMemoryType() == MemoryType::Shared, |
624 | "Transpose swizzle is meant for tensors on shared memory." ); |
625 | for (auto pos : axes) { |
626 | if (pos < 0) { |
627 | pos += nDims(); |
628 | } |
629 | TORCH_CHECK(pos >= 0 && pos < (int)nDims(), "Invalid axis: " , pos); |
630 | TORCH_CHECK( |
631 | pos >= (int)getComputeAtPosition(), |
632 | "Invalid axis: " , |
633 | pos, |
634 | ". Axis outside computeAt position is not allocated." ); |
635 | TORCH_CHECK( |
636 | !axis(pos)->isReduction(), |
637 | "Invalid axis: " , |
638 | pos, |
639 | ". Swizzling a reduction axis is not supported" ); |
640 | TORCH_CHECK( |
641 | !axis(pos)->isBroadcast(), |
642 | "Invalid axis: " , |
643 | pos, |
644 | ". Swizzling a broadcast axis is not supported" ); |
645 | axes_to_swizzle_.push_back(axis(pos)); |
646 | } |
647 | } |
648 | |
649 | return this; |
650 | } |
651 | |
652 | TensorView* TensorView::swizzle( |
653 | Swizzle2DType swizzle_type, |
654 | int x, |
655 | int y, |
656 | SwizzleMode swizzle_mode) { |
657 | has_swizzle_op_ = true; |
658 | if (x < 0) { |
659 | x += domain()->nDims(); |
660 | } |
661 | if (y < 0) { |
662 | y += domain()->nDims(); |
663 | } |
664 | |
665 | TORCH_CHECK( |
666 | x >= (int)getComputeAtPosition(), |
667 | false, |
668 | "Cannot swizzle axes within compute at position. Axis " , |
669 | x, |
670 | " is within computeAtPosition = " , |
671 | getComputeAtPosition()); |
672 | |
673 | TORCH_CHECK( |
674 | y >= (int)getMaxProducerPosition(), |
675 | "Cannot swizzle axes within max producer position. Axis " , |
676 | y, |
677 | " is within maxProducerPosition = " , |
678 | getMaxProducerPosition()); |
679 | |
680 | // Disable unsupported use cases at the current step. |
681 | // Currently do not support reducing or broadcasting |
682 | // swizzled dimensions. |
683 | auto all_inputs = InputsOf::outputs(fusion(), {axis(x), axis(y)}); |
684 | for (auto id : ir_utils::filterByType<IterDomain>(all_inputs)) { |
685 | TORCH_INTERNAL_ASSERT( |
686 | !id->isBroadcast() && !id->isReduction(), |
687 | "Unsupported use case for swizzle." ); |
688 | } |
689 | |
690 | // Also checking that the scheduler is not trying to |
691 | // compose swizzles, which is not yet supported either. |
692 | auto all_exprs = DependencyCheck::getAllValsBetween( |
693 | {all_inputs.begin(), all_inputs.end()}, {axis(x), axis(y)}); |
694 | for (auto expr : all_exprs) { |
695 | TORCH_INTERNAL_ASSERT( |
696 | !expr->isA<Swizzle2D>(), "Composing swizzles is not yet supported" ); |
697 | } |
698 | |
699 | // Check swizzle specific constraints on the input axes: |
700 | if (swizzle_type != Swizzle2DType::ZShape) { |
701 | ExpressionEvaluator const_eval(fusion()); |
702 | |
703 | auto x_id = axis(x); |
704 | auto y_id = axis(y); |
705 | |
706 | TORCH_INTERNAL_ASSERT( |
707 | x_id->extent()->isConstInt() && y_id->extent()->isConstInt(), |
708 | "Only constant iterdomains supported on given swizzle type" ); |
709 | |
710 | int in_x_size = x_id->extent()->evaluateInt(); |
711 | int in_y_size = y_id->extent()->evaluateInt(); |
712 | |
713 | // Check size constraints based on swizzle type |
714 | if (swizzle_type == Swizzle2DType::Transpose || |
715 | swizzle_type == Swizzle2DType::XOR) { |
716 | TORCH_INTERNAL_ASSERT( |
717 | in_x_size == in_y_size, "Swizzle: equal dim iterdomains only" ); |
718 | } |
719 | |
720 | if (swizzle_type == Swizzle2DType::Scatter) { |
721 | TORCH_INTERNAL_ASSERT( |
722 | in_y_size == 4, "Swizzle: unsupported id size must be 4 " , in_y_size); |
723 | TORCH_INTERNAL_ASSERT( |
724 | in_x_size == 8 || in_x_size == 16 || in_x_size == 32, |
725 | "Swizzle: unsupported id size must be 8, 16, or 32 " , |
726 | in_x_size); |
727 | } |
728 | } |
729 | |
730 | domain()->swizzle(swizzle_type, x, y, swizzle_mode); |
731 | |
732 | return this; |
733 | } |
734 | |
735 | TensorView* TensorView::rFactor(const std::vector<int>& axes) { |
736 | TORCH_INTERNAL_ASSERT( |
737 | !container()->isA<kir::Kernel>(), |
738 | "Function invalid for kernel container." ); |
739 | // TODO: I think we should do this but |
740 | // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the |
741 | // moment. |
742 | |
743 | // TORCH_INTERNAL_ASSERT( |
744 | // !hasComputeAt(), "Cannot rfactor tensors after compute at has been |
745 | // set."); |
746 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView" ); |
747 | FusionGuard fg(fusion()); |
748 | TORCH_CHECK( |
749 | definition() != nullptr && |
750 | (definition()->getExprType() == ExprType::ReductionOp || |
751 | definition()->getExprType() == ExprType::MmaOp), |
752 | "Error rfactoring " , |
753 | this, |
754 | " its definition is either a nullptr or not a reduction." ); |
755 | TORCH_CHECK( |
756 | !domain()->hasRFactor(), "Cannot call rfactor on the same view twice." ); |
757 | |
758 | TORCH_CHECK( |
759 | !definition()->isA<GroupedReductionOp>(), |
760 | "For GroupedReductionOp, use TensorView::rFactor(const std::vector<int>& axes, const std::vector<TensorView*>& tvs)" ); |
761 | |
762 | // Split tensor view into 2 parts |
763 | auto domain_pair = domain()->rFactor(axes); |
764 | |
765 | // Producer in the pair |
766 | auto producer_domain = domain_pair.first; |
767 | // Consumer in the pair |
768 | auto consumer_domain = domain_pair.second; |
769 | |
770 | // This domain will be the consumer, so create the producer |
771 | TensorView* producer = |
772 | IrBuilder::create<TensorView>(producer_domain, getDataType().value()); |
773 | |
774 | // Set domain of consumer |
775 | setDomain(consumer_domain); |
776 | TensorView* consumer = this; |
777 | |
778 | if (auto this_reduction = dynamic_cast<ReductionOp*>(definition())) { |
779 | // Setup dependency chain, inserting producer before this op. |
780 | // Expr* producer_definition = |
781 | IrBuilder::create<ReductionOp>( |
782 | this_reduction->getReductionOpType(), |
783 | this_reduction->init(), |
784 | producer, |
785 | this_reduction->in()); |
786 | |
787 | // Expr* consumer_definition = |
788 | IrBuilder::create<ReductionOp>( |
789 | this_reduction->getReductionOpType(), |
790 | this_reduction->init(), |
791 | consumer, |
792 | producer); |
793 | } else if (auto this_mma = dynamic_cast<MmaOp*>(definition())) { |
794 | // Initial reduction that still uses mma to combine |
795 | // the input. |
796 | IrBuilder::create<MmaOp>( |
797 | producer, |
798 | this_mma->inA(), |
799 | this_mma->inB(), |
800 | this_mma->init(), |
801 | this_mma->options()); |
802 | |
803 | // Remaining reduction that can be scheduled cross |
804 | // warp or cta. |
805 | IrBuilder::create<ReductionOp>( |
806 | BinaryOpType::Add, this_mma->init(), consumer, producer); |
807 | } else { |
808 | TORCH_INTERNAL_ASSERT(false, "RFactor: unsupported tensor definition" ); |
809 | } |
810 | return producer; |
811 | } |
812 | |
813 | TensorView* TensorView::multiOutputRfactorHelper( |
814 | TensorView* tv, |
815 | const std::vector<int>& axes) { |
816 | TORCH_INTERNAL_ASSERT( |
817 | !container()->isA<kir::Kernel>(), |
818 | "Function invalid for kernel container." ); |
819 | // Hack: |
820 | // Semantically we should always keep the outputs of multi reduction ops |
821 | // scheduled the same but the user end cannot guarantee that. In order to |
822 | // guarantee that the rFactor is defined meaningfully the scheduling of the |
823 | // output TV that got the rfactor call is force replayed towards the other two |
824 | |
825 | if (!sameAs(tv)) { |
826 | auto root = tv->getRootDomain(); |
827 | auto this_root = getRootDomain(); |
828 | |
829 | // construct a trivial root domain map |
830 | std::unordered_map<IterDomain*, IterDomain*> id_map; |
831 | for (const auto i : c10::irange(root.size())) { |
832 | id_map[this_root[i]] = root[i]; |
833 | } |
834 | |
835 | // replay on the target tv |
836 | ReplayTransformations replay(domain()->domain(), id_map); |
837 | |
838 | // construct the new tensor domain |
839 | std::vector<IterDomain*> new_id; |
840 | for (auto id : domain()->domain()) { |
841 | TORCH_INTERNAL_ASSERT( |
842 | replay.getReplay().count(id), "Multi-output reduction replay failed" ); |
843 | new_id.push_back(replay.getReplay().at(id)); |
844 | } |
845 | |
846 | std::vector<bool> new_contig( |
847 | tv->domain()->contiguity().begin(), tv->domain()->contiguity().end()); |
848 | // replace tensor domain of target tv |
849 | tv->setDomain(IrBuilder::create<TensorDomain>( |
850 | tv->getRootDomain(), new_id, new_contig)); |
851 | } |
852 | |
853 | // Split tensor view into 2 parts |
854 | auto domain_pair = tv->domain()->rFactor(axes); |
855 | // Producer in the pair |
856 | auto producer_domain = domain_pair.first; |
857 | // Consumer in the pair |
858 | auto consumer_domain = domain_pair.second; |
859 | |
860 | // This domain will be the consumer, so create the producer |
861 | TensorView* producer = |
862 | IrBuilder::create<TensorView>(producer_domain, tv->getDataType().value()); |
863 | |
864 | // Set domain of consumer |
865 | tv->setDomain(consumer_domain); |
866 | |
867 | return producer; |
868 | } |
869 | |
870 | std::vector<TensorView*> TensorView::rFactor( |
871 | const std::vector<int>& axes, |
872 | const std::vector<TensorView*>& tvs) { |
873 | TORCH_CHECK( |
874 | !container()->isA<kir::Kernel>(), |
875 | "Function invalid for kernel container." ); |
876 | TORCH_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView" ); |
877 | FusionGuard fg(fusion()); |
878 | TORCH_CHECK( |
879 | definition() != nullptr && ir_utils::isReductionOp(definition()), |
880 | "Error rfactoring multi-output reduction op " , |
881 | this, |
882 | " its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op." ); |
883 | |
884 | TORCH_CHECK( |
885 | !domain()->hasRFactor(), "Cannot call rfactor on the same view twice." ); |
886 | |
887 | TORCH_CHECK( |
888 | definition()->outputs().size() == tvs.size(), |
889 | "Rfactor of a multi-output reduction not used correctly" ); |
890 | |
891 | for (const auto i : c10::irange(tvs.size())) { |
892 | TORCH_CHECK( |
893 | definition()->output(i) == tvs.at(i), |
894 | "Rfactor of a multi-output reduction not used correctly" ); |
895 | } |
896 | |
897 | // Currently grouping of welford is only supported through |
898 | // ParallelType::Group, so GroupedWelfordOp is only created during |
899 | // the lowering time. As rFactor is done before lowering, there |
900 | // should be no GroupedWelfordOp at this point. |
901 | TORCH_INTERNAL_ASSERT( |
902 | !definition()->isA<GroupedWelfordOp>(), |
903 | "GroupedWelfordOp found: " , |
904 | definition()->toString()); |
905 | |
906 | std::vector<TensorView*> rf_tvs(tvs.size()); |
907 | |
908 | // Make sure this gets rfactored last so everybody gets |
909 | // replayed correctly |
910 | for (const auto i : c10::irange(tvs.size())) { |
911 | if (this != tvs.at(i)) { |
912 | rf_tvs.at(i) = multiOutputRfactorHelper(tvs.at(i), axes); |
913 | } |
914 | } |
915 | |
916 | for (const auto i : c10::irange(tvs.size())) { |
917 | if (this == tvs.at(i)) { |
918 | rf_tvs.at(i) = multiOutputRfactorHelper(tvs.at(i), axes); |
919 | } |
920 | } |
921 | |
922 | if (auto wop = dynamic_cast<WelfordOp*>(definition())) { |
923 | TensorView* producer_avg = rf_tvs.at(0); |
924 | TensorView* producer_var = rf_tvs.at(1); |
925 | TensorView* producer_n = rf_tvs.at(2); |
926 | |
927 | // Setup dependency chain, inserting producer before this op. |
928 | // Expr* producer_definition = |
929 | IrBuilder::create<WelfordOp>( |
930 | producer_avg, |
931 | producer_var, |
932 | producer_n, |
933 | wop->inAvg(), |
934 | wop->inVar(), |
935 | wop->inN(), |
936 | wop->initAvg(), |
937 | wop->initVar(), |
938 | wop->initN()); |
939 | |
940 | // Expr* consumer_definition = |
941 | IrBuilder::create<WelfordOp>( |
942 | wop->outAvg(), |
943 | wop->outVar(), |
944 | wop->outN(), |
945 | producer_avg, |
946 | producer_var, |
947 | producer_n, |
948 | wop->initAvg(), |
949 | wop->initVar(), |
950 | wop->initN()); |
951 | } else if ( |
952 | auto grouped_rop = dynamic_cast<GroupedReductionOp*>(definition())) { |
953 | IrBuilder::create<GroupedReductionOp>( |
954 | grouped_rop->getReductionOpTypes(), |
955 | grouped_rop->initVals(), |
956 | std::vector<Val*>{rf_tvs.begin(), rf_tvs.end()}, |
957 | grouped_rop->inputs()); |
958 | |
959 | IrBuilder::create<GroupedReductionOp>( |
960 | grouped_rop->getReductionOpTypes(), |
961 | grouped_rop->initVals(), |
962 | grouped_rop->outputs(), |
963 | std::vector<Val*>{rf_tvs.begin(), rf_tvs.end()}); |
964 | } else { |
965 | TORCH_INTERNAL_ASSERT( |
966 | false, "Invalid definition: " , definition()->toString()); |
967 | } |
968 | |
969 | return rf_tvs; |
970 | } |
971 | |
972 | TensorView* TensorView::cacheBefore(c10::optional<LoadStoreOpType> cache_op) { |
973 | TORCH_INTERNAL_ASSERT( |
974 | !container()->isA<kir::Kernel>(), |
975 | "Function invalid for kernel container." ); |
976 | FusionGuard fg(fusion()); |
977 | |
978 | TORCH_CHECK( |
979 | definition() != nullptr && !isFusionInput(), |
980 | "Error adding cacheBefore " , |
981 | this, |
982 | " its definition is a nullptr and we restrict using cacheBefore on an input." ); |
983 | |
984 | // Previously, caching computed-at tensors was allowed but was never |
985 | // really robust. Make it an error unless it is really needed. |
986 | TORCH_CHECK( |
987 | !hasComputeAt(), |
988 | "Caching computed-at tensors is not allowed. Apply caching before computeAt" ); |
989 | |
990 | // It also did additional transformation when a producer tensor has computeAt. |
991 | // Make sure we no longer rely on that behavior. |
992 | if (definition() != nullptr) { |
993 | for (TensorView* producer_of_producer : |
994 | ir_utils::filterByType<TensorView>(definition()->inputs())) { |
995 | TORCH_CHECK( |
996 | !producer_of_producer->hasComputeAt(), |
997 | "Potentially invalid computeAt and caching detected. Apply caching before computeAt." ); |
998 | } |
999 | } |
1000 | |
1001 | // Create Producer Domain |
1002 | // This domain will be the consumer which needs a new domain, so replace the |
1003 | // producers domain with this domain. |
1004 | |
1005 | TensorView* producer = IrBuilder::create<TensorView>( |
1006 | container(), |
1007 | IrBuilder::create<TensorDomain>( |
1008 | container(), |
1009 | domain()->getRootDomain(), |
1010 | domain()->getRFactorDomain(), |
1011 | domain()->domain(), |
1012 | domain()->contiguity()), |
1013 | getDataType().value()); |
1014 | |
1015 | // Set domain of consumer |
1016 | TensorView* consumer = this; |
1017 | |
1018 | size_t i = 0; |
1019 | auto no_reduction_root_domain = |
1020 | TensorDomain::noReductions(getMaybeRFactorDomain()); |
1021 | std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size()); |
1022 | for (const auto& dom : no_reduction_root_domain) { |
1023 | new_root_domain[i++] = dom->cloneWithoutRFactor(); |
1024 | } |
1025 | |
1026 | consumer->setDomain(IrBuilder::create<TensorDomain>( |
1027 | container(), |
1028 | new_root_domain, |
1029 | std::vector<bool>(new_root_domain.size(), true))); |
1030 | |
1031 | // Insert producer - Cache_Before (CB) - before this TV. |
1032 | // Before: Prev TV -> [Definition Op] -> This TV |
1033 | // After: Prev TV -> [Definition Op] -> New CB TV -> [Set Op] -> This TV |
1034 | |
1035 | // Get inputs for origin expression |
1036 | auto expr_inputs = definition()->inputs(); |
1037 | // Expr* producer_definition = |
1038 | ir_utils::replaceValInExpr(definition(), this, producer); |
1039 | |
1040 | // Expr* producer_uses = |
1041 | if (cache_op.has_value()) { |
1042 | IrBuilder::create<LoadStoreOp>( |
1043 | container(), cache_op.value(), consumer, producer); |
1044 | } else { |
1045 | IrBuilder::create<UnaryOp>( |
1046 | container(), UnaryOpType::Set, consumer, producer); |
1047 | } |
1048 | |
1049 | // definition_ is no longer valid |
1050 | // setDefinition(nullptr); |
1051 | |
1052 | auto replayed_consumer_pair = |
1053 | TransformReplay::replayCasP(consumer, producer, -1); |
1054 | consumer->setDomain(replayed_consumer_pair.first); |
1055 | |
1056 | return producer; |
1057 | } |
1058 | |
1059 | TensorView* TensorView::cacheFork() { |
1060 | TORCH_INTERNAL_ASSERT( |
1061 | !container()->isA<kir::Kernel>(), |
1062 | "Function invalid for kernel container." ); |
1063 | FusionGuard fg(fusion()); |
1064 | |
1065 | // Before: [Expr] -> This TV (Global Output) -> [Usage Expr] |
1066 | // After: [Expr] -> This TV (Local) -> [Usage Expr] > Next TV |
1067 | // (Fork) -> [Set Expr] -> New TV (Global Output) |
1068 | |
1069 | TORCH_CHECK( |
1070 | this->isFusionOutput() && !this->uses().empty(), |
1071 | "Error adding cacheFork " , |
1072 | this, |
1073 | " this TensorView must be an output with subsequent uses" ); |
1074 | |
1075 | // Previously, caching computed-at tensors was allowed but was never |
1076 | // really robust. Make it an error unless it is really needed. |
1077 | TORCH_CHECK( |
1078 | !hasComputeAt(), |
1079 | "Caching computed-at tensors is not allowed. Apply caching before computeAt" ); |
1080 | |
1081 | // This domain will be the producer, so create the consumer |
1082 | auto root_domain = TensorDomain::noReductions(getMaybeRFactorDomain()); |
1083 | TensorView* new_output = IrBuilder::create<TensorView>( |
1084 | container(), |
1085 | IrBuilder::create<TensorDomain>( |
1086 | container(), |
1087 | IterDomain::clone(root_domain), |
1088 | std::vector<bool>(root_domain.size(), true)), |
1089 | getDataType().value()); |
1090 | |
1091 | // Create write operation from this TV to new output |
1092 | IrBuilder::create<UnaryOp>(container(), UnaryOpType::Set, new_output, this); |
1093 | |
1094 | // The new TV becomes an output. |
1095 | // New TV has global memory type. |
1096 | // This TV has local memory type. |
1097 | fusion()->replaceOutput(this, new_output); |
1098 | |
1099 | // Transform new output according to this TV |
1100 | auto replayed_output_pair = TransformReplay::replayCasP(new_output, this, -1); |
1101 | new_output->setDomain(replayed_output_pair.first); |
1102 | |
1103 | return new_output; |
1104 | } |
1105 | |
1106 | TensorView* TensorView::cacheAfter(c10::optional<LoadStoreOpType> cache_op) { |
1107 | TORCH_INTERNAL_ASSERT( |
1108 | !container()->isA<kir::Kernel>(), |
1109 | "Function invalid for kernel container." ); |
1110 | FusionGuard fg(fusion()); |
1111 | |
1112 | // Get all the uses for this Tensorview |
1113 | TORCH_CHECK( |
1114 | !isFusionOutput(), |
1115 | "Error adding cacheAfter " , |
1116 | this, |
1117 | " we restrict using cacheAfter on an output." ); |
1118 | |
1119 | // Previously, caching computed-at tensors was allowed but was never |
1120 | // really robust. Make it an error unless it is really needed. |
1121 | TORCH_CHECK( |
1122 | !hasComputeAt(), |
1123 | "Caching computed-at tensors is not allowed. Apply caching before computeAt." ); |
1124 | |
1125 | // It also did additional transformation when this tensor is an |
1126 | // input and the outputs of its consumers have computeAt. Make sure |
1127 | // we no longer rely on that behavior. |
1128 | if (isFusionInput()) { |
1129 | for (const auto& expr : uses()) { |
1130 | for (TensorView* output : |
1131 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
1132 | TORCH_CHECK( |
1133 | !output->hasComputeAt(), |
1134 | "Potentially invalid computeAt and caching detected. Apply caching before computeAt." ); |
1135 | } |
1136 | } |
1137 | } |
1138 | |
1139 | // Create Consumer Domain |
1140 | // Keep Broadcast Axis (Permanent) |
1141 | // Remove Reduction Axis |
1142 | size_t i = 0; |
1143 | auto no_reduction_root_domain = |
1144 | TensorDomain::noReductions(getMaybeRFactorDomain()); |
1145 | std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size()); |
1146 | for (const auto& dom : no_reduction_root_domain) { |
1147 | new_root_domain[i++] = dom->cloneWithoutRFactor(); |
1148 | } |
1149 | |
1150 | // This domain will be the producer, so create the consumer |
1151 | TensorView* consumer = IrBuilder::create<TensorView>( |
1152 | container(), |
1153 | IrBuilder::create<TensorDomain>( |
1154 | container(), |
1155 | new_root_domain, |
1156 | std::vector<bool>(new_root_domain.size(), true)), |
1157 | getDataType().value()); |
1158 | |
1159 | // Set domain of producer - No Change |
1160 | TensorView* producer = this; |
1161 | |
1162 | // Insert consumer - Cache_After (CA) - after this TV. |
1163 | // Before: This TV -> [Use Op] -> Next TV |
1164 | // After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV |
1165 | |
1166 | // Expr* consumer_uses = |
1167 | for (auto expr : fusion()->unordered_uses(this)) { |
1168 | ir_utils::replaceValInExpr(expr, this, consumer); |
1169 | } |
1170 | |
1171 | // Expr* consumer_definition = |
1172 | if (cache_op.has_value()) { |
1173 | IrBuilder::create<LoadStoreOp>( |
1174 | container(), cache_op.value(), consumer, producer); |
1175 | } else { |
1176 | IrBuilder::create<UnaryOp>( |
1177 | container(), UnaryOpType::Set, consumer, producer); |
1178 | } |
1179 | |
1180 | return consumer; |
1181 | } |
1182 | |
1183 | void TensorView::setMemoryType(MemoryType mt) { |
1184 | memory_type_ = mt; |
1185 | if (isFusionInput() || isFusionOutput()) { |
1186 | TORCH_INTERNAL_ASSERT( |
1187 | mt == MemoryType::Global, |
1188 | "Tried to set an input or output to the fusion to a non-global memory type." ); |
1189 | } |
1190 | } |
1191 | |
1192 | void TensorView::clearReductionIterDomains() { |
1193 | TORCH_INTERNAL_ASSERT( |
1194 | !domain()->hasRFactor(), |
1195 | "should not call clearReductionIterDomains on rfactor tv" ); |
1196 | |
1197 | TORCH_INTERNAL_ASSERT( |
1198 | domain()->domain() == getRootDomain(), |
1199 | "should not call clearReductionIterDomains on already transformed TensorDomains" ); |
1200 | |
1201 | std::vector<IterDomain*> new_root; |
1202 | std::vector<bool> new_contig; |
1203 | for (const auto i : c10::irange(getRootDomain().size())) { |
1204 | if (!getRootDomain()[i]->isReduction()) { |
1205 | new_root.push_back(getRootDomain()[i]); |
1206 | new_contig.push_back(domain()->contiguity()[i]); |
1207 | } |
1208 | } |
1209 | |
1210 | setDomain(IrBuilder::create<TensorDomain>(container(), new_root, new_contig)); |
1211 | } |
1212 | |
1213 | void TensorView::doubleBuffer() { |
1214 | // Early correctness checking. May miss eventual errors as the |
1215 | // checks depend on memory types and parallelization, which may not |
1216 | // be finalized until lowering. |
1217 | validateDoubleBufferedTensor(this); |
1218 | is_double_buffered_ = true; |
1219 | } |
1220 | |
1221 | void TensorView::circularBuffer(unsigned int stage) { |
1222 | // Early correctness checking. May miss eventual errors as the |
1223 | // checks depend on memory types and parallelization, which may not |
1224 | // be finalized until lowering. |
1225 | TORCH_INTERNAL_ASSERT(stage > 1, "Unsupported stage number" ); |
1226 | if (stage == 2) { |
1227 | // Re-direct to double buffer interface if stage is 2; |
1228 | doubleBuffer(); |
1229 | return; |
1230 | } |
1231 | validateDoubleBufferedTensor(this); |
1232 | is_circular_buffered_ = true; |
1233 | circular_buffer_stage_ = stage; |
1234 | } |
1235 | |
1236 | bool TensorView::isEmptyTensor() const { |
1237 | auto& root_domain = getMaybeRFactorDomain(); |
1238 | return std::all_of( |
1239 | root_domain.begin(), root_domain.end(), [](IterDomain* id) { |
1240 | return id->extent()->isZeroInt(); |
1241 | }); |
1242 | } |
1243 | |
1244 | void TensorView::applyMmaSwizzle(MmaOptions options) { |
1245 | switch (options.operand) { |
1246 | case MmaOptions::Operand::Accumulator: |
1247 | mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options); |
1248 | break; |
1249 | case MmaOptions::Operand::A: |
1250 | case MmaOptions::Operand::B: |
1251 | mma_util::WarpMmaSwizzler::scheduleOperandRead(this, options); |
1252 | break; |
1253 | default: |
1254 | TORCH_INTERNAL_ASSERT(false, "unknown operand flag" ); |
1255 | break; |
1256 | } |
1257 | } |
1258 | |
1259 | TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) { |
1260 | TORCH_CHECK(shape_.empty() || shape_.size() == ndims); |
1261 | TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims); |
1262 | ndims_ = ndims; |
1263 | return *this; |
1264 | } |
1265 | |
1266 | TensorViewBuilder& TensorViewBuilder::dtype(DataType dtype) { |
1267 | dtype_ = dtype; |
1268 | return *this; |
1269 | } |
1270 | |
1271 | TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) { |
1272 | TORCH_CHECK(contiguity_.empty(), "Attempting to reset contiguity" ); |
1273 | if (!contiguity.empty()) { |
1274 | TORCH_CHECK(ndims_ == 0 || ndims_ == contiguity.size()); |
1275 | ndims_ = contiguity.size(); |
1276 | } |
1277 | contiguity_ = std::move(contiguity); |
1278 | return *this; |
1279 | } |
1280 | |
1281 | TensorViewBuilder& TensorViewBuilder::shape(const std::vector<int64_t>& shape) { |
1282 | TORCH_CHECK(shape_.empty(), "Attempting to reset shape" ); |
1283 | if (!shape.empty()) { |
1284 | TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size()); |
1285 | ndims_ = shape.size(); |
1286 | } |
1287 | shape_.clear(); |
1288 | shape_.reserve(shape.size()); |
1289 | for (int64_t i : shape) { |
1290 | if (i == -1) { |
1291 | shape_.emplace_back(IrBuilder::create<Int>()); |
1292 | } else { |
1293 | TORCH_CHECK( |
1294 | i >= 0, |
1295 | "Invalid extent value. " , |
1296 | "For a tensor representing a single scalar use ndims = 0 with no sizes set." ); |
1297 | shape_.emplace_back(IrBuilder::create<Int>(i)); |
1298 | } |
1299 | } |
1300 | return *this; |
1301 | } |
1302 | |
1303 | TensorViewBuilder& TensorViewBuilder::shape(std::vector<Val*> shape) { |
1304 | TORCH_CHECK(shape_.empty(), "Attempting to reset shape" ); |
1305 | if (!shape.empty()) { |
1306 | TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size()); |
1307 | ndims_ = shape.size(); |
1308 | } |
1309 | shape_ = std::move(shape); |
1310 | return *this; |
1311 | } |
1312 | |
1313 | TensorView* TensorViewBuilder::build() const { |
1314 | // Build the domain |
1315 | std::vector<IterDomain*> domain(ndims_, nullptr); |
1316 | for (const auto i : c10::irange(ndims_)) { |
1317 | if (shape_.empty()) { |
1318 | domain[i] = |
1319 | IterDomainBuilder( |
1320 | FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create<Int>()) |
1321 | .build(); |
1322 | } else { |
1323 | if (shape_[i]->isOneInt()) { |
1324 | // If size is known to be 1, assume it needs to be broadcasted. |
1325 | domain[i] = IterDomainBuilder( |
1326 | FusionGuard::getCurFusion()->zeroVal(), |
1327 | FusionGuard::getCurFusion()->oneVal()) |
1328 | .iter_type(IterType::Broadcast) |
1329 | .build(); |
1330 | } else { |
1331 | domain[i] = |
1332 | IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), shape_[i]) |
1333 | .build(); |
1334 | } |
1335 | } |
1336 | } |
1337 | |
1338 | // Create the final TensorView |
1339 | return IrBuilder::create<TensorView>( |
1340 | IrBuilder::create<TensorDomain>(domain, contiguity_), dtype_); |
1341 | } |
1342 | |
1343 | } // namespace cuda |
1344 | } // namespace fuser |
1345 | } // namespace jit |
1346 | } // namespace torch |
1347 | |