1 | #include <transform_rfactor.h> |
2 | |
3 | #include <arith.h> |
4 | #include <fusion.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_builder.h> |
7 | #include <ir_iostream.h> |
8 | #include <ir_utils.h> |
9 | #include <iter_visitor.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | namespace { |
17 | |
18 | // This class replays the root domains of the producer of an rfactor domain. |
19 | // Axes must be replayed to mark rfactor iter domains as being reductions in the |
20 | // producer, but converting the other reductions in the producer as iter |
21 | // domains. Those (previously reductions in the producer) iter domains are then |
22 | // converted to reduction domains in the consumer. This breaks up the reduction |
23 | // into two stages, but maintains the correct values are reduced across those |
24 | // stages. |
25 | // |
26 | // The rfactor domain of the producer must match the consumers root domain to |
27 | // maintain producer-consumer mappings. The following uses the original domain |
28 | // being rfactored and marked iter domains as "static_rfactor_ids". These static |
29 | // IDs cannot be changed in the producer as it would invalidate the rfactor, no |
30 | // longer matching the consumer. |
31 | // |
32 | // To find the rfactor domain in the producer which will be used as the root |
33 | // domain in the consumer, we start at the roots of producer, and replay forward |
34 | // the root iter domains if that iter domain is marked as a "static_rfactor_id". |
35 | // To do this we maintain the ordering of the iter domains. For example: |
36 | // |
37 | // I1 |
38 | // /\ // |
39 | // I2 \ // |
40 | // /\ I3 |
41 | // / I4 / |
42 | // / \/ |
43 | // I5 I6 |
44 | // |
45 | // If rfactor_axes = {I6}, then "static_rfactor_id" IDs will be {I6, I4, I3, I2, |
46 | // I1}. Then, as we perform the replay the rfactor domain will be updated as: |
47 | // [I1] -> [I2, I3] -> [I5, I4, I3] -> [I5, I6] |
48 | // |
49 | // ReplayTransformations typically updates the leaf ids, but we'll simply use |
50 | // the mapping from the original tensor domain so we won't bother updating them |
51 | // in this replay. |
52 | class ReplayRFactor : public ReplayTransformations { |
53 | private: |
54 | // Perform the update of the rfactor domain by replacing "replace0" with |
55 | // "with0" and if not nullptr "with1", also removes "replace1" if not nullptr. |
56 | void updateRFactorDomain( |
57 | IterDomain* replace0, |
58 | IterDomain* replace1, |
59 | IterDomain* with0, |
60 | IterDomain* with1) { |
61 | TORCH_INTERNAL_ASSERT( |
62 | with0 != nullptr, |
63 | "The first provided IterDomain should be a real pointer," , |
64 | " the second iter domain provided can be a nullptr." ); |
65 | auto pos = |
66 | std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0); |
67 | TORCH_INTERNAL_ASSERT( |
68 | pos != rfactor_domain_.end(), |
69 | "Could not find iter domain: " , |
70 | replace0->toString(), |
71 | " in the rfactor domain to replace." ); |
72 | rfactor_domain_.insert(pos, with0); |
73 | if (with1 != nullptr) { |
74 | pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0); |
75 | rfactor_domain_.insert(pos, with1); |
76 | } |
77 | pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace0); |
78 | rfactor_domain_.erase(pos); |
79 | if (replace1 != nullptr) { |
80 | pos = std::find(rfactor_domain_.begin(), rfactor_domain_.end(), replace1); |
81 | TORCH_INTERNAL_ASSERT( |
82 | pos != rfactor_domain_.end(), |
83 | "Wanted to replace " , |
84 | replace1->toString(), |
85 | " but it's not in the rfactor domain." ); |
86 | rfactor_domain_.erase(pos); |
87 | } |
88 | } |
89 | |
90 | // Took a good bit of this from ReplayTransformations::handle(Split...) |
91 | void handle(Split* s) override { |
92 | // Grab input to the split operation |
93 | auto id_in = s->in(); |
94 | // Grab our mapping of that ID to the one we're replaying |
95 | auto it = id_map_.find(id_in); |
96 | // Make sure it exists in the map |
97 | TORCH_INTERNAL_ASSERT( |
98 | it != id_map_.end(), |
99 | "Transform traversal failed, dependencies not met." ); |
100 | // Grab the ID we're going to replay on |
101 | auto mapped = (*it).second; |
102 | // This ID should be a leaf ID (meaning it has no uses we generated) |
103 | TORCH_INTERNAL_ASSERT( |
104 | leaf_ids_.find(mapped) != leaf_ids_.end(), |
105 | "Transform traversal failed, modified a node but it was not a leaf node." ); |
106 | |
107 | // outer loop size |
108 | Val* remainder = ceilDiv(mapped->extent(), s->factor()); |
109 | |
110 | // Check if we need to mark the outputs as an rfactor domain meaning this |
111 | // transformation must be present in replays otherwise it breaks the compute |
112 | // definition of the fusion. Iter domains are actually not static, its the |
113 | // transformation that's static or not, so if one output is marked as a |
114 | // static id, then both must be. |
115 | bool static_rfactor_outputs = static_rfactor_ids_.count(s->outer()) || |
116 | static_rfactor_ids_.count(s->inner()); |
117 | |
118 | // Manually replay the split, making reduction = false and rfactor = true |
119 | // outer IterDomain |
120 | IterDomain* ido = |
121 | IterDomainBuilder( |
122 | s->container()->zeroVal(), |
123 | s->innerSplit() ? remainder->as<Int>() : s->factor()) |
124 | .iter_type( |
125 | rfactor_axes_.count(s->outer()) ? IterType::Reduction |
126 | : IterType::Iteration) |
127 | .is_rfactor_domain(static_rfactor_outputs) |
128 | .build(); |
129 | |
130 | // inner IterDomain |
131 | IterDomain* idi = |
132 | IterDomainBuilder( |
133 | s->container()->zeroVal(), |
134 | s->innerSplit() ? s->factor() : remainder->as<Int>()) |
135 | .iter_type( |
136 | rfactor_axes_.count(s->inner()) ? IterType::Reduction |
137 | : IterType::Iteration) |
138 | .is_rfactor_domain(static_rfactor_outputs) |
139 | .build(); |
140 | |
141 | // Generate the split node |
142 | IrBuilder::create<Split>( |
143 | s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); |
144 | |
145 | // Remove mapped id from leaf IDs |
146 | leaf_ids_.erase(mapped); |
147 | // Add outputs to leaf IDs |
148 | leaf_ids_[ido] = counter++; |
149 | leaf_ids_[idi] = counter++; |
150 | |
151 | // Update our ID map to include these outputs |
152 | id_map_[s->outer()] = ido; |
153 | id_map_[s->inner()] = idi; |
154 | |
155 | if (static_rfactor_ids_.count(s->in())) { |
156 | updateRFactorDomain(s->in(), nullptr, s->outer(), s->inner()); |
157 | } |
158 | } |
159 | |
160 | void handle(Merge* m) override { |
161 | auto id_outer = m->outer(); |
162 | auto id_inner = m->inner(); |
163 | auto it_outer = id_map_.find(id_outer); |
164 | auto it_inner = id_map_.find(id_inner); |
165 | TORCH_INTERNAL_ASSERT( |
166 | it_outer != id_map_.end() && it_inner != id_map_.end(), |
167 | "Transform traversal failed, dependencies not met." ); |
168 | |
169 | auto id_outer_mapped = (*it_outer).second; |
170 | auto id_inner_mapped = (*it_inner).second; |
171 | |
172 | TORCH_INTERNAL_ASSERT( |
173 | leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && |
174 | leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), |
175 | "Transform traversal failed, modified " , |
176 | id_outer_mapped, |
177 | " and " , |
178 | id_inner_mapped, |
179 | " however one or both are not leaf nodes." ); |
180 | |
181 | Val* merged_id_size = |
182 | mul(id_outer_mapped->extent(), id_inner_mapped->extent()); |
183 | |
184 | IterDomain* merged_id = |
185 | IterDomainBuilder(m->container()->zeroVal(), merged_id_size->as<Int>()) |
186 | .iter_type( |
187 | rfactor_axes_.count(m->out()) ? IterType::Reduction |
188 | : IterType::Iteration) |
189 | .is_rfactor_domain(static_rfactor_ids_.count(m->out())) |
190 | .build(); |
191 | |
192 | IrBuilder::create<Merge>( |
193 | m->container(), merged_id, id_outer_mapped, id_inner_mapped); |
194 | |
195 | // Remove inputs from the leaf IDs |
196 | leaf_ids_.erase(id_outer_mapped); |
197 | leaf_ids_.erase(id_inner_mapped); |
198 | |
199 | // Add the output to the leaf IDs |
200 | leaf_ids_[merged_id] = counter++; |
201 | |
202 | id_map_[m->out()] = merged_id; |
203 | |
204 | // Similar to split replay above, check if output needs to be marked as |
205 | // rfactor indicating this transofrmation is static. |
206 | if (static_rfactor_ids_.count(m->inner()) || |
207 | static_rfactor_ids_.count(m->outer())) { |
208 | TORCH_INTERNAL_ASSERT( |
209 | static_rfactor_ids_.count(m->inner()) == |
210 | static_rfactor_ids_.count(m->outer()), |
211 | "If one input to a merge is a static rfactor id, the other must be as well." ); |
212 | updateRFactorDomain(m->outer(), m->inner(), m->out(), nullptr); |
213 | } |
214 | } |
215 | |
216 | // The IterDomains in the original_domain that are being factored into the |
217 | // first stage of the two stage reduction (the producer). |
218 | std::unordered_set<IterDomain*> rfactor_axes_; |
219 | // Iter domains whose history cannot be changed as it would break rfactor |
220 | // dependencies. |
221 | std::unordered_set<IterDomain*> static_rfactor_ids_; |
222 | |
223 | public: |
224 | // The updated domain matching the producer's rfactor domain. This rfactor |
225 | // domain is relative to the iter domains in the origianl_domain and must be |
226 | // updated to grab the mapped id's later. |
227 | std::vector<IterDomain*> rfactor_domain_; |
228 | |
229 | ReplayRFactor( |
230 | // Original domain the rfactor is in reference to. |
231 | TensorDomain* original_domain, |
232 | // The root mapping from the original root domain, to the roots of the |
233 | // domain to be replayed. |
234 | std::unordered_map<IterDomain*, IterDomain*> id_map, |
235 | // The rfactor axes in original_domain->domain() to be factored into the |
236 | // two stage reduction. |
237 | std::unordered_set<IterDomain*> rfactor_axes, |
238 | // All the iter domains in original_domain that the rfactor axes are |
239 | // dependant on. |
240 | std::unordered_set<IterDomain*> static_rfactor_ids) |
241 | : ReplayTransformations( |
242 | original_domain->domain(), |
243 | std::move(id_map), |
244 | false), |
245 | rfactor_axes_(std::move(rfactor_axes)), |
246 | static_rfactor_ids_(static_rfactor_ids), |
247 | rfactor_domain_(original_domain->getMaybeRFactorDomain()) {} |
248 | }; |
249 | |
250 | } // namespace |
251 | |
252 | std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay( |
253 | TensorDomain* original_td, |
254 | std::vector<int> axes) { |
255 | FUSER_PERF_SCOPE("TransformRFactor::runReplay" ); |
256 | |
257 | TORCH_CHECK(!axes.empty(), "No axes provided to rfactor replay." ); |
258 | |
259 | int ndims = (int)original_td->nDims(); |
260 | |
261 | // Adjust and check provided axes |
262 | std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { |
263 | TORCH_CHECK( |
264 | i >= -ndims && i < ndims, |
265 | "Rfactor replay received an axis outside the number of dims in the tensor, acceptable inclusive range is " , |
266 | -ndims, |
267 | " to " , |
268 | ndims - 1); |
269 | return i < 0 ? i + ndims : i; |
270 | }); |
271 | |
272 | // remove duplicates, and put into a set for searching |
273 | std::unordered_set<int> axes_set(axes.begin(), axes.end()); |
274 | |
275 | TORCH_INTERNAL_ASSERT( |
276 | std::all_of( |
277 | axes_set.begin(), |
278 | axes_set.end(), |
279 | [original_td](int i) { return original_td->axis(i)->isReduction(); }), |
280 | "Cannot rfactor axes that are not reduction axes." ); |
281 | |
282 | // RFactor requires at least one reduction axis to be marked as factored out, |
283 | // and at least one reduction axis that won't. Otherwise it's just a pointwise |
284 | // cacheing operation. |
285 | bool found_non_rfactor_reduction = false; |
286 | |
287 | // Make a set of final axes that are marked to be rfactored |
288 | std::unordered_set<IterDomain*> rfactor_axes(axes_set.size()); |
289 | { |
290 | size_t i = 0; |
291 | for (auto id : original_td->domain()) { |
292 | if (axes_set.find(i++) != axes_set.end()) { |
293 | rfactor_axes.emplace(id); |
294 | } else if (id->isReduction()) { |
295 | found_non_rfactor_reduction = true; |
296 | } |
297 | } |
298 | } |
299 | |
300 | TORCH_CHECK( |
301 | found_non_rfactor_reduction, |
302 | "Must have at least one reduction axis not marked as rfactor." ); |
303 | |
304 | // Get root IterDomains of the rfactor domains, these will be the ones we will |
305 | // replay marked as rfactor axes, those marked in the axes set will be |
306 | // reduction=false |
307 | auto rfactor_root_vals = IterVisitor::getInputsTo( |
308 | std::vector<Val*>(rfactor_axes.begin(), rfactor_axes.end())); |
309 | auto rfactor_root_ids = ir_utils::filterByType<IterDomain>(rfactor_root_vals); |
310 | |
311 | // Put in a set to make searching easy |
312 | std::unordered_set<IterDomain*> rfactor_root_axes( |
313 | rfactor_root_ids.begin(), rfactor_root_ids.end()); |
314 | |
315 | TORCH_INTERNAL_ASSERT( |
316 | std::none_of( |
317 | rfactor_root_ids.begin(), |
318 | rfactor_root_ids.end(), |
319 | [](IterDomain* id) { return id->maybePartial(); }), |
320 | "rFactor of partial domains not allowed, but at least one found." ); |
321 | |
322 | auto original_td_root = original_td->getMaybeRFactorDomain(); |
323 | |
324 | // Generate a new TensorDomain and set up map from one root to this one. |
325 | std::vector<IterDomain*> new_producer_root(original_td_root.size(), nullptr); |
326 | std::unordered_map<IterDomain*, IterDomain*> original_to_producer_root_map; |
327 | |
328 | { |
329 | for (auto i : c10::irange(original_td_root.size())) { |
330 | auto id = original_td_root[i]; |
331 | // If this is an rfactor root, it will be a reduction in this stage |
332 | if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) { |
333 | new_producer_root[i] = IterDomainBuilder(id->start(), id->extent()) |
334 | .stop_offset(id->stopOffset()) |
335 | .iter_type(IterType::Reduction) |
336 | .is_rfactor_domain(true) |
337 | .build(); |
338 | // If this is not an rfactor root, but a reduction root, it should be |
339 | // turned into an iteration domain |
340 | } else if (id->isReduction()) { |
341 | new_producer_root[i] = IterDomainBuilder(id->start(), id->extent()) |
342 | .stop_offset(id->stopOffset()) |
343 | .build(); |
344 | } else { |
345 | new_producer_root[i] = id->cloneWithoutRFactor(); |
346 | } |
347 | original_to_producer_root_map[id] = new_producer_root[i++]; |
348 | } |
349 | } |
350 | |
351 | // Axes in the original_td that are in the history of the rfactored domains. |
352 | // These will mark which iter domains must be preserved as static |
353 | // transformations to preserve compute semantics. |
354 | auto all_deps_of_rfactor = DependencyCheck::getAllValsBetween( |
355 | {original_td->getMaybeRFactorDomain().begin(), |
356 | original_td->getMaybeRFactorDomain().end()}, |
357 | {rfactor_axes.begin(), rfactor_axes.end()}); |
358 | |
359 | auto all_id_deps_of_rfactor = |
360 | ir_utils::filterByType<IterDomain>(all_deps_of_rfactor); |
361 | |
362 | std::unordered_set<IterDomain*> static_rfactor_ids( |
363 | {all_id_deps_of_rfactor.begin(), all_id_deps_of_rfactor.end()}); |
364 | |
365 | // Replay producer dimensions. |
366 | ReplayRFactor replay_rfactor( |
367 | original_td, |
368 | original_to_producer_root_map, |
369 | rfactor_axes, |
370 | static_rfactor_ids); |
371 | |
372 | std::unordered_map<IterDomain*, IterDomain*> original_to_producer_id_map = |
373 | replay_rfactor.getReplay(); |
374 | |
375 | std::vector<IterDomain*> new_producer_domain(original_td->nDims(), nullptr); |
376 | { |
377 | for (auto i : c10::irange(original_td->nDims())) { |
378 | auto orig_id = original_td->axis(i); |
379 | auto replayed_id_it = original_to_producer_id_map.find(orig_id); |
380 | TORCH_INTERNAL_ASSERT( |
381 | replayed_id_it != original_to_producer_id_map.end(), |
382 | "Error during rfactor replay, missing an axis." ); |
383 | auto replayed_id = replayed_id_it->second; |
384 | replayed_id->parallelize(orig_id->getParallelType()); |
385 | if (orig_id->hasPaddingToMultipleOfWarp()) { |
386 | replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding()); |
387 | } |
388 | new_producer_domain[i++] = replayed_id; |
389 | } |
390 | } |
391 | |
392 | // Specify the rfactor domain of the producer which will match the consumer |
393 | // root domain. |
394 | std::vector<IterDomain*> new_producer_rfactor_domain; |
395 | new_producer_rfactor_domain.reserve(replay_rfactor.rfactor_domain_.size()); |
396 | std::transform( |
397 | replay_rfactor.rfactor_domain_.begin(), |
398 | replay_rfactor.rfactor_domain_.end(), |
399 | std::back_inserter(new_producer_rfactor_domain), |
400 | [&](IterDomain* id) { |
401 | auto replayed_id_it = original_to_producer_id_map.find(id); |
402 | TORCH_INTERNAL_ASSERT( |
403 | replayed_id_it != original_to_producer_id_map.end(), |
404 | "Error during rfactor replay, missing an axis." ); |
405 | return replayed_id_it->second; |
406 | }); |
407 | |
408 | TensorDomain* producer_domain = IrBuilder::create<TensorDomain>( |
409 | original_td->container(), |
410 | new_producer_root, |
411 | new_producer_rfactor_domain, |
412 | new_producer_domain, |
413 | std::vector<bool>(new_producer_rfactor_domain.size(), true)); |
414 | |
415 | // Producer has been finished, now work on consumer. |
416 | |
417 | // For convenience flip the original to producer map |
418 | std::unordered_map<IterDomain*, IterDomain*> producer_to_original_map; |
419 | for (auto entry : original_to_producer_id_map) { |
420 | producer_to_original_map[entry.second] = entry.first; |
421 | } |
422 | |
423 | std::vector<IterDomain*> new_consumer_root_domain; |
424 | new_consumer_root_domain.reserve(new_producer_rfactor_domain.size()); |
425 | std::unordered_map<IterDomain*, IterDomain*> original_to_consumer_root_map; |
426 | for (auto p_root_id : new_producer_rfactor_domain) { |
427 | if (p_root_id->isReduction()) { |
428 | continue; |
429 | } |
430 | auto p2o_it = producer_to_original_map.find(p_root_id); |
431 | TORCH_INTERNAL_ASSERT( |
432 | p2o_it != producer_to_original_map.end(), |
433 | "Missing mapping from original tensor domain to producer tensor domain." ); |
434 | auto original_id = p2o_it->second; |
435 | auto new_consumer_root = |
436 | IterDomainBuilder(original_id->start(), original_id->extent()) |
437 | .stop_offset(original_id->stopOffset()) |
438 | .iter_type(original_id->getIterType()) |
439 | .build(); |
440 | new_consumer_root_domain.push_back(new_consumer_root); |
441 | original_to_consumer_root_map[original_id] = new_consumer_root; |
442 | } |
443 | |
444 | ReplayTransformations consumer_replay( |
445 | original_td->domain(), original_to_consumer_root_map, false); |
446 | auto original_to_consumer_map = consumer_replay.getReplay(); |
447 | |
448 | std::vector<IterDomain*> new_consumer_domain; |
449 | |
450 | { |
451 | // Construct the new consumer domain |
452 | for (auto i : c10::irange(original_td->nDims())) { |
453 | auto orig_id = original_td->axis(i); |
454 | auto replayed_id_it = original_to_consumer_map.find(orig_id); |
455 | if (replayed_id_it != original_to_consumer_map.end()) { |
456 | auto replayed_id = replayed_id_it->second; |
457 | new_consumer_domain.push_back(replayed_id); |
458 | replayed_id->parallelize(orig_id->getParallelType()); |
459 | if (orig_id->hasPaddingToMultipleOfWarp()) { |
460 | replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding()); |
461 | } |
462 | } |
463 | } |
464 | } |
465 | |
466 | auto consumer_domain = IrBuilder::create<TensorDomain>( |
467 | original_td->container(), |
468 | new_consumer_root_domain, |
469 | new_consumer_domain, |
470 | std::vector<bool>(new_consumer_root_domain.size(), true)); |
471 | |
472 | return std::make_pair(producer_domain, consumer_domain); |
473 | } |
474 | |
475 | } // namespace cuda |
476 | } // namespace fuser |
477 | } // namespace jit |
478 | } // namespace torch |
479 | |