1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <disjoint_set.h>
6#include <ir_all_nodes.h>
7#include <ir_iostream.h>
8#include <iter_visitor.h>
9#include <root_domain_map.h>
10#include <unordered_map>
11#include <vector>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18namespace {
19
20// Enable pair<IterDomain*, size_t> in a set, size_t must be unique in set
21struct id_int_lt {
22 bool operator()(
23 const std::pair<IterDomain*, size_t>& first,
24 const std::pair<IterDomain*, size_t>& second) const {
25 return first.second < second.second;
26 }
27};
28
29} // namespace
30
31// Uses the history of _target_domain, and replays that history using the
32// provided map.
33//
34// target_domain contains the history we want replayed.
35//
36// id_map maps IterDomains in that history to the IterDomains we want it
37// replayed on.
38//
39// error_on_failure = true will cause the replay to error if we can't replay any
40// operation in target_domain's history due to missing IDs in the id_map.
41//
42// If error_on_failure = false, replay will replay everything it can, and ignore
43// operations it can't.
44class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor {
45 protected:
46 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
47 const std::vector<IterDomain*>& target_domain_;
48 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
49 std::unordered_map<IterDomain*, IterDomain*> id_map_;
50 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
51 std::unordered_map<IterDomain*, size_t> leaf_ids_;
52 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
53 std::vector<IterDomain*> leaf_vec_;
54 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
55 size_t counter = 0;
56 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
57 bool error_on_failure_ = true;
58 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
59 bool ran_replay = false; // Mark if replay has been run
60 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
61 bool replay_swizzle_ = false;
62 using IterVisitor::handle;
63
64 // Transform dispatch
65 void handle(Expr* e) override;
66
67 // We're going to replay this split operation on the corresponding ID
68 void handle(Split* s) override;
69
70 // We're going to replay this merge operation on the corresponding IDs
71 void handle(Merge* m) override;
72
73 // We're going to replay this swizzle operation on the corresponding IDs
74 // if replaying swizzle is enabled.
75 void handle(Swizzle2D* m) override;
76
77 public:
78 ReplayTransformations(
79 const std::vector<IterDomain*>& _target_domain,
80 std::unordered_map<IterDomain*, IterDomain*> _id_map,
81 bool _error_on_failure = true,
82
83 // Indicates if we want to replay swizzle ops on the replayed
84 // tensor.
85 // The swizzle op will be replayed if true,
86 // The swizzle inputs will be directly forwarded, and therefore skipping
87 // the swizzle op if false.
88 // Currently this options should always be off but
89 // later we may have cases in scheduling large fusions where
90 // this functionality could be useful.
91 bool replay_swizzle = false);
92
93 // Replays outputs that were generated from ids.first on ids.second
94 void runReplay();
95
96 // Returns map from provided target domain to their corresponding IDs
97 const std::unordered_map<IterDomain*, IterDomain*>& getReplay() {
98 if (!ran_replay)
99 runReplay();
100 return id_map_;
101 }
102
103 // Returns leaf_ids_ the size_t marks the order in which they were put into
104 // the map, this is part of the structure because it's used to generate the
105 // order from 'getLeafIDs'
106 const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() {
107 if (!ran_replay)
108 runReplay();
109 return leaf_ids_;
110 }
111
112 // Returns all terminating IDs that resulted from the replay. Leaf IDs are run
113 // to run deterministic, but otherwise in no specific order.
114 const std::vector<IterDomain*>& getLeafIDs() {
115 if (!ran_replay)
116 runReplay();
117 return leaf_vec_;
118 }
119};
120
121/*
122 * Motivation:
123 *
124 * Consider the following program:
125 *
126 * T1[I0, R1] = T0[I0, I1]
127 * T2[I0] = T1[I0, R1i]
128 *
129 * T1->split(1, factor)
130 * T1->rFactor(2)
131 *
132 * T4[I0, R1orf, I1irf] = T0[I0, I1]
133 * T1[I0, R1i] = T4[I0, R1orf, I1irf]
134 * T2[I0] = T1[I0, R1i]
135 *
136 * There's an issue when we call replayCasP on
137 * T4[I0, R1o, I1i] = T0[I0, I1]
138 *
139 * This would try to replay T4 as T0, and it could include the rfactor domains.
140 * For example we compute T0 inline with T4. The way computeAt is setup this
141 * would call replayPasC(T0, T4, -1) then repalyCasP(T4, T0, -1)
142 *
143 * We might assume that the only way we will hit this is if we call
144 * T4->computeAt(T0...) so it might be safe to assume that the right
145 * transformations would be replayed. However, we want to preserve the rfactor
146 * domain, so since it would replay T4 at root, it would produce iterdomains
147 * that wouldn't corresopnd to those in rfactor. Also, I don't know if this
148 * assumption is correct.
149 *
150 * Therefore, we will assume it is not correct, and we will validate here that
151 * if we replay a domain that it would transform it in a way consistent with
152 * any defined RFactor domains, then we will update the replay map so that
153 * RFactor roots are mapped to intermediate IterDomains in the target and start
154 * replay from there.
155 *
156 *
157 * SHORT DESCRIPTION:
158 *
159 * This class will validate/do the above. It will also run through
160 * transformations in target according to replay_map. If equal transformations
161 * already exist in replay_domain history, we will not redo those
162 * transformations, but instead update replay_map to reflect forwarding the
163 * existing transformations. This later part is the "best effort" replay. Though
164 * we include rfactor replay and validation here.
165 *
166 * Given an Expr in target_domain, check if its inputs are in replay_map. If so,
167 * check if the mapped domain in replay_map are recorded to be transformed by an
168 * equivelent operation in replay_domain's history. If so, "forward" the
169 * operation and update replay_map to the outputs of target_domain's output(s),
170 * to the output of the equivlent expr's outputs in relpay_domain's history.
171 *
172 * replay_map maps root IDs in the history of target_domain to root IDs in the
173 * history replay_domain
174 */
175
176class TORCH_CUDA_CU_API BestEffortReplay {
177 private:
178 std::unordered_map<IterDomain*, IterDomain*> target2replay_id_map_;
179 std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map_;
180 std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map_;
181 std::unordered_map<IterDomain*, size_t> leaf_ids_;
182 std::vector<IterDomain*> forwarded_ids_;
183
184 // Need to track which id's have been forwarded. Later need to make sure leaf
185 // nodes to produce compliment axes are properly tracked. i.e.
186 // T[i0, b1, b2, i3]
187 // -> T[i0, b1o, b1i, b2o, b2i, i3]
188 // -> T[i0*b1i*b2o, b1o, b2i, i3]
189 // -> T[i0*b1i*b2o*i3, b1o, b2i]
190 // If we forwarded i0 -> i0*b1i*b2o*i3, we need to know that b1o and b2i
191 // are leaf nodes even though their split wasn't part of targets replay.
192
193 // Counter to make sure best effort replay leaf_ids can be grabbed
194 // deterministicly
195 size_t counter = 0;
196
197 // Determine if current replay will ignore swizzle ops.
198 // When not skipping swizzles, swizzle ops will have to be matched
199 // same way as split and merge to progress forward on the mapping.
200 //
201 // When skipping swizzles, mismatched swizzle ops will not stop matching
202 // further down the tensor domains but only the swizzle outputs will be on
203 // the target to replay map, since we only generate one-to-one maps in
204 // BestEffortReplay and the swizzle outputs is just picked as a convention
205 // for simpler and uniform mapping behavior. The swizzle op inputs will be
206 // added by the disjoint set passes when building the iterdomain graph.
207 //
208 // Example:
209 // Target:
210 // I0o, I0i = split I0
211 // Ix0o, Ix0i = swizzle I0o, I0i
212 // I02 = merge Ix0o, Ix0i
213 // Replay:
214 // I1o, I1i = split I1
215 // I12 = merge I1o, I1i
216 //
217 // BestEffortReplay **no** skip swizzle gives:
218 // {
219 // I0->I1,
220 // I0o->I1o,
221 // I0i->I1i,
222 // }
223 //
224 // BestEffortReplay skip swizzle gives:
225 // {
226 // I0->I1,
227 // Ix0o->I1o,
228 // Ix0i->I1i,
229 // I02->I12
230 // }
231 //
232 bool skip_swizzle_ = true;
233
234 bool inReplayForwardMap(IterDomain* id) const {
235 return replay_forward_id_map_.find(id) != replay_forward_id_map_.end();
236 }
237
238 bool inTargetForwardMap(IterDomain* id) const {
239 return target_forward_id_map_.find(id) != target_forward_id_map_.end();
240 }
241
242 IterDomain* getReplayForwardedId(IterDomain* id) const {
243 auto forwarded_id_it = replay_forward_id_map_.find(id);
244 if (forwarded_id_it == replay_forward_id_map_.end()) {
245 return id;
246 } else {
247 return getReplayForwardedId(forwarded_id_it->second);
248 }
249 }
250
251 IterDomain* getTargetForwardedId(IterDomain* id) const {
252 auto forwarded_id_it = target_forward_id_map_.find(id);
253 if (forwarded_id_it == target_forward_id_map_.end()) {
254 return id;
255 } else {
256 return getTargetForwardedId(forwarded_id_it->second);
257 }
258 }
259
260 //! Adds complimenting IDs of forwarded IDs to the leaf map
261 void addComplimentLeafIDs(
262 const std::unordered_map<IterDomain*, IterDomain*>& forwarding_map,
263 const std::unordered_map<IterDomain*, std::vector<IterDomain*>>&
264 compliment_map);
265
266 // Skip swizzle step to make sure both target and
267 // replay swizzles are skipped while the mapping
268 // makes progress. This makes sure that, for example
269 // different tensors can still be inlined despite
270 // different local swizzle patterns.
271 void skipSwizzles(
272 const std::unordered_map<IterDomain*, Expr*>& target_id2expr,
273 const std::unordered_map<IterDomain*, Expr*>& replay_id2expr);
274
275 public:
276 BestEffortReplay(
277 const std::vector<IterDomain*>& replay_domain,
278 const std::vector<IterDomain*>& target_domain,
279 std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
280 std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map = {},
281 std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map = {},
282 bool skip_swizzle = true);
283
284 // Return iter domain map from target_domain IDs to their "replayed"
285 // replay_domain IDs. If not in map, was not replayed.
286 const std::unordered_map<IterDomain*, IterDomain*>& getReplay() const {
287 return target2replay_id_map_;
288 }
289
290 // ids in replay that did not have matching transforms in target_domain
291 const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() {
292 return leaf_ids_;
293 }
294
295 // Returned ordered set of IDs in getUnorderedLeafIDs
296 std::vector<IterDomain*> getLeafIDs() {
297 std::set<std::pair<IterDomain*, size_t>, id_int_lt> ordered_set;
298 for (auto entry : leaf_ids_)
299 ordered_set.emplace(entry);
300
301 std::vector<IterDomain*> leaf_vec_;
302 leaf_vec_.resize(ordered_set.size());
303 std::transform(
304 ordered_set.begin(),
305 ordered_set.end(),
306 leaf_vec_.begin(),
307 [](std::pair<IterDomain*, size_t> entry) { return entry.first; });
308 return leaf_vec_;
309 }
310
311 DisjointSets<IterDomain*> getDisjointSets();
312
313 // Runs a best effort replay that ignores broadcast axes that appear in
314 // consumer that are not mapped to producer in root_map.
315 static BestEffortReplay replayCasP(
316 const TensorView* consumer,
317 const TensorView* producer,
318 int producer_compute_at_axis,
319 const RootDomainMap& root_map);
320
321 // Runs a best effort replay that ignores broadcast axes that appear in
322 // consumer that are not mapped to producer in root_map.
323 static BestEffortReplay replayPasC(
324 const TensorView* producer,
325 const TensorView* consumer,
326 int consumer_compute_at_axis,
327 const RootDomainMap& root_map);
328
329 // Find the first position i where td1[i] is not the same as td2[i]. "Same"
330 // means the DAG and input IDs to generate td1[i] and td2[i] are the same.
331 // td1 and td2 are assumed to have some matching iter domains, as this is a
332 // strict same-ness check.
333 static int findFirstMismatchedID(
334 const TensorDomain* td1,
335 const TensorDomain* td2);
336};
337
338} // namespace cuda
339} // namespace fuser
340} // namespace jit
341} // namespace torch
342