1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <ir_internal_nodes.h> |
6 | #include <maxinfo_propagator.h> |
7 | |
8 | #include <algorithm> |
9 | #include <unordered_map> |
10 | #include <unordered_set> |
11 | #include <vector> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | /* |
19 | * compute_at is a relative property between two TensorViews which marks at what |
20 | * iteration domain we're going to generate a tensor to be consumed by another. |
21 | * For example if we have: T2[I, J, K] = T1[I, J, K] * 2.0 and then we call |
22 | * T2.split(axis = 0, factor = ...): T2[Io, Ii, J, K] = T1[I, J, K] * 2.0 where |
23 | * Io is the outer axes from the split, and Ii is the inner axes from the split. |
24 | * then we call T1.compute_at(T2, axis=1) we would expect to have: |
25 | * T2[Io, Ii, J, K] = T1[Io, Ii, J, K] * 2.0 |
26 | * which would produce the following loop nest structure: |
27 | * |
28 | * for(io : Io) |
29 | * for(ii : Ii) |
30 | * for(j : J) |
31 | * for(k : K) |
32 | * //produce T1: |
33 | * T1[io, ii, j, k] = ... |
34 | * for(ii : Ii) |
35 | * for(j : J) |
36 | * for(k : K) |
37 | * //consume T1, produce T2 |
38 | * T2[io, ii, j, k] = T1[io, ii, j, k] * 2.0 |
39 | * |
40 | * This file provides the replay function that allows us to construct T1's |
41 | * domain from T2 at a desired level (compute_at_axis) without modifying any |
42 | * unnecessary parts of the domain. |
43 | * |
44 | * EXAMPLES: |
45 | * |
46 | * ANOTHER ITER EXAMPLE: |
47 | * T2[I, J, K] = T1[I, J, K] * 2.0 |
48 | * T2.split(axis = 0, factor = ...) |
49 | * T2[Io, Ii, J, K] = T1[I, J, K] * 2.0 |
50 | * T2.split(axis = 2, factor = ...) |
51 | * T2[Io, Ii, Jo, Ji, K] = T1[I, J, K] * 2.0 |
52 | * T1.compute_at(T2, axis=1) |
53 | * T2[Io, Ii, Jo, Ji, K] = T1[Io, Ii, J, K] * 2.0 |
54 | * |
55 | * Note: compute_at axis: |
56 | * T2[ 0 Io, 1 Ii, 2 Jo, 3 Ji, 4 K 5 ] //5 is inline, 0 is at "root" which means |
57 | * completely separate loop nests. |
58 | * |
59 | * for(io : Io) |
60 | * for(ii : Ii) |
61 | * for(j : J) |
62 | * for(k : K) |
63 | * //produce T1, this is the view that replay generates: |
64 | * T1[io, ii, j, k] = ... |
65 | * for(ii : Ii) |
66 | * for(jo : Jo) |
67 | * for(ji : Ji) |
68 | * for(k : K) |
69 | * //consume T1, produce T2 |
70 | * T2[io, ii, jo, ji, k] = T1[io, ii, jo, ji, k] * 2.0 |
71 | * //consumer view on T1 will be produced at a later stage. |
72 | * |
73 | * |
74 | * SIMPLE REDUCTION EXAMPLE: |
75 | * T1[I, J, K] = ... |
76 | * T2[I, R, K] = T1[I, J, K] //.sum(axis = 1), we reduce on R/J to produce |
77 | * T2[I, K] T2.split(axis = 0, factor = ...) T2[Io, Ii, R, K] = T1[I, J, K] |
78 | * T1.compute_at(T2, axis=3) |
79 | * T2[Io, Ii, R, K] = T1[Io, Ii, J, K] |
80 | * |
81 | * for(io : Io) |
82 | * for(ii : Ii) |
83 | * for(k : K) |
84 | * T2[io, ii, k] = init |
85 | * for(r : R) |
86 | * for(k : K) |
87 | * //produce T1: |
88 | * T1[io, ii, r, k] = ... |
89 | * //consume T1 produce T2: |
90 | * T2[io, ii, k] += T1[io, ii, r, k] |
91 | * |
92 | * |
93 | * REDUCTION EXAMPLE RESULTING IN AN ERROR: |
94 | * T1[I, R, K] = ... //R is reduction domain, we reduce on R to produce T1[I, |
95 | * K] T2[I, K] = T1[I, K] |
96 | * |
97 | * for(i : I) |
98 | * for(k : K) |
99 | * T1[i, k] = init |
100 | * for(r : R) |
101 | * for(k : K) |
102 | * T1[i, k] += ...[i, r, k] |
103 | * for(i : I) |
104 | * for(k : K) |
105 | * T2[i, k] = T1[i, k] |
106 | * |
107 | * T1.compute_at(T2, axis=2) |
108 | * This should be an error, or a warning and changed to: |
109 | * T1.compute_at(T2, axis=1) |
110 | * The error is because the kernel would have to be: |
111 | * |
112 | * for(i : I) |
113 | * T1[i, k] = init |
114 | * for(r : R) |
115 | * for(k : K) |
116 | * T1[i, k] += ...[i, r, k] |
117 | * for(k : K) |
118 | * T2[i, k] = T1[i, k] |
119 | * |
120 | * Otherwise we would produce incorrect results. |
121 | * |
122 | */ |
123 | |
124 | class TensorDomain; |
125 | class TensorView; |
126 | class RootDomainMap; |
127 | |
128 | class TORCH_CUDA_CU_API TransformReplay { |
129 | public: |
130 | // Replay producer as consumer, returns {producer, producer_compute_at_axis}. |
131 | static std::pair<TensorDomain*, unsigned int> replayPasC( |
132 | const TensorView* producer, |
133 | const TensorView* consumer, |
134 | int consumer_compute_at_axis, |
135 | bool replay_swizzle = false); |
136 | static std::pair<TensorDomain*, unsigned int> replayPasC( |
137 | const TensorView* producer, |
138 | const TensorView* consumer, |
139 | int consumer_compute_at_axis, |
140 | const RootDomainMap& root_map, |
141 | bool replay_swizzle = false); |
142 | |
143 | // Replay producer as consumer, returns {replayed_consumer_domain, |
144 | // consumer_compute_at_axis}. |
145 | static std::pair<TensorDomain*, unsigned int> replayCasP( |
146 | const TensorView* consumer, |
147 | const TensorView* producer, |
148 | int producer_compute_at_axis); |
149 | static std::pair<TensorDomain*, unsigned int> replayCasP( |
150 | const TensorView* consumer, |
151 | const TensorView* producer, |
152 | int producer_compute_at_axis, |
153 | const RootDomainMap& root_map); |
154 | |
155 | // Self replay. |
156 | static TensorDomain* fullSelfReplay( |
157 | const TensorDomain* new_self_root, |
158 | const TensorDomain* self); |
159 | |
160 | // Returns the leaf position in producer that matches with `consumer_pos` in |
161 | // consumer. Returns -1 if matching is impossible. This function can be used |
162 | // to test if replay is needed for getting matching outer dims. This function |
163 | // should be consistent with `replayPasC`: if you pass the tensors just |
164 | // replayed by replayPasC as inputs, you should return exactly the same |
165 | // position as `replayPasC`. However, this function is more tolerant than |
166 | // fully matching `replayPasC`: if in the consumer, there are unmappable |
167 | // dimensions, these dimensions are just ignored. |
168 | static int getMatchedLeafPosWithoutReplayPasC( |
169 | const TensorView* producer, |
170 | const TensorView* consumer, |
171 | int consumer_pos); |
172 | |
173 | // Returns the leaf position in consumer that matches with `producer_pos` in |
174 | // producer. Behavior similar to getMatchedLeafPosWithoutReplayPasC, except |
175 | // that we are also ignoring reductions in the producer. |
176 | static int getMatchedLeafPosWithoutReplayCasP( |
177 | const TensorView* consumer, |
178 | const TensorView* producer, |
179 | int producer_pos); |
180 | |
181 | // tests if two tensors has fully matching transformations |
182 | static bool fullSelfMatching( |
183 | const TensorView* replay, |
184 | const TensorView* target); |
185 | }; |
186 | |
187 | class TORCH_CUDA_CU_API TransformPropagator |
188 | : public MaxRootDomainInfoSpanningTree::Propagator { |
189 | protected: |
190 | std::unordered_map<TensorView*, size_t> replayed_pos_; |
191 | |
192 | public: |
193 | virtual void propagateC2P(TensorView* from, TensorView* to) override; |
194 | virtual void propagateP2C(TensorView* from, TensorView* to) override; |
195 | virtual void propagateSibling(TensorView* from, TensorView* to) override; |
196 | TransformPropagator(TensorView* from, int64_t pos = -1); |
197 | }; |
198 | |
199 | struct TORCH_CUDA_CU_API MostInlinedTransformPropagator |
200 | : public MaxRootDomainInfoSpanningTree::Propagator { |
201 | virtual void propagateC2P(TensorView* from, TensorView* to) override; |
202 | virtual void propagateP2C(TensorView* from, TensorView* to) override; |
203 | virtual void propagateSibling(TensorView* from, TensorView* to) override; |
204 | }; |
205 | |
206 | } // namespace cuda |
207 | } // namespace fuser |
208 | } // namespace jit |
209 | } // namespace torch |
210 | |