1#include <transform_view.h>
2
3#include <arith.h>
4#include <fusion.h>
5#include <instrumentation.h>
6#include <ir_builder.h>
7#include <ir_internal_nodes.h>
8#include <ir_iostream.h>
9#include <iter_visitor.h>
10#include <transform_iter.h>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17//! There's three domains associated with performing a view operation:
18//! 1) Original Domain:
19//! This view is the original input to the view operation. It has no
20//! transforms on it, it is however passed in without its reduction domains
21//! (as is expected since we're trying to generate the output of the
22//! operations).
23//!
24//! Trivially reduced domain:
25//! Predicting which operations are trivial reduced are not trivial. If a
26//! broadcast is between two iter domains in the original domain that must be
27//! merged for the view transform:
28//! - If the broadcast domain lines up with a broadcast domain in the final
29//! tensor domain keep it.
30//! - If the domain is size-1 but not marked as a broadcast domain (runtime
31//! size==1)
32//! Note: This isn't something we generally support consistently
33//! - If the broadcast domain is marked as a compile time broadcast domain,
34//! and doesn't line up with a broadcast domain in the final result.
35//! Trivially reduce it.
36//! The index for these transformations is marked as the index of the original
37//! domain, as that's the input for the trivial reduction. This produces the
38//! trivially reduced domain.
39//!
40//! Post-view Domain:
41//! This domain is the original domain after the trivial reductions and all
42//! transformations. This domain holds the rfactor domains determined by
43//! merge/split operations of the find transformations pass. It is the final
44//! domain without all the broadcast operations (can have some that were
45//! preserved through the transformations).
46//! For example: {1, 2, 1, 4} -> {1, 2, 1, 2, 2} doesn't have any
47//! conflicts of the view transformation and the broadcast dimensions,
48//! so they won't be trivial reduced, they will simply be propagated
49//! through the view.
50//! {1, 2, 1, 4} -> {1, 8, 1} does have the second 1 dimension in
51//! between the 2 and 8 that have to be merged. The first broadcast axis
52//! will be propagated through the domains unafected, yet the second
53//! braodcast axis will be trivially reduced, then rebroadcasted.
54//! The transformation index marked for the splits/merges to produce this
55//! domain are done based on an "in progress" tensor view (called transform
56//! view index in the find transformation pass). This allows us to simply apply
57//! these transformations serially to produce this domain.
58//!
59//! Post-broadcast Domain:
60//! This domain finally matches the output of the view operation fully and
61//! can be used in further computations.
62//!
63//! View process at compute time:
64//! 1) View takes in the input TensorView x, original runtime
65//! std::vector<int64_t>, and viewed runtime std::vector<int64_t>.
66//! 2) AnalyzeView is called Which will figure out what series of
67//! transformations is required from the input tensor to the output tensor.
68//! These transformations are recorded.
69//! 3) Sum operation is called on the trivial reduction axes from the
70//! analysis.
71//! 4) applyViewTransforms will generate the output domain of the view
72//! operation.
73//! Calls TensorDomain::view(view_analysis) which returns the rfactored
74//! domain.
75//! Gets forwarded to transformView(TensorDomain, view_analysis)
76//! Gets forwarded to createViewDomain(TensorDomain, view_analysis)
77//! createViewDomain creates the new root domain, and calls
78//! createRfactorDomain on view_analysis.transforms().
79//! 5) brooadcast will be called with view_analysis.broadcast_axes
80//!
81//! TODO: Caching assumes that all size-1 inputs are correctly marked as a
82//! broadcast dimension. We should probably remove the runtime size-1 merge
83//! support in find transformation.
84//!
85//! Simple abstract class to record transformation and the indices required to
86//! apply it.
87class Transform : public PolymorphicBase {
88 public:
89 virtual std::string toString() const = 0;
90
91 int64_t index() const {
92 return index_;
93 }
94
95 protected:
96 // Relevant location information for the transformation. Stored information is
97 // related to when we have to apply that transformation (see long comment at
98 // top of this file).
99 Transform(int64_t index) : index_(index) {}
100
101 const int64_t index_ = 0;
102};
103
104class ViewTransform : public Transform {
105 public:
106 // Function to apply the transformation. Transformation is applied on
107 // current_transformed_domain. root_domain is required here to replace
108 // IterDomains so we can flip the rfactor flag on the root domain if it's
109 // involved in merge/split trasnforms to produce the rfactor domain.
110 virtual void createRfactorDomain(
111 std::vector<IterDomain*>& root_domain,
112 std::vector<IterDomain*>& current_transformed_domain) = 0;
113
114 // Convenience function to replace id in root_domain with an id that has
115 // expand expanded, and rfactor flag turned on.
116 static IterDomain* replaceRootIdWithRFactor(
117 std::vector<IterDomain*>& root_domain,
118 IterDomain* id) {
119 auto root_domain_it = std::find(root_domain.begin(), root_domain.end(), id);
120
121 TORCH_INTERNAL_ASSERT(
122 root_domain_it != root_domain.end(),
123 "Wanted to replace ",
124 id->toString(),
125 " in root with an rfactor dimension, but IterDomain was not found in root.");
126
127 auto root_domain_pos = std::distance(root_domain.begin(), root_domain_it);
128
129 bool is_expanded_dim = id->hasExpandedExtent();
130
131 auto extent = is_expanded_dim ? id->expandedExtent() : id->extent();
132
133 auto cloned_id =
134 IterDomainBuilder(id)
135 .iter_type(
136 is_expanded_dim ? IterType::Iteration : id->getIterType())
137 .extent(extent)
138 .expanded_extent(nullptr)
139 .is_rfactor_domain(true)
140 .build();
141
142 root_domain.erase(root_domain.begin() + root_domain_pos);
143 root_domain.insert(root_domain.begin() + root_domain_pos, cloned_id);
144 return cloned_id;
145 }
146
147 // Debugging utility to convert the transformation into a string.
148 virtual std::string toString() const override = 0;
149
150 protected:
151 ViewTransform(const int64_t& index) : Transform(index) {}
152};
153
154namespace {
155//! The merge tranformation either combines two root iterDomains together OR
156//! the last rfactor iterDomain with a root iterDomain. Unlike the general
157//! TensorView merge there's no merging across axes not placed in consecutive
158//! positions for View.
159class MergeTransform final : public ViewTransform {
160 public:
161 MergeTransform(int64_t index) : ViewTransform(index) {}
162
163 virtual std::string toString() const override {
164 std::stringstream ss;
165 ss << "Merge at index: " << index_;
166 return ss.str();
167 }
168
169 void createRfactorDomain(
170 std::vector<IterDomain*>& root_domain,
171 std::vector<IterDomain*>& current_transformed_domain) override {
172 TORCH_INTERNAL_ASSERT(
173 (index_ + 1) < (int64_t)current_transformed_domain.size(),
174 "Tried to apply: ",
175 toString(),
176 "\t To domain: \t",
177 current_transformed_domain);
178
179 // Assumed to never merge over non-contiguous dimensions.
180 IterDomain* outer_id = current_transformed_domain[index_];
181 if (!outer_id->isRFactorProduct()) {
182 outer_id = replaceRootIdWithRFactor(root_domain, outer_id);
183 }
184
185 IterDomain* inner_id = current_transformed_domain[index_ + 1];
186 if (!inner_id->isRFactorProduct()) {
187 inner_id = replaceRootIdWithRFactor(root_domain, inner_id);
188 }
189
190 TORCH_INTERNAL_ASSERT(
191 outer_id->start()->isZeroInt() && inner_id->start()->isZeroInt(),
192 "Didn't expect to apply view transformations on an iter domain",
193 " starting at a non-zero position.");
194
195 auto merged_extent = mul(outer_id->extent(), inner_id->extent());
196
197 auto new_merged_id =
198 IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), merged_extent)
199 .is_rfactor_domain(true)
200 .build();
201
202 IrBuilder::create<Merge>(new_merged_id, outer_id, inner_id);
203
204 current_transformed_domain.erase(
205 current_transformed_domain.begin() + index_);
206 current_transformed_domain.erase(
207 current_transformed_domain.begin() + index_);
208 current_transformed_domain.insert(
209 current_transformed_domain.begin() + index_, new_merged_id);
210 }
211};
212
213//! The split tranformation creates two new iterDomains via an outer split.
214class SplitTransform final : public ViewTransform {
215 public:
216 SplitTransform(const int64_t index, int64_t split_factor)
217 : ViewTransform(index), split_factor_(split_factor) {
218 TORCH_INTERNAL_ASSERT(
219 split_factor > 0,
220 "Split factors must be greater than 0, but found ",
221 split_factor,
222 " during view transformation.");
223 }
224
225 virtual std::string toString() const override {
226 std::stringstream ss;
227 ss << "Split Index at: " << index_ << " by: " << split_factor_ << std::endl;
228 return ss.str();
229 }
230
231 void createRfactorDomain(
232 std::vector<IterDomain*>& root_domain,
233 std::vector<IterDomain*>& current_transformed_domain) override {
234 TORCH_INTERNAL_ASSERT(
235 index_ < (int64_t)current_transformed_domain.size(),
236 "Index: \t",
237 index_,
238 "\t Domain Size:\t",
239 current_transformed_domain.size());
240
241 auto factor = IrBuilder::create<Int>(split_factor_);
242
243 IterDomain* id = current_transformed_domain[index_];
244 if (!id->isRFactorProduct()) {
245 id = replaceRootIdWithRFactor(root_domain, id);
246 }
247
248 TORCH_INTERNAL_ASSERT(
249 id->start()->isZeroInt(),
250 "Didn't expect to apply view transformations on an iter domain",
251 " starting at a non-zero position.");
252
253 Val* remainder = ceilDiv(id->extent(), factor);
254
255 // outer loop IterDomain
256 IterDomain* factor_id =
257 IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), factor)
258 .parallel_type(id->getParallelType())
259 .iter_type(id->getIterType())
260 .is_rfactor_domain(true)
261 .build();
262
263 // inner loop IterDomain
264 IterDomain* remainder_id =
265 IterDomainBuilder(
266 FusionGuard::getCurFusion()->zeroVal(), remainder->as<Int>())
267 .is_rfactor_domain(true)
268 .build();
269
270 IrBuilder::create<Split>(factor_id, remainder_id, id, factor, false);
271
272 current_transformed_domain.erase(
273 current_transformed_domain.begin() + index_);
274 current_transformed_domain.insert(
275 current_transformed_domain.begin() + index_, remainder_id);
276 current_transformed_domain.insert(
277 current_transformed_domain.begin() + index_, factor_id);
278 }
279
280 int64_t split_factor() const {
281 return split_factor_;
282 }
283
284 private:
285 const int64_t split_factor_ = 0;
286};
287
288//! For any singleton dimensions in the new view, we create an implicit
289//! broadcast dimension. We apply these transforms after the trivial reduction
290//! and view transformation steps.
291class BroadcastTransform final : public Transform {
292 public:
293 BroadcastTransform(int64_t index) : Transform(index) {}
294
295 virtual std::string toString() const override {
296 std::stringstream ss;
297 ss << "Broadcast at: " << index_ << std::endl;
298 return ss.str();
299 }
300};
301
302//! For any implicit broadcast dimensions in the original view, we remove
303//! them using a trivial reduction.
304class TrivialReductionTransform final : public Transform {
305 public:
306 TrivialReductionTransform(int64_t index) : Transform(index) {}
307
308 virtual std::string toString() const override {
309 std::stringstream ss;
310 ss << "Trivial reduction at: " << index_ << std::endl;
311 return ss.str();
312 }
313};
314
315//! The primary class that generates the transformations to go from
316//! the original view to the new view.
317class AnalyzeViewTransformation {
318 public:
319 AnalyzeViewTransformation(
320 const std::vector<int64_t>& original_view,
321 const std::vector<int64_t>& new_view,
322 std::vector<IterDomain*> root_domain = {})
323 : root_domain_not_provided_(root_domain.empty()),
324 root_domain_(root_domain),
325 root_is_transformed_(original_view.size(), false),
326 original_view_(original_view),
327 new_view_(new_view) {
328 TORCH_INTERNAL_ASSERT(
329 root_domain.empty() || original_view.size() == root_domain.size(),
330 "Incoming domain must match the original view sizes for view.");
331 // Check that the product of original and new view std::vector<int64_t> are
332 // equal.
333 const int64_t kOriginalNumElements = std::accumulate(
334 original_view_.begin(), original_view_.end(), 1, std::multiplies<>());
335 const int64_t kNewNumElements = std::accumulate(
336 new_view_.begin(), new_view.end(), 1, std::multiplies<>());
337 TORCH_INTERNAL_ASSERT(
338 kOriginalNumElements == kNewNumElements,
339 "Total element counts across view operation must match.");
340 }
341
342 AnalyzeViewConstraint constraint() {
343 findTransformation();
344
345 AnalyzeViewConstraint constraint;
346 constraint.original_constraint =
347 std::vector<int64_t>(original_view_.begin(), original_view_.end());
348 for (auto i : c10::irange(constraint.original_constraint.size())) {
349 if (constraint.original_constraint[i] != 1) {
350 constraint.original_constraint[i] = 0;
351 }
352 }
353
354 constraint.new_constraint =
355 std::vector<int64_t>(new_view_.begin(), new_view_.end());
356 for (auto i : c10::irange(constraint.new_constraint.size())) {
357 if (constraint.new_constraint[i] != 1) {
358 constraint.new_constraint[i] = 0;
359 }
360 }
361
362 for (auto trivial_reduce : trivial_reduction_transforms_) {
363 constraint.trivial_reduction_string.push_back(trivial_reduce->index());
364 }
365
366 for (auto broadcast : broadcast_transforms_) {
367 constraint.broadcast_string.push_back(broadcast->index());
368 }
369
370 // Dilimeter for split/merge transforms is -2
371 for (auto split_merge : view_transforms_) {
372 if (split_merge->isA<SplitTransform>()) {
373 constraint.split_merge_string.push_back(split_merge->index());
374 constraint.split_merge_string.push_back(
375 split_merge->as<SplitTransform>()->split_factor());
376 constraint.split_merge_string.push_back(-2);
377 } else {
378 TORCH_INTERNAL_ASSERT(
379 split_merge->isA<MergeTransform>(),
380 "Unrecognized transformation found.");
381 constraint.split_merge_string.push_back(split_merge->index());
382 constraint.split_merge_string.push_back(-2);
383 }
384 }
385
386 return constraint;
387 }
388
389 // Fill out all the information needed in AnalyzeViewResult, this should
390 // contain all the information of what's required to perform the view
391 // operation.
392 AnalyzeViewResult run() {
393 // Find all the transformations to go from the original tensor domain to the
394 // final output of the view operations.
395 findTransformation();
396
397 auto trivial_reduction_axes = generateTrivialReductionAxes();
398 auto broadcast_axes = generateBroadcastAxes();
399
400 // Move data to AnalyzeViewResult and return it.
401 return {broadcast_axes, trivial_reduction_axes, view_transforms_};
402 }
403
404 private:
405 // Returns the bool flags that should be used to broadcast the output view
406 // tensor
407 std::vector<bool> generateBroadcastAxes() {
408 std::vector<bool> broadcast_axes(new_view_.size(), false);
409 for (auto& bcast : broadcast_transforms_) {
410 broadcast_axes.at(bcast->index()) = true;
411 }
412 return broadcast_axes;
413 }
414
415 // Returns the positions for the trivial reductions to be performed before the
416 // view operation
417 std::vector<int> generateTrivialReductionAxes() {
418 std::vector<int> reduction_axes;
419 for (auto& tred : trivial_reduction_transforms_) {
420 reduction_axes.push_back(tred->index());
421 }
422 return reduction_axes;
423 }
424
425 std::string toString() {
426 std::stringstream output;
427 output << "===============================" << std::endl;
428 output << "old:";
429 for (auto s : original_view_) {
430 output << " " << s;
431 }
432 output << std::endl;
433
434 output << "===============================" << std::endl;
435 output << "new:";
436 for (auto s : new_view_) {
437 output << " " << s;
438 }
439 output << std::endl;
440
441 output << "===============================" << std::endl;
442 for (auto& trivial_reduction : trivial_reduction_transforms_) {
443 output << trivial_reduction->toString() << "\n";
444 }
445 for (auto& split_or_merge : view_transforms_) {
446 output << split_or_merge->toString() << "\n";
447 }
448 for (auto& broadcast : broadcast_transforms_) {
449 output << broadcast->toString() << "\n";
450 }
451 output << "===============================" << std::endl;
452 return output.str();
453 }
454
455 // Validation check after transformations are all found
456
457 bool isImplicitBroadcast(int64_t original_view_index) const {
458 if (root_domain_not_provided_) {
459 return original_view_[original_view_index] == 1;
460 } else {
461 TORCH_INTERNAL_ASSERT(original_view_index < (int64_t)root_domain_.size());
462 return root_domain_[original_view_index]->isImplicitBroadcast() &&
463 !root_domain_[original_view_index]->hasExpandedExtent();
464 }
465 }
466
467 //! Find the broadcast, merge and split operations necessary
468 //! to transform the original view into the new view
469 void findTransformation() {
470 // There are three particularly important state indices we're working with.
471 // There is:
472 // 1) original_view_index which is indexing into the original tensor
473 // domain after all reductions are removed. This lines up with the last
474 // domain in original view that we added to current_size.
475 // 2) transform_view_index which is the index of the transformations as
476 // we're virtually "developing" the output tensor domain (split/merge
477 // transformations post trivial reductions).
478 // 3) The new_view_index which is directly associated with the new_view
479 // and the dimension in new_view we're currently trying to create.
480
481 int64_t original_view_index = 0;
482 int64_t transform_view_index = 0;
483 int64_t new_view_index = 0;
484 int64_t current_size = original_view_[0];
485
486 // Safety counters to make sure we don't end up in an infinite loop.
487 int64_t prev_original_view_index = std::numeric_limits<int64_t>::max();
488 int64_t prev_new_view_index = std::numeric_limits<int64_t>::max();
489
490 TORCH_INTERNAL_ASSERT(
491 view_transforms_.empty(),
492 "Already ran find transformation pass for View op, cannot run a second time.");
493
494 // Iterate until original view is completely consumed and new view is
495 // completely generated.
496 while (original_view_index < (int64_t)original_view_.size() ||
497 new_view_index < (int64_t)new_view_.size()) {
498 TORCH_INTERNAL_ASSERT(
499 !(prev_new_view_index == new_view_index &&
500 prev_original_view_index == original_view_index),
501 "Infinite loop detected in AnalyzeViewTransformation::findTransformation(). Bailing.");
502
503 prev_new_view_index = new_view_index;
504 prev_original_view_index = original_view_index;
505
506 if (new_view_index >= (int64_t)new_view_.size()) {
507 TORCH_INTERNAL_ASSERT(
508 current_size == 1,
509 "View is complete, but there's still some elements to distribute.");
510 }
511
512 if ((new_view_index + 1 >= (int64_t)new_view_.size() ||
513 (new_view_[new_view_index + 1] != 1)) &&
514 original_view_index + 1 < (int64_t)original_view_.size() &&
515 original_view_[original_view_index + 1] == 1 &&
516 !isImplicitBroadcast(original_view_index + 1)) {
517 // Next index in original_view is runtime size 1 and next new view is
518 // not, merge the size 1 into the current view before moving on. Even if
519 // the current size and new view size match we could have a trailing
520 // size 1 dimension on the input that needs to be merged in.
521 view_transforms_.push_back(
522 std::make_shared<MergeTransform>(transform_view_index));
523 ++original_view_index;
524 continue;
525 }
526
527 if (new_view_index < (int64_t)new_view_.size() &&
528 // Still new dimensions to resolve and current size does resolve it.
529 current_size == new_view_[new_view_index]) {
530 // Keep this dimension, it's good to go, we hit a boundary where there's
531 // a multiple of original dims, that matches a multiple of view dims.
532 // Increment state and keep going.
533
534 ++transform_view_index;
535 ++new_view_index;
536 ++original_view_index;
537
538 // Update current_size with the next size in original view
539 if (original_view_index < (int64_t)original_view_.size()) {
540 current_size = original_view_[original_view_index];
541 } else {
542 current_size = 0;
543 }
544 continue;
545 }
546
547 // Compile time broadcast in new view, but not a matching one in original
548 // view. Insert broadcast and increment new_view. Size 1 dimensions in
549 // new_view that don't match up with runtime size 1's in original view are
550 // assumed to be broadcast (not a split from a runtime domain).
551 if (new_view_index < (int64_t)new_view_.size() &&
552 new_view_[new_view_index] == 1) {
553 broadcast_transforms_.push_back(
554 std::make_shared<BroadcastTransform>(new_view_index));
555 ++new_view_index;
556 continue;
557 }
558
559 // If we run out of original_view dimensions we could still have broadcast
560 // dimensions for new_view, but that should be hit before this point.
561 TORCH_INTERNAL_ASSERT(
562 current_size != 0,
563 "View analysis failed, should never process an empty size unless we ",
564 "simply need to add broadcasts to the post-view domain.");
565
566 if (current_size == 1 && isImplicitBroadcast(original_view_index)) {
567 // Original view has a compile time size 1 dimension, and it's not found
568 // in the new_view_ (otherwise would have been caught in a branch
569 // above). Do a trivial reduction.
570 trivial_reduction_transforms_.push_back(
571 std::make_shared<TrivialReductionTransform>(original_view_index));
572 ++original_view_index;
573
574 // Update original position and current size.
575 if (original_view_index < (int64_t)original_view_.size()) {
576 current_size = original_view_[original_view_index];
577 } else {
578 current_size = 0;
579 }
580
581 continue;
582 }
583
584 if (original_view_index + 1 < (int64_t)original_view_.size() &&
585 isImplicitBroadcast(original_view_index + 1)) {
586 // Original view has a compile time size 1 dimension, and it's
587 // interfering with necessary transformations. Do a trivial reduction.
588 ++original_view_index;
589 trivial_reduction_transforms_.push_back(
590 std::make_shared<TrivialReductionTransform>(original_view_index));
591
592 continue;
593 }
594
595 // We're only left with performing transformations to match a new_view
596 // dimension, there must be an activew new_view.
597 TORCH_INTERNAL_ASSERT(
598 new_view_index < (int64_t)new_view_.size(),
599 "Expecting to still have new dimensions to work on in view, but none left.");
600
601 if (new_view_index < (int64_t)new_view_.size() &&
602 current_size % new_view_[new_view_index] == 0) {
603 // Insert split to generate the next new_view domain.
604 view_transforms_.push_back(std::make_shared<SplitTransform>(
605 transform_view_index, new_view_[new_view_index]));
606 current_size /= new_view_[new_view_index];
607 TORCH_INTERNAL_ASSERT(current_size > 1, "This should be unreachable.");
608 // Update transform and new since a split doesn't increment from the
609 // original domain we're working on.
610 ++transform_view_index;
611 ++new_view_index;
612 continue;
613 }
614
615 // Need more of the original_view dimension to resolve the new_view
616 // dimension, merge the next dimension in.
617 TORCH_INTERNAL_ASSERT(
618 original_view_index + 1 < (int64_t)original_view_.size(),
619 "Expecting to still have original dimensions to work on in view, but none left.");
620
621 view_transforms_.push_back(
622 std::make_shared<MergeTransform>(transform_view_index));
623 current_size *= original_view_[++original_view_index];
624 }
625 }
626
627 private:
628 std::vector<std::shared_ptr<ViewTransform>> view_transforms_;
629 std::vector<std::shared_ptr<BroadcastTransform>> broadcast_transforms_;
630 std::vector<std::shared_ptr<TrivialReductionTransform>>
631 trivial_reduction_transforms_;
632
633 // If root domain isn't provided always assume size-1 dimensions are
634 // compile-time dimensions. TODO: Remove runtime size-1 dimension support.
635 // This should be cached higher in the stack.
636 const bool root_domain_not_provided_ = true;
637
638 const std::vector<IterDomain*> root_domain_;
639 // Track if the root ID was transformed or kept ()
640 std::vector<bool> root_is_transformed_;
641 const std::vector<int64_t>& original_view_;
642 const std::vector<int64_t>& new_view_;
643};
644
645//! Create new TensorDomain with a new root domain and modified rfactor domains
646//! using the specified view transformations. Original domain should already be
647//! without reduction axes.
648TensorDomain* createViewDomain(
649 TensorDomain* original_domain,
650 const AnalyzeViewResult& view_analysis) {
651 FUSER_PERF_SCOPE("createViewDomain");
652 TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty());
653
654 std::vector<IterDomain*> new_root_domain;
655 auto orig_root_domain = original_domain->getMaybeRFactorDomain();
656
657 // Apply trivial reductions.
658 for (auto id_i : c10::irange(orig_root_domain.size())) {
659 auto id = orig_root_domain[id_i];
660 if (id->isReduction()) {
661 continue;
662 }
663 if (std::find(
664 view_analysis.trivial_reduction_axes.begin(),
665 view_analysis.trivial_reduction_axes.end(),
666 (int)id_i) != view_analysis.trivial_reduction_axes.end()) {
667 continue;
668 }
669
670 new_root_domain.push_back(id->cloneWithoutRFactor());
671 }
672
673 std::vector<IterDomain*> new_rfactor_domain(
674 new_root_domain.begin(), new_root_domain.end());
675
676 // Apply rfactor transformations.
677 for (auto& t : view_analysis.transforms) {
678 t->createRfactorDomain(new_root_domain, new_rfactor_domain);
679 }
680
681 return IrBuilder::create<TensorDomain>(
682 new_root_domain,
683 new_rfactor_domain,
684 new_rfactor_domain,
685 std::vector<bool>(new_rfactor_domain.size(), true));
686}
687
688} // namespace
689
690std::pair<std::vector<int64_t>, std::vector<int64_t>> inferViewShapes(
691 const std::vector<int64_t>& original_sizes,
692 const std::vector<int64_t>& new_sizes) {
693 bool valid_original_sizes = std::all_of(
694 original_sizes.begin(), original_sizes.end(), [](int64_t dim) {
695 return dim > 0;
696 });
697 TORCH_INTERNAL_ASSERT(valid_original_sizes);
698
699 std::vector<int64_t> original_view(
700 original_sizes.begin(), original_sizes.end());
701 std::vector<int64_t> new_view(new_sizes.size());
702
703 // TODO: refactor
704 int64_t dynamic_index = -1;
705 int64_t new_size_num_elements = 1;
706 for (int64_t idx = 0; idx < (int64_t)new_sizes.size(); ++idx) {
707 if (new_sizes[idx] == -1) {
708 TORCH_INTERNAL_ASSERT(
709 dynamic_index == -1, "Only one dimension can by inferred.")
710 dynamic_index = idx;
711 } else {
712 TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0);
713 new_size_num_elements *= new_sizes[idx];
714 new_view[idx] = new_sizes[idx];
715 }
716 }
717
718 const int64_t kNumElements = std::accumulate(
719 original_view.begin(), original_view.end(), 1, std::multiplies<>());
720 if (dynamic_index != -1) {
721 new_view[dynamic_index] = kNumElements / new_size_num_elements;
722 }
723
724 return {original_view, new_view};
725}
726
727//! Generates the transformations necessary to convert
728//! from the original view into the new view.
729AnalyzeViewResult analyzeView(
730 const TensorView* original_view_tv,
731 const std::vector<int64_t>& original_sizes,
732 const std::vector<int64_t>& new_sizes) {
733 FUSER_PERF_SCOPE("analyzeView");
734 TORCH_INTERNAL_ASSERT(
735 original_sizes.size() > 0,
736 "Empty original size not supported for view operation.");
737
738 TORCH_INTERNAL_ASSERT(
739 TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain())
740 .size() == original_sizes.size());
741
742 // Fill -1 dimension in new_std::vector<int64_t> with size infered from all
743 // other values
744 auto sizes = inferViewShapes(original_sizes, new_sizes);
745
746 // Analysize the transformations required to go from original_sizes to
747 // new_sizes
748 AnalyzeViewTransformation analyzer(
749 sizes.first /* original_view */,
750 sizes.second /* new_view */,
751 TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain()));
752 return analyzer.run();
753}
754
755AnalyzeViewConstraint analyzeViewConstraint(
756 const std::vector<int64_t>& original_sizes,
757 const std::vector<int64_t>& new_sizes) {
758 FUSER_PERF_SCOPE("analyzeViewConstraint");
759 auto sizes = inferViewShapes(original_sizes, new_sizes);
760 AnalyzeViewTransformation analyzer(
761 sizes.first /* original_view */, sizes.second /* new_view */);
762 return analyzer.constraint();
763}
764
765//! Create new TensorDomain with a modified rfactor domain using the specified
766//! view transformations
767TensorDomain* transformView(
768 TensorDomain* original_domain,
769 const AnalyzeViewResult& view_analysis) {
770 FUSER_PERF_SCOPE("transformView");
771 return createViewDomain(original_domain, view_analysis);
772}
773
774} // namespace cuda
775} // namespace fuser
776} // namespace jit
777} // namespace torch
778