1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_all_nodes.h>
6
7#include <memory>
8#include <vector>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15class ViewTransform;
16
17//!
18//! The goal of analyzeView is to find the minimum number of transformations
19//! to convert from the original size to the new size. A naive view algorithm
20//! would merge all axis together and then split according to the new sizes.
21//!
22//! This implementation will keep the original domains, if the domains are the
23//! same size in the original and new shapes. If an original domain is not
24//! evenly divisible by the new domain, we will merge the minimum number of
25//! adjacent original domains.
26//!
27//! The view transformations are processed in the following order:
28//! 1. Trivial Reductions - Removes size-1 broadcast dimensions
29//! 2. Keep, Merge, Split - Used to create new rfactor domain
30//! 3. Broadcast - Inserts size-1 dimensions
31//!
32//! Broadcast is handled last because size-1 dimension can be inserted anywhere
33//! in the new shape.
34//!
35
36struct AnalyzeViewResult {
37 std::vector<bool> broadcast_axes;
38 std::vector<int> trivial_reduction_axes;
39 std::vector<std::shared_ptr<ViewTransform>> transforms;
40};
41
42struct TORCH_CUDA_CU_API AnalyzeViewConstraint {
43 // 1 if size 1 dimension, otherwise 0;
44 std::vector<int64_t> original_constraint;
45 std::vector<int64_t> new_constraint;
46 // Just the positions of true in AnalyzeViewResult::trivial_reduction_axes
47 std::vector<int64_t> trivial_reduction_string;
48 // Just the positions of true in AnalyzeViewResult:broadcast_axes
49 std::vector<int64_t> broadcast_string;
50 // A stringified version of the transformations:
51 std::vector<int64_t> split_merge_string;
52
53 std::vector<int64_t> conglomerateString() const {
54 // Don't think this is necessary but just being safe. Using
55 // -3 as a dilimeter between value groups.
56 std::vector<int64_t> conglomerate = {
57 (int64_t)original_constraint.size(),
58 (int64_t)new_constraint.size(),
59 -3};
60 auto add_vec = [&conglomerate](const std::vector<int64_t>& vec) {
61 for (auto element : vec) {
62 conglomerate.push_back(element);
63 }
64 // TODO: Why doesn't this work?
65 // conglomerate.insert(conglomerate.back(), vec.begin(), vec.end());
66 conglomerate.push_back(-3);
67 };
68 add_vec(original_constraint);
69 add_vec(new_constraint);
70 add_vec(trivial_reduction_string);
71 add_vec(broadcast_string);
72 add_vec(split_merge_string);
73 return conglomerate;
74 }
75
76 bool operator==(const AnalyzeViewConstraint& other) const {
77 return other.conglomerateString() == this->conglomerateString();
78 }
79
80 // Naive hashing function, likely has a lot of collisions, but may not matter
81 // too much if we don't expact many types of views.
82 size_t hash() {
83 size_t hash_value = 0;
84 for (auto val : conglomerateString()) {
85 if (val == std::numeric_limits<int64_t>::max()) {
86 continue;
87 }
88 hash_value += val;
89 }
90 return hash_value;
91 }
92};
93
94//! Infer -1 value in new view std::vector<int64_t> based on original view
95//! std::vector<int64_t>. This shouldn't generally be used directly but is
96//! useful for testing.
97TORCH_CUDA_CU_API std::pair<std::vector<int64_t>, std::vector<int64_t>>
98inferViewShapes(
99 const std::vector<int64_t>& original_sizes,
100 const std::vector<int64_t>& new_sizes);
101
102// Find the transformations necessary to convert TensorView
103// from original size to new size.
104AnalyzeViewResult analyzeView(
105 const TensorView* tv,
106 const std::vector<int64_t>& original_sizes,
107 const std::vector<int64_t>& new_sizes);
108
109// Find the constraints derived from the view transformations
110TORCH_CUDA_CU_API AnalyzeViewConstraint analyzeViewConstraint(
111 const std::vector<int64_t>& original_sizes,
112 const std::vector<int64_t>& new_sizes);
113
114// Generate a new TensorDomain from the given view transformations.
115// The original root domain is kept in the new TensorDomain,
116// but a new rfactor domain is created from the view transformations.
117TensorDomain* transformView(
118 TensorDomain* original_domain,
119 const AnalyzeViewResult& view_analysis);
120
121} // namespace cuda
122} // namespace fuser
123} // namespace jit
124} // namespace torch
125