1#pragma once
2
3#include <c10/macros/Export.h>
4#include <c10/util/Exception.h>
5#include <ir_internal_nodes.h>
6#include <maxinfo_propagator.h>
7
8#include <algorithm>
9#include <unordered_map>
10#include <unordered_set>
11#include <vector>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18/*
19 * compute_at is a relative property between two TensorViews which marks at what
20 * iteration domain we're going to generate a tensor to be consumed by another.
21 * For example if we have: T2[I, J, K] = T1[I, J, K] * 2.0 and then we call
22 * T2.split(axis = 0, factor = ...): T2[Io, Ii, J, K] = T1[I, J, K] * 2.0 where
23 * Io is the outer axes from the split, and Ii is the inner axes from the split.
24 * then we call T1.compute_at(T2, axis=1) we would expect to have:
25 * T2[Io, Ii, J, K] = T1[Io, Ii, J, K] * 2.0
26 * which would produce the following loop nest structure:
27 *
28 * for(io : Io)
29 * for(ii : Ii)
30 * for(j : J)
31 * for(k : K)
32 * //produce T1:
33 * T1[io, ii, j, k] = ...
34 * for(ii : Ii)
35 * for(j : J)
36 * for(k : K)
37 * //consume T1, produce T2
38 * T2[io, ii, j, k] = T1[io, ii, j, k] * 2.0
39 *
40 * This file provides the replay function that allows us to construct T1's
41 * domain from T2 at a desired level (compute_at_axis) without modifying any
42 * unnecessary parts of the domain.
43 *
44 * EXAMPLES:
45 *
46 * ANOTHER ITER EXAMPLE:
47 * T2[I, J, K] = T1[I, J, K] * 2.0
48 * T2.split(axis = 0, factor = ...)
49 * T2[Io, Ii, J, K] = T1[I, J, K] * 2.0
50 * T2.split(axis = 2, factor = ...)
51 * T2[Io, Ii, Jo, Ji, K] = T1[I, J, K] * 2.0
52 * T1.compute_at(T2, axis=1)
53 * T2[Io, Ii, Jo, Ji, K] = T1[Io, Ii, J, K] * 2.0
54 *
55 * Note: compute_at axis:
56 * T2[ 0 Io, 1 Ii, 2 Jo, 3 Ji, 4 K 5 ] //5 is inline, 0 is at "root" which means
57 * completely separate loop nests.
58 *
59 * for(io : Io)
60 * for(ii : Ii)
61 * for(j : J)
62 * for(k : K)
63 * //produce T1, this is the view that replay generates:
64 * T1[io, ii, j, k] = ...
65 * for(ii : Ii)
66 * for(jo : Jo)
67 * for(ji : Ji)
68 * for(k : K)
69 * //consume T1, produce T2
70 * T2[io, ii, jo, ji, k] = T1[io, ii, jo, ji, k] * 2.0
71 * //consumer view on T1 will be produced at a later stage.
72 *
73 *
74 * SIMPLE REDUCTION EXAMPLE:
75 * T1[I, J, K] = ...
76 * T2[I, R, K] = T1[I, J, K] //.sum(axis = 1), we reduce on R/J to produce
77 * T2[I, K] T2.split(axis = 0, factor = ...) T2[Io, Ii, R, K] = T1[I, J, K]
78 * T1.compute_at(T2, axis=3)
79 * T2[Io, Ii, R, K] = T1[Io, Ii, J, K]
80 *
81 * for(io : Io)
82 * for(ii : Ii)
83 * for(k : K)
84 * T2[io, ii, k] = init
85 * for(r : R)
86 * for(k : K)
87 * //produce T1:
88 * T1[io, ii, r, k] = ...
89 * //consume T1 produce T2:
90 * T2[io, ii, k] += T1[io, ii, r, k]
91 *
92 *
93 * REDUCTION EXAMPLE RESULTING IN AN ERROR:
94 * T1[I, R, K] = ... //R is reduction domain, we reduce on R to produce T1[I,
95 * K] T2[I, K] = T1[I, K]
96 *
97 * for(i : I)
98 * for(k : K)
99 * T1[i, k] = init
100 * for(r : R)
101 * for(k : K)
102 * T1[i, k] += ...[i, r, k]
103 * for(i : I)
104 * for(k : K)
105 * T2[i, k] = T1[i, k]
106 *
107 * T1.compute_at(T2, axis=2)
108 * This should be an error, or a warning and changed to:
109 * T1.compute_at(T2, axis=1)
110 * The error is because the kernel would have to be:
111 *
112 * for(i : I)
113 * T1[i, k] = init
114 * for(r : R)
115 * for(k : K)
116 * T1[i, k] += ...[i, r, k]
117 * for(k : K)
118 * T2[i, k] = T1[i, k]
119 *
120 * Otherwise we would produce incorrect results.
121 *
122 */
123
124class TensorDomain;
125class TensorView;
126class RootDomainMap;
127
128class TORCH_CUDA_CU_API TransformReplay {
129 public:
130 // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
131 static std::pair<TensorDomain*, unsigned int> replayPasC(
132 const TensorView* producer,
133 const TensorView* consumer,
134 int consumer_compute_at_axis,
135 bool replay_swizzle = false);
136 static std::pair<TensorDomain*, unsigned int> replayPasC(
137 const TensorView* producer,
138 const TensorView* consumer,
139 int consumer_compute_at_axis,
140 const RootDomainMap& root_map,
141 bool replay_swizzle = false);
142
143 // Replay producer as consumer, returns {replayed_consumer_domain,
144 // consumer_compute_at_axis}.
145 static std::pair<TensorDomain*, unsigned int> replayCasP(
146 const TensorView* consumer,
147 const TensorView* producer,
148 int producer_compute_at_axis);
149 static std::pair<TensorDomain*, unsigned int> replayCasP(
150 const TensorView* consumer,
151 const TensorView* producer,
152 int producer_compute_at_axis,
153 const RootDomainMap& root_map);
154
155 // Self replay.
156 static TensorDomain* fullSelfReplay(
157 const TensorDomain* new_self_root,
158 const TensorDomain* self);
159
160 // Returns the leaf position in producer that matches with `consumer_pos` in
161 // consumer. Returns -1 if matching is impossible. This function can be used
162 // to test if replay is needed for getting matching outer dims. This function
163 // should be consistent with `replayPasC`: if you pass the tensors just
164 // replayed by replayPasC as inputs, you should return exactly the same
165 // position as `replayPasC`. However, this function is more tolerant than
166 // fully matching `replayPasC`: if in the consumer, there are unmappable
167 // dimensions, these dimensions are just ignored.
168 static int getMatchedLeafPosWithoutReplayPasC(
169 const TensorView* producer,
170 const TensorView* consumer,
171 int consumer_pos);
172
173 // Returns the leaf position in consumer that matches with `producer_pos` in
174 // producer. Behavior similar to getMatchedLeafPosWithoutReplayPasC, except
175 // that we are also ignoring reductions in the producer.
176 static int getMatchedLeafPosWithoutReplayCasP(
177 const TensorView* consumer,
178 const TensorView* producer,
179 int producer_pos);
180
181 // tests if two tensors has fully matching transformations
182 static bool fullSelfMatching(
183 const TensorView* replay,
184 const TensorView* target);
185};
186
187class TORCH_CUDA_CU_API TransformPropagator
188 : public MaxRootDomainInfoSpanningTree::Propagator {
189 protected:
190 std::unordered_map<TensorView*, size_t> replayed_pos_;
191
192 public:
193 virtual void propagateC2P(TensorView* from, TensorView* to) override;
194 virtual void propagateP2C(TensorView* from, TensorView* to) override;
195 virtual void propagateSibling(TensorView* from, TensorView* to) override;
196 TransformPropagator(TensorView* from, int64_t pos = -1);
197};
198
199struct TORCH_CUDA_CU_API MostInlinedTransformPropagator
200 : public MaxRootDomainInfoSpanningTree::Propagator {
201 virtual void propagateC2P(TensorView* from, TensorView* to) override;
202 virtual void propagateP2C(TensorView* from, TensorView* to) override;
203 virtual void propagateSibling(TensorView* from, TensorView* to) override;
204};
205
206} // namespace cuda
207} // namespace fuser
208} // namespace jit
209} // namespace torch
210