1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_all_nodes.h> |
6 | |
7 | #include <memory> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | class ViewTransform; |
16 | |
17 | //! |
18 | //! The goal of analyzeView is to find the minimum number of transformations |
19 | //! to convert from the original size to the new size. A naive view algorithm |
20 | //! would merge all axis together and then split according to the new sizes. |
21 | //! |
22 | //! This implementation will keep the original domains, if the domains are the |
23 | //! same size in the original and new shapes. If an original domain is not |
24 | //! evenly divisible by the new domain, we will merge the minimum number of |
25 | //! adjacent original domains. |
26 | //! |
27 | //! The view transformations are processed in the following order: |
28 | //! 1. Trivial Reductions - Removes size-1 broadcast dimensions |
29 | //! 2. Keep, Merge, Split - Used to create new rfactor domain |
30 | //! 3. Broadcast - Inserts size-1 dimensions |
31 | //! |
32 | //! Broadcast is handled last because size-1 dimension can be inserted anywhere |
33 | //! in the new shape. |
34 | //! |
35 | |
36 | struct AnalyzeViewResult { |
37 | std::vector<bool> broadcast_axes; |
38 | std::vector<int> trivial_reduction_axes; |
39 | std::vector<std::shared_ptr<ViewTransform>> transforms; |
40 | }; |
41 | |
42 | struct TORCH_CUDA_CU_API AnalyzeViewConstraint { |
43 | // 1 if size 1 dimension, otherwise 0; |
44 | std::vector<int64_t> original_constraint; |
45 | std::vector<int64_t> new_constraint; |
46 | // Just the positions of true in AnalyzeViewResult::trivial_reduction_axes |
47 | std::vector<int64_t> trivial_reduction_string; |
48 | // Just the positions of true in AnalyzeViewResult:broadcast_axes |
49 | std::vector<int64_t> broadcast_string; |
50 | // A stringified version of the transformations: |
51 | std::vector<int64_t> split_merge_string; |
52 | |
53 | std::vector<int64_t> conglomerateString() const { |
54 | // Don't think this is necessary but just being safe. Using |
55 | // -3 as a dilimeter between value groups. |
56 | std::vector<int64_t> conglomerate = { |
57 | (int64_t)original_constraint.size(), |
58 | (int64_t)new_constraint.size(), |
59 | -3}; |
60 | auto add_vec = [&conglomerate](const std::vector<int64_t>& vec) { |
61 | for (auto element : vec) { |
62 | conglomerate.push_back(element); |
63 | } |
64 | // TODO: Why doesn't this work? |
65 | // conglomerate.insert(conglomerate.back(), vec.begin(), vec.end()); |
66 | conglomerate.push_back(-3); |
67 | }; |
68 | add_vec(original_constraint); |
69 | add_vec(new_constraint); |
70 | add_vec(trivial_reduction_string); |
71 | add_vec(broadcast_string); |
72 | add_vec(split_merge_string); |
73 | return conglomerate; |
74 | } |
75 | |
76 | bool operator==(const AnalyzeViewConstraint& other) const { |
77 | return other.conglomerateString() == this->conglomerateString(); |
78 | } |
79 | |
80 | // Naive hashing function, likely has a lot of collisions, but may not matter |
81 | // too much if we don't expact many types of views. |
82 | size_t hash() { |
83 | size_t hash_value = 0; |
84 | for (auto val : conglomerateString()) { |
85 | if (val == std::numeric_limits<int64_t>::max()) { |
86 | continue; |
87 | } |
88 | hash_value += val; |
89 | } |
90 | return hash_value; |
91 | } |
92 | }; |
93 | |
94 | //! Infer -1 value in new view std::vector<int64_t> based on original view |
95 | //! std::vector<int64_t>. This shouldn't generally be used directly but is |
96 | //! useful for testing. |
97 | TORCH_CUDA_CU_API std::pair<std::vector<int64_t>, std::vector<int64_t>> |
98 | inferViewShapes( |
99 | const std::vector<int64_t>& original_sizes, |
100 | const std::vector<int64_t>& new_sizes); |
101 | |
102 | // Find the transformations necessary to convert TensorView |
103 | // from original size to new size. |
104 | AnalyzeViewResult analyzeView( |
105 | const TensorView* tv, |
106 | const std::vector<int64_t>& original_sizes, |
107 | const std::vector<int64_t>& new_sizes); |
108 | |
109 | // Find the constraints derived from the view transformations |
110 | TORCH_CUDA_CU_API AnalyzeViewConstraint analyzeViewConstraint( |
111 | const std::vector<int64_t>& original_sizes, |
112 | const std::vector<int64_t>& new_sizes); |
113 | |
114 | // Generate a new TensorDomain from the given view transformations. |
115 | // The original root domain is kept in the new TensorDomain, |
116 | // but a new rfactor domain is created from the view transformations. |
117 | TensorDomain* transformView( |
118 | TensorDomain* original_domain, |
119 | const AnalyzeViewResult& view_analysis); |
120 | |
121 | } // namespace cuda |
122 | } // namespace fuser |
123 | } // namespace jit |
124 | } // namespace torch |
125 | |