1 | #include <transform_view.h> |
2 | |
3 | #include <arith.h> |
4 | #include <fusion.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_builder.h> |
7 | #include <ir_internal_nodes.h> |
8 | #include <ir_iostream.h> |
9 | #include <iter_visitor.h> |
10 | #include <transform_iter.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | //! There's three domains associated with performing a view operation: |
18 | //! 1) Original Domain: |
19 | //! This view is the original input to the view operation. It has no |
20 | //! transforms on it, it is however passed in without its reduction domains |
21 | //! (as is expected since we're trying to generate the output of the |
22 | //! operations). |
23 | //! |
24 | //! Trivially reduced domain: |
25 | //! Predicting which operations are trivial reduced are not trivial. If a |
26 | //! broadcast is between two iter domains in the original domain that must be |
27 | //! merged for the view transform: |
28 | //! - If the broadcast domain lines up with a broadcast domain in the final |
29 | //! tensor domain keep it. |
30 | //! - If the domain is size-1 but not marked as a broadcast domain (runtime |
31 | //! size==1) |
32 | //! Note: This isn't something we generally support consistently |
33 | //! - If the broadcast domain is marked as a compile time broadcast domain, |
34 | //! and doesn't line up with a broadcast domain in the final result. |
35 | //! Trivially reduce it. |
36 | //! The index for these transformations is marked as the index of the original |
37 | //! domain, as that's the input for the trivial reduction. This produces the |
38 | //! trivially reduced domain. |
39 | //! |
40 | //! Post-view Domain: |
41 | //! This domain is the original domain after the trivial reductions and all |
42 | //! transformations. This domain holds the rfactor domains determined by |
43 | //! merge/split operations of the find transformations pass. It is the final |
44 | //! domain without all the broadcast operations (can have some that were |
45 | //! preserved through the transformations). |
46 | //! For example: {1, 2, 1, 4} -> {1, 2, 1, 2, 2} doesn't have any |
47 | //! conflicts of the view transformation and the broadcast dimensions, |
48 | //! so they won't be trivial reduced, they will simply be propagated |
49 | //! through the view. |
50 | //! {1, 2, 1, 4} -> {1, 8, 1} does have the second 1 dimension in |
51 | //! between the 2 and 8 that have to be merged. The first broadcast axis |
52 | //! will be propagated through the domains unafected, yet the second |
53 | //! braodcast axis will be trivially reduced, then rebroadcasted. |
54 | //! The transformation index marked for the splits/merges to produce this |
55 | //! domain are done based on an "in progress" tensor view (called transform |
56 | //! view index in the find transformation pass). This allows us to simply apply |
57 | //! these transformations serially to produce this domain. |
58 | //! |
59 | //! Post-broadcast Domain: |
60 | //! This domain finally matches the output of the view operation fully and |
61 | //! can be used in further computations. |
62 | //! |
63 | //! View process at compute time: |
64 | //! 1) View takes in the input TensorView x, original runtime |
65 | //! std::vector<int64_t>, and viewed runtime std::vector<int64_t>. |
66 | //! 2) AnalyzeView is called Which will figure out what series of |
67 | //! transformations is required from the input tensor to the output tensor. |
68 | //! These transformations are recorded. |
69 | //! 3) Sum operation is called on the trivial reduction axes from the |
70 | //! analysis. |
71 | //! 4) applyViewTransforms will generate the output domain of the view |
72 | //! operation. |
73 | //! Calls TensorDomain::view(view_analysis) which returns the rfactored |
74 | //! domain. |
75 | //! Gets forwarded to transformView(TensorDomain, view_analysis) |
76 | //! Gets forwarded to createViewDomain(TensorDomain, view_analysis) |
77 | //! createViewDomain creates the new root domain, and calls |
78 | //! createRfactorDomain on view_analysis.transforms(). |
79 | //! 5) brooadcast will be called with view_analysis.broadcast_axes |
80 | //! |
81 | //! TODO: Caching assumes that all size-1 inputs are correctly marked as a |
82 | //! broadcast dimension. We should probably remove the runtime size-1 merge |
83 | //! support in find transformation. |
84 | //! |
85 | //! Simple abstract class to record transformation and the indices required to |
86 | //! apply it. |
87 | class Transform : public PolymorphicBase { |
88 | public: |
89 | virtual std::string toString() const = 0; |
90 | |
91 | int64_t index() const { |
92 | return index_; |
93 | } |
94 | |
95 | protected: |
96 | // Relevant location information for the transformation. Stored information is |
97 | // related to when we have to apply that transformation (see long comment at |
98 | // top of this file). |
99 | Transform(int64_t index) : index_(index) {} |
100 | |
101 | const int64_t index_ = 0; |
102 | }; |
103 | |
104 | class ViewTransform : public Transform { |
105 | public: |
106 | // Function to apply the transformation. Transformation is applied on |
107 | // current_transformed_domain. root_domain is required here to replace |
108 | // IterDomains so we can flip the rfactor flag on the root domain if it's |
109 | // involved in merge/split trasnforms to produce the rfactor domain. |
110 | virtual void createRfactorDomain( |
111 | std::vector<IterDomain*>& root_domain, |
112 | std::vector<IterDomain*>& current_transformed_domain) = 0; |
113 | |
114 | // Convenience function to replace id in root_domain with an id that has |
115 | // expand expanded, and rfactor flag turned on. |
116 | static IterDomain* replaceRootIdWithRFactor( |
117 | std::vector<IterDomain*>& root_domain, |
118 | IterDomain* id) { |
119 | auto root_domain_it = std::find(root_domain.begin(), root_domain.end(), id); |
120 | |
121 | TORCH_INTERNAL_ASSERT( |
122 | root_domain_it != root_domain.end(), |
123 | "Wanted to replace " , |
124 | id->toString(), |
125 | " in root with an rfactor dimension, but IterDomain was not found in root." ); |
126 | |
127 | auto root_domain_pos = std::distance(root_domain.begin(), root_domain_it); |
128 | |
129 | bool is_expanded_dim = id->hasExpandedExtent(); |
130 | |
131 | auto extent = is_expanded_dim ? id->expandedExtent() : id->extent(); |
132 | |
133 | auto cloned_id = |
134 | IterDomainBuilder(id) |
135 | .iter_type( |
136 | is_expanded_dim ? IterType::Iteration : id->getIterType()) |
137 | .extent(extent) |
138 | .expanded_extent(nullptr) |
139 | .is_rfactor_domain(true) |
140 | .build(); |
141 | |
142 | root_domain.erase(root_domain.begin() + root_domain_pos); |
143 | root_domain.insert(root_domain.begin() + root_domain_pos, cloned_id); |
144 | return cloned_id; |
145 | } |
146 | |
147 | // Debugging utility to convert the transformation into a string. |
148 | virtual std::string toString() const override = 0; |
149 | |
150 | protected: |
151 | ViewTransform(const int64_t& index) : Transform(index) {} |
152 | }; |
153 | |
154 | namespace { |
155 | //! The merge tranformation either combines two root iterDomains together OR |
156 | //! the last rfactor iterDomain with a root iterDomain. Unlike the general |
157 | //! TensorView merge there's no merging across axes not placed in consecutive |
158 | //! positions for View. |
159 | class MergeTransform final : public ViewTransform { |
160 | public: |
161 | MergeTransform(int64_t index) : ViewTransform(index) {} |
162 | |
163 | virtual std::string toString() const override { |
164 | std::stringstream ss; |
165 | ss << "Merge at index: " << index_; |
166 | return ss.str(); |
167 | } |
168 | |
169 | void createRfactorDomain( |
170 | std::vector<IterDomain*>& root_domain, |
171 | std::vector<IterDomain*>& current_transformed_domain) override { |
172 | TORCH_INTERNAL_ASSERT( |
173 | (index_ + 1) < (int64_t)current_transformed_domain.size(), |
174 | "Tried to apply: " , |
175 | toString(), |
176 | "\t To domain: \t" , |
177 | current_transformed_domain); |
178 | |
179 | // Assumed to never merge over non-contiguous dimensions. |
180 | IterDomain* outer_id = current_transformed_domain[index_]; |
181 | if (!outer_id->isRFactorProduct()) { |
182 | outer_id = replaceRootIdWithRFactor(root_domain, outer_id); |
183 | } |
184 | |
185 | IterDomain* inner_id = current_transformed_domain[index_ + 1]; |
186 | if (!inner_id->isRFactorProduct()) { |
187 | inner_id = replaceRootIdWithRFactor(root_domain, inner_id); |
188 | } |
189 | |
190 | TORCH_INTERNAL_ASSERT( |
191 | outer_id->start()->isZeroInt() && inner_id->start()->isZeroInt(), |
192 | "Didn't expect to apply view transformations on an iter domain" , |
193 | " starting at a non-zero position." ); |
194 | |
195 | auto merged_extent = mul(outer_id->extent(), inner_id->extent()); |
196 | |
197 | auto new_merged_id = |
198 | IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), merged_extent) |
199 | .is_rfactor_domain(true) |
200 | .build(); |
201 | |
202 | IrBuilder::create<Merge>(new_merged_id, outer_id, inner_id); |
203 | |
204 | current_transformed_domain.erase( |
205 | current_transformed_domain.begin() + index_); |
206 | current_transformed_domain.erase( |
207 | current_transformed_domain.begin() + index_); |
208 | current_transformed_domain.insert( |
209 | current_transformed_domain.begin() + index_, new_merged_id); |
210 | } |
211 | }; |
212 | |
213 | //! The split tranformation creates two new iterDomains via an outer split. |
214 | class SplitTransform final : public ViewTransform { |
215 | public: |
216 | SplitTransform(const int64_t index, int64_t split_factor) |
217 | : ViewTransform(index), split_factor_(split_factor) { |
218 | TORCH_INTERNAL_ASSERT( |
219 | split_factor > 0, |
220 | "Split factors must be greater than 0, but found " , |
221 | split_factor, |
222 | " during view transformation." ); |
223 | } |
224 | |
225 | virtual std::string toString() const override { |
226 | std::stringstream ss; |
227 | ss << "Split Index at: " << index_ << " by: " << split_factor_ << std::endl; |
228 | return ss.str(); |
229 | } |
230 | |
231 | void createRfactorDomain( |
232 | std::vector<IterDomain*>& root_domain, |
233 | std::vector<IterDomain*>& current_transformed_domain) override { |
234 | TORCH_INTERNAL_ASSERT( |
235 | index_ < (int64_t)current_transformed_domain.size(), |
236 | "Index: \t" , |
237 | index_, |
238 | "\t Domain Size:\t" , |
239 | current_transformed_domain.size()); |
240 | |
241 | auto factor = IrBuilder::create<Int>(split_factor_); |
242 | |
243 | IterDomain* id = current_transformed_domain[index_]; |
244 | if (!id->isRFactorProduct()) { |
245 | id = replaceRootIdWithRFactor(root_domain, id); |
246 | } |
247 | |
248 | TORCH_INTERNAL_ASSERT( |
249 | id->start()->isZeroInt(), |
250 | "Didn't expect to apply view transformations on an iter domain" , |
251 | " starting at a non-zero position." ); |
252 | |
253 | Val* remainder = ceilDiv(id->extent(), factor); |
254 | |
255 | // outer loop IterDomain |
256 | IterDomain* factor_id = |
257 | IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), factor) |
258 | .parallel_type(id->getParallelType()) |
259 | .iter_type(id->getIterType()) |
260 | .is_rfactor_domain(true) |
261 | .build(); |
262 | |
263 | // inner loop IterDomain |
264 | IterDomain* remainder_id = |
265 | IterDomainBuilder( |
266 | FusionGuard::getCurFusion()->zeroVal(), remainder->as<Int>()) |
267 | .is_rfactor_domain(true) |
268 | .build(); |
269 | |
270 | IrBuilder::create<Split>(factor_id, remainder_id, id, factor, false); |
271 | |
272 | current_transformed_domain.erase( |
273 | current_transformed_domain.begin() + index_); |
274 | current_transformed_domain.insert( |
275 | current_transformed_domain.begin() + index_, remainder_id); |
276 | current_transformed_domain.insert( |
277 | current_transformed_domain.begin() + index_, factor_id); |
278 | } |
279 | |
280 | int64_t split_factor() const { |
281 | return split_factor_; |
282 | } |
283 | |
284 | private: |
285 | const int64_t split_factor_ = 0; |
286 | }; |
287 | |
288 | //! For any singleton dimensions in the new view, we create an implicit |
289 | //! broadcast dimension. We apply these transforms after the trivial reduction |
290 | //! and view transformation steps. |
291 | class BroadcastTransform final : public Transform { |
292 | public: |
293 | BroadcastTransform(int64_t index) : Transform(index) {} |
294 | |
295 | virtual std::string toString() const override { |
296 | std::stringstream ss; |
297 | ss << "Broadcast at: " << index_ << std::endl; |
298 | return ss.str(); |
299 | } |
300 | }; |
301 | |
302 | //! For any implicit broadcast dimensions in the original view, we remove |
303 | //! them using a trivial reduction. |
304 | class TrivialReductionTransform final : public Transform { |
305 | public: |
306 | TrivialReductionTransform(int64_t index) : Transform(index) {} |
307 | |
308 | virtual std::string toString() const override { |
309 | std::stringstream ss; |
310 | ss << "Trivial reduction at: " << index_ << std::endl; |
311 | return ss.str(); |
312 | } |
313 | }; |
314 | |
315 | //! The primary class that generates the transformations to go from |
316 | //! the original view to the new view. |
317 | class AnalyzeViewTransformation { |
318 | public: |
319 | AnalyzeViewTransformation( |
320 | const std::vector<int64_t>& original_view, |
321 | const std::vector<int64_t>& new_view, |
322 | std::vector<IterDomain*> root_domain = {}) |
323 | : root_domain_not_provided_(root_domain.empty()), |
324 | root_domain_(root_domain), |
325 | root_is_transformed_(original_view.size(), false), |
326 | original_view_(original_view), |
327 | new_view_(new_view) { |
328 | TORCH_INTERNAL_ASSERT( |
329 | root_domain.empty() || original_view.size() == root_domain.size(), |
330 | "Incoming domain must match the original view sizes for view." ); |
331 | // Check that the product of original and new view std::vector<int64_t> are |
332 | // equal. |
333 | const int64_t kOriginalNumElements = std::accumulate( |
334 | original_view_.begin(), original_view_.end(), 1, std::multiplies<>()); |
335 | const int64_t kNewNumElements = std::accumulate( |
336 | new_view_.begin(), new_view.end(), 1, std::multiplies<>()); |
337 | TORCH_INTERNAL_ASSERT( |
338 | kOriginalNumElements == kNewNumElements, |
339 | "Total element counts across view operation must match." ); |
340 | } |
341 | |
342 | AnalyzeViewConstraint constraint() { |
343 | findTransformation(); |
344 | |
345 | AnalyzeViewConstraint constraint; |
346 | constraint.original_constraint = |
347 | std::vector<int64_t>(original_view_.begin(), original_view_.end()); |
348 | for (auto i : c10::irange(constraint.original_constraint.size())) { |
349 | if (constraint.original_constraint[i] != 1) { |
350 | constraint.original_constraint[i] = 0; |
351 | } |
352 | } |
353 | |
354 | constraint.new_constraint = |
355 | std::vector<int64_t>(new_view_.begin(), new_view_.end()); |
356 | for (auto i : c10::irange(constraint.new_constraint.size())) { |
357 | if (constraint.new_constraint[i] != 1) { |
358 | constraint.new_constraint[i] = 0; |
359 | } |
360 | } |
361 | |
362 | for (auto trivial_reduce : trivial_reduction_transforms_) { |
363 | constraint.trivial_reduction_string.push_back(trivial_reduce->index()); |
364 | } |
365 | |
366 | for (auto broadcast : broadcast_transforms_) { |
367 | constraint.broadcast_string.push_back(broadcast->index()); |
368 | } |
369 | |
370 | // Dilimeter for split/merge transforms is -2 |
371 | for (auto split_merge : view_transforms_) { |
372 | if (split_merge->isA<SplitTransform>()) { |
373 | constraint.split_merge_string.push_back(split_merge->index()); |
374 | constraint.split_merge_string.push_back( |
375 | split_merge->as<SplitTransform>()->split_factor()); |
376 | constraint.split_merge_string.push_back(-2); |
377 | } else { |
378 | TORCH_INTERNAL_ASSERT( |
379 | split_merge->isA<MergeTransform>(), |
380 | "Unrecognized transformation found." ); |
381 | constraint.split_merge_string.push_back(split_merge->index()); |
382 | constraint.split_merge_string.push_back(-2); |
383 | } |
384 | } |
385 | |
386 | return constraint; |
387 | } |
388 | |
389 | // Fill out all the information needed in AnalyzeViewResult, this should |
390 | // contain all the information of what's required to perform the view |
391 | // operation. |
392 | AnalyzeViewResult run() { |
393 | // Find all the transformations to go from the original tensor domain to the |
394 | // final output of the view operations. |
395 | findTransformation(); |
396 | |
397 | auto trivial_reduction_axes = generateTrivialReductionAxes(); |
398 | auto broadcast_axes = generateBroadcastAxes(); |
399 | |
400 | // Move data to AnalyzeViewResult and return it. |
401 | return {broadcast_axes, trivial_reduction_axes, view_transforms_}; |
402 | } |
403 | |
404 | private: |
405 | // Returns the bool flags that should be used to broadcast the output view |
406 | // tensor |
407 | std::vector<bool> generateBroadcastAxes() { |
408 | std::vector<bool> broadcast_axes(new_view_.size(), false); |
409 | for (auto& bcast : broadcast_transforms_) { |
410 | broadcast_axes.at(bcast->index()) = true; |
411 | } |
412 | return broadcast_axes; |
413 | } |
414 | |
415 | // Returns the positions for the trivial reductions to be performed before the |
416 | // view operation |
417 | std::vector<int> generateTrivialReductionAxes() { |
418 | std::vector<int> reduction_axes; |
419 | for (auto& tred : trivial_reduction_transforms_) { |
420 | reduction_axes.push_back(tred->index()); |
421 | } |
422 | return reduction_axes; |
423 | } |
424 | |
425 | std::string toString() { |
426 | std::stringstream output; |
427 | output << "===============================" << std::endl; |
428 | output << "old:" ; |
429 | for (auto s : original_view_) { |
430 | output << " " << s; |
431 | } |
432 | output << std::endl; |
433 | |
434 | output << "===============================" << std::endl; |
435 | output << "new:" ; |
436 | for (auto s : new_view_) { |
437 | output << " " << s; |
438 | } |
439 | output << std::endl; |
440 | |
441 | output << "===============================" << std::endl; |
442 | for (auto& trivial_reduction : trivial_reduction_transforms_) { |
443 | output << trivial_reduction->toString() << "\n" ; |
444 | } |
445 | for (auto& split_or_merge : view_transforms_) { |
446 | output << split_or_merge->toString() << "\n" ; |
447 | } |
448 | for (auto& broadcast : broadcast_transforms_) { |
449 | output << broadcast->toString() << "\n" ; |
450 | } |
451 | output << "===============================" << std::endl; |
452 | return output.str(); |
453 | } |
454 | |
455 | // Validation check after transformations are all found |
456 | |
457 | bool isImplicitBroadcast(int64_t original_view_index) const { |
458 | if (root_domain_not_provided_) { |
459 | return original_view_[original_view_index] == 1; |
460 | } else { |
461 | TORCH_INTERNAL_ASSERT(original_view_index < (int64_t)root_domain_.size()); |
462 | return root_domain_[original_view_index]->isImplicitBroadcast() && |
463 | !root_domain_[original_view_index]->hasExpandedExtent(); |
464 | } |
465 | } |
466 | |
467 | //! Find the broadcast, merge and split operations necessary |
468 | //! to transform the original view into the new view |
469 | void findTransformation() { |
470 | // There are three particularly important state indices we're working with. |
471 | // There is: |
472 | // 1) original_view_index which is indexing into the original tensor |
473 | // domain after all reductions are removed. This lines up with the last |
474 | // domain in original view that we added to current_size. |
475 | // 2) transform_view_index which is the index of the transformations as |
476 | // we're virtually "developing" the output tensor domain (split/merge |
477 | // transformations post trivial reductions). |
478 | // 3) The new_view_index which is directly associated with the new_view |
479 | // and the dimension in new_view we're currently trying to create. |
480 | |
481 | int64_t original_view_index = 0; |
482 | int64_t transform_view_index = 0; |
483 | int64_t new_view_index = 0; |
484 | int64_t current_size = original_view_[0]; |
485 | |
486 | // Safety counters to make sure we don't end up in an infinite loop. |
487 | int64_t prev_original_view_index = std::numeric_limits<int64_t>::max(); |
488 | int64_t prev_new_view_index = std::numeric_limits<int64_t>::max(); |
489 | |
490 | TORCH_INTERNAL_ASSERT( |
491 | view_transforms_.empty(), |
492 | "Already ran find transformation pass for View op, cannot run a second time." ); |
493 | |
494 | // Iterate until original view is completely consumed and new view is |
495 | // completely generated. |
496 | while (original_view_index < (int64_t)original_view_.size() || |
497 | new_view_index < (int64_t)new_view_.size()) { |
498 | TORCH_INTERNAL_ASSERT( |
499 | !(prev_new_view_index == new_view_index && |
500 | prev_original_view_index == original_view_index), |
501 | "Infinite loop detected in AnalyzeViewTransformation::findTransformation(). Bailing." ); |
502 | |
503 | prev_new_view_index = new_view_index; |
504 | prev_original_view_index = original_view_index; |
505 | |
506 | if (new_view_index >= (int64_t)new_view_.size()) { |
507 | TORCH_INTERNAL_ASSERT( |
508 | current_size == 1, |
509 | "View is complete, but there's still some elements to distribute." ); |
510 | } |
511 | |
512 | if ((new_view_index + 1 >= (int64_t)new_view_.size() || |
513 | (new_view_[new_view_index + 1] != 1)) && |
514 | original_view_index + 1 < (int64_t)original_view_.size() && |
515 | original_view_[original_view_index + 1] == 1 && |
516 | !isImplicitBroadcast(original_view_index + 1)) { |
517 | // Next index in original_view is runtime size 1 and next new view is |
518 | // not, merge the size 1 into the current view before moving on. Even if |
519 | // the current size and new view size match we could have a trailing |
520 | // size 1 dimension on the input that needs to be merged in. |
521 | view_transforms_.push_back( |
522 | std::make_shared<MergeTransform>(transform_view_index)); |
523 | ++original_view_index; |
524 | continue; |
525 | } |
526 | |
527 | if (new_view_index < (int64_t)new_view_.size() && |
528 | // Still new dimensions to resolve and current size does resolve it. |
529 | current_size == new_view_[new_view_index]) { |
530 | // Keep this dimension, it's good to go, we hit a boundary where there's |
531 | // a multiple of original dims, that matches a multiple of view dims. |
532 | // Increment state and keep going. |
533 | |
534 | ++transform_view_index; |
535 | ++new_view_index; |
536 | ++original_view_index; |
537 | |
538 | // Update current_size with the next size in original view |
539 | if (original_view_index < (int64_t)original_view_.size()) { |
540 | current_size = original_view_[original_view_index]; |
541 | } else { |
542 | current_size = 0; |
543 | } |
544 | continue; |
545 | } |
546 | |
547 | // Compile time broadcast in new view, but not a matching one in original |
548 | // view. Insert broadcast and increment new_view. Size 1 dimensions in |
549 | // new_view that don't match up with runtime size 1's in original view are |
550 | // assumed to be broadcast (not a split from a runtime domain). |
551 | if (new_view_index < (int64_t)new_view_.size() && |
552 | new_view_[new_view_index] == 1) { |
553 | broadcast_transforms_.push_back( |
554 | std::make_shared<BroadcastTransform>(new_view_index)); |
555 | ++new_view_index; |
556 | continue; |
557 | } |
558 | |
559 | // If we run out of original_view dimensions we could still have broadcast |
560 | // dimensions for new_view, but that should be hit before this point. |
561 | TORCH_INTERNAL_ASSERT( |
562 | current_size != 0, |
563 | "View analysis failed, should never process an empty size unless we " , |
564 | "simply need to add broadcasts to the post-view domain." ); |
565 | |
566 | if (current_size == 1 && isImplicitBroadcast(original_view_index)) { |
567 | // Original view has a compile time size 1 dimension, and it's not found |
568 | // in the new_view_ (otherwise would have been caught in a branch |
569 | // above). Do a trivial reduction. |
570 | trivial_reduction_transforms_.push_back( |
571 | std::make_shared<TrivialReductionTransform>(original_view_index)); |
572 | ++original_view_index; |
573 | |
574 | // Update original position and current size. |
575 | if (original_view_index < (int64_t)original_view_.size()) { |
576 | current_size = original_view_[original_view_index]; |
577 | } else { |
578 | current_size = 0; |
579 | } |
580 | |
581 | continue; |
582 | } |
583 | |
584 | if (original_view_index + 1 < (int64_t)original_view_.size() && |
585 | isImplicitBroadcast(original_view_index + 1)) { |
586 | // Original view has a compile time size 1 dimension, and it's |
587 | // interfering with necessary transformations. Do a trivial reduction. |
588 | ++original_view_index; |
589 | trivial_reduction_transforms_.push_back( |
590 | std::make_shared<TrivialReductionTransform>(original_view_index)); |
591 | |
592 | continue; |
593 | } |
594 | |
595 | // We're only left with performing transformations to match a new_view |
596 | // dimension, there must be an activew new_view. |
597 | TORCH_INTERNAL_ASSERT( |
598 | new_view_index < (int64_t)new_view_.size(), |
599 | "Expecting to still have new dimensions to work on in view, but none left." ); |
600 | |
601 | if (new_view_index < (int64_t)new_view_.size() && |
602 | current_size % new_view_[new_view_index] == 0) { |
603 | // Insert split to generate the next new_view domain. |
604 | view_transforms_.push_back(std::make_shared<SplitTransform>( |
605 | transform_view_index, new_view_[new_view_index])); |
606 | current_size /= new_view_[new_view_index]; |
607 | TORCH_INTERNAL_ASSERT(current_size > 1, "This should be unreachable." ); |
608 | // Update transform and new since a split doesn't increment from the |
609 | // original domain we're working on. |
610 | ++transform_view_index; |
611 | ++new_view_index; |
612 | continue; |
613 | } |
614 | |
615 | // Need more of the original_view dimension to resolve the new_view |
616 | // dimension, merge the next dimension in. |
617 | TORCH_INTERNAL_ASSERT( |
618 | original_view_index + 1 < (int64_t)original_view_.size(), |
619 | "Expecting to still have original dimensions to work on in view, but none left." ); |
620 | |
621 | view_transforms_.push_back( |
622 | std::make_shared<MergeTransform>(transform_view_index)); |
623 | current_size *= original_view_[++original_view_index]; |
624 | } |
625 | } |
626 | |
627 | private: |
628 | std::vector<std::shared_ptr<ViewTransform>> view_transforms_; |
629 | std::vector<std::shared_ptr<BroadcastTransform>> broadcast_transforms_; |
630 | std::vector<std::shared_ptr<TrivialReductionTransform>> |
631 | trivial_reduction_transforms_; |
632 | |
633 | // If root domain isn't provided always assume size-1 dimensions are |
634 | // compile-time dimensions. TODO: Remove runtime size-1 dimension support. |
635 | // This should be cached higher in the stack. |
636 | const bool root_domain_not_provided_ = true; |
637 | |
638 | const std::vector<IterDomain*> root_domain_; |
639 | // Track if the root ID was transformed or kept () |
640 | std::vector<bool> root_is_transformed_; |
641 | const std::vector<int64_t>& original_view_; |
642 | const std::vector<int64_t>& new_view_; |
643 | }; |
644 | |
645 | //! Create new TensorDomain with a new root domain and modified rfactor domains |
646 | //! using the specified view transformations. Original domain should already be |
647 | //! without reduction axes. |
648 | TensorDomain* createViewDomain( |
649 | TensorDomain* original_domain, |
650 | const AnalyzeViewResult& view_analysis) { |
651 | FUSER_PERF_SCOPE("createViewDomain" ); |
652 | TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty()); |
653 | |
654 | std::vector<IterDomain*> new_root_domain; |
655 | auto orig_root_domain = original_domain->getMaybeRFactorDomain(); |
656 | |
657 | // Apply trivial reductions. |
658 | for (auto id_i : c10::irange(orig_root_domain.size())) { |
659 | auto id = orig_root_domain[id_i]; |
660 | if (id->isReduction()) { |
661 | continue; |
662 | } |
663 | if (std::find( |
664 | view_analysis.trivial_reduction_axes.begin(), |
665 | view_analysis.trivial_reduction_axes.end(), |
666 | (int)id_i) != view_analysis.trivial_reduction_axes.end()) { |
667 | continue; |
668 | } |
669 | |
670 | new_root_domain.push_back(id->cloneWithoutRFactor()); |
671 | } |
672 | |
673 | std::vector<IterDomain*> new_rfactor_domain( |
674 | new_root_domain.begin(), new_root_domain.end()); |
675 | |
676 | // Apply rfactor transformations. |
677 | for (auto& t : view_analysis.transforms) { |
678 | t->createRfactorDomain(new_root_domain, new_rfactor_domain); |
679 | } |
680 | |
681 | return IrBuilder::create<TensorDomain>( |
682 | new_root_domain, |
683 | new_rfactor_domain, |
684 | new_rfactor_domain, |
685 | std::vector<bool>(new_rfactor_domain.size(), true)); |
686 | } |
687 | |
688 | } // namespace |
689 | |
690 | std::pair<std::vector<int64_t>, std::vector<int64_t>> inferViewShapes( |
691 | const std::vector<int64_t>& original_sizes, |
692 | const std::vector<int64_t>& new_sizes) { |
693 | bool valid_original_sizes = std::all_of( |
694 | original_sizes.begin(), original_sizes.end(), [](int64_t dim) { |
695 | return dim > 0; |
696 | }); |
697 | TORCH_INTERNAL_ASSERT(valid_original_sizes); |
698 | |
699 | std::vector<int64_t> original_view( |
700 | original_sizes.begin(), original_sizes.end()); |
701 | std::vector<int64_t> new_view(new_sizes.size()); |
702 | |
703 | // TODO: refactor |
704 | int64_t dynamic_index = -1; |
705 | int64_t new_size_num_elements = 1; |
706 | for (int64_t idx = 0; idx < (int64_t)new_sizes.size(); ++idx) { |
707 | if (new_sizes[idx] == -1) { |
708 | TORCH_INTERNAL_ASSERT( |
709 | dynamic_index == -1, "Only one dimension can by inferred." ) |
710 | dynamic_index = idx; |
711 | } else { |
712 | TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0); |
713 | new_size_num_elements *= new_sizes[idx]; |
714 | new_view[idx] = new_sizes[idx]; |
715 | } |
716 | } |
717 | |
718 | const int64_t kNumElements = std::accumulate( |
719 | original_view.begin(), original_view.end(), 1, std::multiplies<>()); |
720 | if (dynamic_index != -1) { |
721 | new_view[dynamic_index] = kNumElements / new_size_num_elements; |
722 | } |
723 | |
724 | return {original_view, new_view}; |
725 | } |
726 | |
727 | //! Generates the transformations necessary to convert |
728 | //! from the original view into the new view. |
729 | AnalyzeViewResult analyzeView( |
730 | const TensorView* original_view_tv, |
731 | const std::vector<int64_t>& original_sizes, |
732 | const std::vector<int64_t>& new_sizes) { |
733 | FUSER_PERF_SCOPE("analyzeView" ); |
734 | TORCH_INTERNAL_ASSERT( |
735 | original_sizes.size() > 0, |
736 | "Empty original size not supported for view operation." ); |
737 | |
738 | TORCH_INTERNAL_ASSERT( |
739 | TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain()) |
740 | .size() == original_sizes.size()); |
741 | |
742 | // Fill -1 dimension in new_std::vector<int64_t> with size infered from all |
743 | // other values |
744 | auto sizes = inferViewShapes(original_sizes, new_sizes); |
745 | |
746 | // Analysize the transformations required to go from original_sizes to |
747 | // new_sizes |
748 | AnalyzeViewTransformation analyzer( |
749 | sizes.first /* original_view */, |
750 | sizes.second /* new_view */, |
751 | TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain())); |
752 | return analyzer.run(); |
753 | } |
754 | |
755 | AnalyzeViewConstraint analyzeViewConstraint( |
756 | const std::vector<int64_t>& original_sizes, |
757 | const std::vector<int64_t>& new_sizes) { |
758 | FUSER_PERF_SCOPE("analyzeViewConstraint" ); |
759 | auto sizes = inferViewShapes(original_sizes, new_sizes); |
760 | AnalyzeViewTransformation analyzer( |
761 | sizes.first /* original_view */, sizes.second /* new_view */); |
762 | return analyzer.constraint(); |
763 | } |
764 | |
765 | //! Create new TensorDomain with a modified rfactor domain using the specified |
766 | //! view transformations |
767 | TensorDomain* transformView( |
768 | TensorDomain* original_domain, |
769 | const AnalyzeViewResult& view_analysis) { |
770 | FUSER_PERF_SCOPE("transformView" ); |
771 | return createViewDomain(original_domain, view_analysis); |
772 | } |
773 | |
774 | } // namespace cuda |
775 | } // namespace fuser |
776 | } // namespace jit |
777 | } // namespace torch |
778 | |