1#include <transform_rfactor.h>
2
3#include <arith.h>
4#include <fusion.h>
5#include <instrumentation.h>
6#include <ir_builder.h>
7#include <ir_iostream.h>
8#include <ir_utils.h>
9#include <iter_visitor.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16namespace {
17
18// This class replays the root domains of the producer of an rfactor domain.
19// Axes must be replayed to mark rfactor iter domains as being reductions in the
20// producer, but converting the other reductions in the producer as iter
21// domains. Those (previously reductions in the producer) iter domains are then
22// converted to reduction domains in the consumer. This breaks up the reduction
23// into two stages, but maintains the correct values are reduced across those
24// stages.
25//
26// The rfactor domain of the producer must match the consumers root domain to
27// maintain producer-consumer mappings. The following uses the original domain
28// being rfactored and marked iter domains as "static_rfactor_ids". These static
29// IDs cannot be changed in the producer as it would invalidate the rfactor, no
30// longer matching the consumer.
31//
32// To find the rfactor domain in the producer which will be used as the root
33// domain in the consumer, we start at the roots of producer, and replay forward
34// the root iter domains if that iter domain is marked as a "static_rfactor_id".
35// To do this we maintain the ordering of the iter domains. For example:
36//
37// I1
38// /\ //
39// I2 \ //
40// /\ I3
41// / I4 /
42// / \/
43// I5 I6
44//
45// If rfactor_axes = {I6}, then "static_rfactor_id" IDs will be {I6, I4, I3, I2,
46// I1}. Then, as we perform the replay the rfactor domain will be updated as:
47// [I1] -> [I2, I3] -> [I5, I4, I3] -> [I5, I6]
48//
49// ReplayTransformations typically updates the leaf ids, but we'll simply use
50// the mapping from the original tensor domain so we won't bother updating them
51// in this replay.
52class ReplayRFactor : public ReplayTransformations {
53 private:
54 // Perform the update of the rfactor domain by replacing "replace0" with
55 // "with0" and if not nullptr "with1", also removes "replace1" if not nullptr.
56 void updateRFactorDomain(
57 IterDomain* replace0,
58 IterDomain* replace1,
59 IterDomain* with0,
60 IterDomain* with1) {
61 TORCH_INTERNAL_ASSERT(
62 with0 != nullptr,
63 "The first provided IterDomain should be a real pointer,",
64 " the second iter domain provided can be a nullptr.");
65 auto pos =
66 std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0);
67 TORCH_INTERNAL_ASSERT(
68 pos != rfactor_domain_.end(),
69 "Could not find iter domain: ",
70 replace0->toString(),
71 " in the rfactor domain to replace.");
72 rfactor_domain_.insert(pos, with0);
73 if (with1 != nullptr) {
74 pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0);
75 rfactor_domain_.insert(pos, with1);
76 }
77 pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0);
78 rfactor_domain_.erase(pos);
79 if (replace1 != nullptr) {
80 pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace1);
81 TORCH_INTERNAL_ASSERT(
82 pos != rfactor_domain_.end(),
83 "Wanted to replace ",
84 replace1->toString(),
85 " but it's not in the rfactor domain.");
86 rfactor_domain_.erase(pos);
87 }
88 }
89
90 // Took a good bit of this from ReplayTransformations::handle(Split...)
91 void handle(Split* s) override {
92 // Grab input to the split operation
93 auto id_in = s->in();
94 // Grab our mapping of that ID to the one we're replaying
95 auto it = id_map_.find(id_in);
96 // Make sure it exists in the map
97 TORCH_INTERNAL_ASSERT(
98 it != id_map_.end(),
99 "Transform traversal failed, dependencies not met.");
100 // Grab the ID we're going to replay on
101 auto mapped = (*it).second;
102 // This ID should be a leaf ID (meaning it has no uses we generated)
103 TORCH_INTERNAL_ASSERT(
104 leaf_ids_.find(mapped) != leaf_ids_.end(),
105 "Transform traversal failed, modified a node but it was not a leaf node.");
106
107 // outer loop size
108 Val* remainder = ceilDiv(mapped->extent(), s->factor());
109
110 // Check if we need to mark the outputs as an rfactor domain meaning this
111 // transformation must be present in replays otherwise it breaks the compute
112 // definition of the fusion. Iter domains are actually not static, its the
113 // transformation that's static or not, so if one output is marked as a
114 // static id, then both must be.
115 bool static_rfactor_outputs = static_rfactor_ids_.count(s->outer()) ||
116 static_rfactor_ids_.count(s->inner());
117
118 // Manually replay the split, making reduction = false and rfactor = true
119 // outer IterDomain
120 IterDomain* ido =
121 IterDomainBuilder(
122 s->container()->zeroVal(),
123 s->innerSplit() ? remainder->as<Int>() : s->factor())
124 .iter_type(
125 rfactor_axes_.count(s->outer()) ? IterType::Reduction
126 : IterType::Iteration)
127 .is_rfactor_domain(static_rfactor_outputs)
128 .build();
129
130 // inner IterDomain
131 IterDomain* idi =
132 IterDomainBuilder(
133 s->container()->zeroVal(),
134 s->innerSplit() ? s->factor() : remainder->as<Int>())
135 .iter_type(
136 rfactor_axes_.count(s->inner()) ? IterType::Reduction
137 : IterType::Iteration)
138 .is_rfactor_domain(static_rfactor_outputs)
139 .build();
140
141 // Generate the split node
142 IrBuilder::create<Split>(
143 s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
144
145 // Remove mapped id from leaf IDs
146 leaf_ids_.erase(mapped);
147 // Add outputs to leaf IDs
148 leaf_ids_[ido] = counter++;
149 leaf_ids_[idi] = counter++;
150
151 // Update our ID map to include these outputs
152 id_map_[s->outer()] = ido;
153 id_map_[s->inner()] = idi;
154
155 if (static_rfactor_ids_.count(s->in())) {
156 updateRFactorDomain(s->in(), nullptr, s->outer(), s->inner());
157 }
158 }
159
160 void handle(Merge* m) override {
161 auto id_outer = m->outer();
162 auto id_inner = m->inner();
163 auto it_outer = id_map_.find(id_outer);
164 auto it_inner = id_map_.find(id_inner);
165 TORCH_INTERNAL_ASSERT(
166 it_outer != id_map_.end() && it_inner != id_map_.end(),
167 "Transform traversal failed, dependencies not met.");
168
169 auto id_outer_mapped = (*it_outer).second;
170 auto id_inner_mapped = (*it_inner).second;
171
172 TORCH_INTERNAL_ASSERT(
173 leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() &&
174 leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(),
175 "Transform traversal failed, modified ",
176 id_outer_mapped,
177 " and ",
178 id_inner_mapped,
179 " however one or both are not leaf nodes.");
180
181 Val* merged_id_size =
182 mul(id_outer_mapped->extent(), id_inner_mapped->extent());
183
184 IterDomain* merged_id =
185 IterDomainBuilder(m->container()->zeroVal(), merged_id_size->as<Int>())
186 .iter_type(
187 rfactor_axes_.count(m->out()) ? IterType::Reduction
188 : IterType::Iteration)
189 .is_rfactor_domain(static_rfactor_ids_.count(m->out()))
190 .build();
191
192 IrBuilder::create<Merge>(
193 m->container(), merged_id, id_outer_mapped, id_inner_mapped);
194
195 // Remove inputs from the leaf IDs
196 leaf_ids_.erase(id_outer_mapped);
197 leaf_ids_.erase(id_inner_mapped);
198
199 // Add the output to the leaf IDs
200 leaf_ids_[merged_id] = counter++;
201
202 id_map_[m->out()] = merged_id;
203
204 // Similar to split replay above, check if output needs to be marked as
205 // rfactor indicating this transofrmation is static.
206 if (static_rfactor_ids_.count(m->inner()) ||
207 static_rfactor_ids_.count(m->outer())) {
208 TORCH_INTERNAL_ASSERT(
209 static_rfactor_ids_.count(m->inner()) ==
210 static_rfactor_ids_.count(m->outer()),
211 "If one input to a merge is a static rfactor id, the other must be as well.");
212 updateRFactorDomain(m->outer(), m->inner(), m->out(), nullptr);
213 }
214 }
215
216 // The IterDomains in the original_domain that are being factored into the
217 // first stage of the two stage reduction (the producer).
218 std::unordered_set<IterDomain*> rfactor_axes_;
219 // Iter domains whose history cannot be changed as it would break rfactor
220 // dependencies.
221 std::unordered_set<IterDomain*> static_rfactor_ids_;
222
223 public:
224 // The updated domain matching the producer's rfactor domain. This rfactor
225 // domain is relative to the iter domains in the origianl_domain and must be
226 // updated to grab the mapped id's later.
227 std::vector<IterDomain*> rfactor_domain_;
228
229 ReplayRFactor(
230 // Original domain the rfactor is in reference to.
231 TensorDomain* original_domain,
232 // The root mapping from the original root domain, to the roots of the
233 // domain to be replayed.
234 std::unordered_map<IterDomain*, IterDomain*> id_map,
235 // The rfactor axes in original_domain->domain() to be factored into the
236 // two stage reduction.
237 std::unordered_set<IterDomain*> rfactor_axes,
238 // All the iter domains in original_domain that the rfactor axes are
239 // dependant on.
240 std::unordered_set<IterDomain*> static_rfactor_ids)
241 : ReplayTransformations(
242 original_domain->domain(),
243 std::move(id_map),
244 false),
245 rfactor_axes_(std::move(rfactor_axes)),
246 static_rfactor_ids_(static_rfactor_ids),
247 rfactor_domain_(original_domain->getMaybeRFactorDomain()) {}
248};
249
250} // namespace
251
252std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
253 TensorDomain* original_td,
254 std::vector<int> axes) {
255 FUSER_PERF_SCOPE("TransformRFactor::runReplay");
256
257 TORCH_CHECK(!axes.empty(), "No axes provided to rfactor replay.");
258
259 int ndims = (int)original_td->nDims();
260
261 // Adjust and check provided axes
262 std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) {
263 TORCH_CHECK(
264 i >= -ndims && i < ndims,
265 "Rfactor replay received an axis outside the number of dims in the tensor, acceptable inclusive range is ",
266 -ndims,
267 " to ",
268 ndims - 1);
269 return i < 0 ? i + ndims : i;
270 });
271
272 // remove duplicates, and put into a set for searching
273 std::unordered_set<int> axes_set(axes.begin(), axes.end());
274
275 TORCH_INTERNAL_ASSERT(
276 std::all_of(
277 axes_set.begin(),
278 axes_set.end(),
279 [original_td](int i) { return original_td->axis(i)->isReduction(); }),
280 "Cannot rfactor axes that are not reduction axes.");
281
282 // RFactor requires at least one reduction axis to be marked as factored out,
283 // and at least one reduction axis that won't. Otherwise it's just a pointwise
284 // cacheing operation.
285 bool found_non_rfactor_reduction = false;
286
287 // Make a set of final axes that are marked to be rfactored
288 std::unordered_set<IterDomain*> rfactor_axes(axes_set.size());
289 {
290 size_t i = 0;
291 for (auto id : original_td->domain()) {
292 if (axes_set.find(i++) != axes_set.end()) {
293 rfactor_axes.emplace(id);
294 } else if (id->isReduction()) {
295 found_non_rfactor_reduction = true;
296 }
297 }
298 }
299
300 TORCH_CHECK(
301 found_non_rfactor_reduction,
302 "Must have at least one reduction axis not marked as rfactor.");
303
304 // Get root IterDomains of the rfactor domains, these will be the ones we will
305 // replay marked as rfactor axes, those marked in the axes set will be
306 // reduction=false
307 auto rfactor_root_vals = IterVisitor::getInputsTo(
308 std::vector<Val*>(rfactor_axes.begin(), rfactor_axes.end()));
309 auto rfactor_root_ids = ir_utils::filterByType<IterDomain>(rfactor_root_vals);
310
311 // Put in a set to make searching easy
312 std::unordered_set<IterDomain*> rfactor_root_axes(
313 rfactor_root_ids.begin(), rfactor_root_ids.end());
314
315 TORCH_INTERNAL_ASSERT(
316 std::none_of(
317 rfactor_root_ids.begin(),
318 rfactor_root_ids.end(),
319 [](IterDomain* id) { return id->maybePartial(); }),
320 "rFactor of partial domains not allowed, but at least one found.");
321
322 auto original_td_root = original_td->getMaybeRFactorDomain();
323
324 // Generate a new TensorDomain and set up map from one root to this one.
325 std::vector<IterDomain*> new_producer_root(original_td_root.size(), nullptr);
326 std::unordered_map<IterDomain*, IterDomain*> original_to_producer_root_map;
327
328 {
329 for (auto i : c10::irange(original_td_root.size())) {
330 auto id = original_td_root[i];
331 // If this is an rfactor root, it will be a reduction in this stage
332 if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) {
333 new_producer_root[i] = IterDomainBuilder(id->start(), id->extent())
334 .stop_offset(id->stopOffset())
335 .iter_type(IterType::Reduction)
336 .is_rfactor_domain(true)
337 .build();
338 // If this is not an rfactor root, but a reduction root, it should be
339 // turned into an iteration domain
340 } else if (id->isReduction()) {
341 new_producer_root[i] = IterDomainBuilder(id->start(), id->extent())
342 .stop_offset(id->stopOffset())
343 .build();
344 } else {
345 new_producer_root[i] = id->cloneWithoutRFactor();
346 }
347 original_to_producer_root_map[id] = new_producer_root[i++];
348 }
349 }
350
351 // Axes in the original_td that are in the history of the rfactored domains.
352 // These will mark which iter domains must be preserved as static
353 // transformations to preserve compute semantics.
354 auto all_deps_of_rfactor = DependencyCheck::getAllValsBetween(
355 {original_td->getMaybeRFactorDomain().begin(),
356 original_td->getMaybeRFactorDomain().end()},
357 {rfactor_axes.begin(), rfactor_axes.end()});
358
359 auto all_id_deps_of_rfactor =
360 ir_utils::filterByType<IterDomain>(all_deps_of_rfactor);
361
362 std::unordered_set<IterDomain*> static_rfactor_ids(
363 {all_id_deps_of_rfactor.begin(), all_id_deps_of_rfactor.end()});
364
365 // Replay producer dimensions.
366 ReplayRFactor replay_rfactor(
367 original_td,
368 original_to_producer_root_map,
369 rfactor_axes,
370 static_rfactor_ids);
371
372 std::unordered_map<IterDomain*, IterDomain*> original_to_producer_id_map =
373 replay_rfactor.getReplay();
374
375 std::vector<IterDomain*> new_producer_domain(original_td->nDims(), nullptr);
376 {
377 for (auto i : c10::irange(original_td->nDims())) {
378 auto orig_id = original_td->axis(i);
379 auto replayed_id_it = original_to_producer_id_map.find(orig_id);
380 TORCH_INTERNAL_ASSERT(
381 replayed_id_it != original_to_producer_id_map.end(),
382 "Error during rfactor replay, missing an axis.");
383 auto replayed_id = replayed_id_it->second;
384 replayed_id->parallelize(orig_id->getParallelType());
385 if (orig_id->hasPaddingToMultipleOfWarp()) {
386 replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding());
387 }
388 new_producer_domain[i++] = replayed_id;
389 }
390 }
391
392 // Specify the rfactor domain of the producer which will match the consumer
393 // root domain.
394 std::vector<IterDomain*> new_producer_rfactor_domain;
395 new_producer_rfactor_domain.reserve(replay_rfactor.rfactor_domain_.size());
396 std::transform(
397 replay_rfactor.rfactor_domain_.begin(),
398 replay_rfactor.rfactor_domain_.end(),
399 std::back_inserter(new_producer_rfactor_domain),
400 [&](IterDomain* id) {
401 auto replayed_id_it = original_to_producer_id_map.find(id);
402 TORCH_INTERNAL_ASSERT(
403 replayed_id_it != original_to_producer_id_map.end(),
404 "Error during rfactor replay, missing an axis.");
405 return replayed_id_it->second;
406 });
407
408 TensorDomain* producer_domain = IrBuilder::create<TensorDomain>(
409 original_td->container(),
410 new_producer_root,
411 new_producer_rfactor_domain,
412 new_producer_domain,
413 std::vector<bool>(new_producer_rfactor_domain.size(), true));
414
415 // Producer has been finished, now work on consumer.
416
417 // For convenience flip the original to producer map
418 std::unordered_map<IterDomain*, IterDomain*> producer_to_original_map;
419 for (auto entry : original_to_producer_id_map) {
420 producer_to_original_map[entry.second] = entry.first;
421 }
422
423 std::vector<IterDomain*> new_consumer_root_domain;
424 new_consumer_root_domain.reserve(new_producer_rfactor_domain.size());
425 std::unordered_map<IterDomain*, IterDomain*> original_to_consumer_root_map;
426 for (auto p_root_id : new_producer_rfactor_domain) {
427 if (p_root_id->isReduction()) {
428 continue;
429 }
430 auto p2o_it = producer_to_original_map.find(p_root_id);
431 TORCH_INTERNAL_ASSERT(
432 p2o_it != producer_to_original_map.end(),
433 "Missing mapping from original tensor domain to producer tensor domain.");
434 auto original_id = p2o_it->second;
435 auto new_consumer_root =
436 IterDomainBuilder(original_id->start(), original_id->extent())
437 .stop_offset(original_id->stopOffset())
438 .iter_type(original_id->getIterType())
439 .build();
440 new_consumer_root_domain.push_back(new_consumer_root);
441 original_to_consumer_root_map[original_id] = new_consumer_root;
442 }
443
444 ReplayTransformations consumer_replay(
445 original_td->domain(), original_to_consumer_root_map, false);
446 auto original_to_consumer_map = consumer_replay.getReplay();
447
448 std::vector<IterDomain*> new_consumer_domain;
449
450 {
451 // Construct the new consumer domain
452 for (auto i : c10::irange(original_td->nDims())) {
453 auto orig_id = original_td->axis(i);
454 auto replayed_id_it = original_to_consumer_map.find(orig_id);
455 if (replayed_id_it != original_to_consumer_map.end()) {
456 auto replayed_id = replayed_id_it->second;
457 new_consumer_domain.push_back(replayed_id);
458 replayed_id->parallelize(orig_id->getParallelType());
459 if (orig_id->hasPaddingToMultipleOfWarp()) {
460 replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding());
461 }
462 }
463 }
464 }
465
466 auto consumer_domain = IrBuilder::create<TensorDomain>(
467 original_td->container(),
468 new_consumer_root_domain,
469 new_consumer_domain,
470 std::vector<bool>(new_consumer_root_domain.size(), true));
471
472 return std::make_pair(producer_domain, consumer_domain);
473}
474
475} // namespace cuda
476} // namespace fuser
477} // namespace jit
478} // namespace torch
479