1 | #include <ir_utils.h> |
2 | #include <iter_visitor.h> |
3 | #include <lower2device.h> |
4 | |
5 | #include <contiguity.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | OrderedIdInformation::OrderedIdInformation( |
13 | const std::vector<IterDomain*>& ids, |
14 | const std::vector<IterDomain*>& root_domain, |
15 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info) |
16 | : active_ids_(root_domain), concrete_info_(concrete_info) { |
17 | if (ids.empty() || root_domain.empty()) { |
18 | return; |
19 | } |
20 | |
21 | // Grab root ids and initialize them. |
22 | for (const auto root_i : c10::irange(root_domain.size())) { |
23 | auto root_id = root_domain[root_i]->as<IterDomain>(); |
24 | |
25 | // Initialize id_to_root_ids to map roots to themselves |
26 | id_to_root_ids_[root_id] = {root_id}; |
27 | |
28 | // Initialize roots as being made up of correctly ordered transforms. |
29 | consistently_ordered_ids_.emplace(root_id); |
30 | |
31 | exclusively_consumes_roots_.emplace(root_id); |
32 | } |
33 | |
34 | // Iterate from the root domain to the provided ids and fill |
35 | // consistently_ordered_ids_, id_to_root_ids_, and exclusively_consumes_roots_ |
36 | // for all the IDs |
37 | auto exprs = StmtSort::getExprsBetween( |
38 | ids[0]->fusion(), |
39 | {root_domain.begin(), root_domain.end()}, |
40 | {ids.begin(), ids.end()}); |
41 | |
42 | for (auto expr : exprs) { |
43 | OptInDispatch::handle(expr); |
44 | } |
45 | } |
46 | |
47 | bool OrderedIdInformation::checkExclusivelyConsumesRoots(IterDomain* id) { |
48 | TORCH_INTERNAL_ASSERT( |
49 | std::find(active_ids_.begin(), active_ids_.end(), id) != |
50 | active_ids_.end(), |
51 | "Error replaying transforms in contiguous ID checker, expected " , |
52 | id->toString(), |
53 | " to be in the active ID set." ); |
54 | |
55 | auto root_id_it = id_to_root_ids_.find(id); |
56 | TORCH_INTERNAL_ASSERT( |
57 | root_id_it != id_to_root_ids_.end(), |
58 | "Error replaying transforms in contiguous ID checker, couldn't find mapped roots of " , |
59 | id->toString()); |
60 | |
61 | const auto& root_ids = root_id_it->second; |
62 | |
63 | // Check all the roots of all other ids, to see if any root_ids in id are also |
64 | // in them. |
65 | for (auto other_active_id : active_ids_) { |
66 | if (other_active_id == id || other_active_id == nullptr) { |
67 | continue; |
68 | } |
69 | |
70 | auto root_id_it = id_to_root_ids_.find(other_active_id); |
71 | TORCH_INTERNAL_ASSERT( |
72 | root_id_it != id_to_root_ids_.end(), |
73 | "Error replaying transforms in contiguous ID checker, couldn't find mapped roots of " , |
74 | other_active_id->toString()); |
75 | |
76 | const auto& other_root_ids = root_id_it->second; |
77 | |
78 | for (auto other_root_id : other_root_ids) { |
79 | if (root_ids.has(other_root_id)) { |
80 | return false; |
81 | } |
82 | } |
83 | } |
84 | return true; |
85 | } |
86 | |
87 | void OrderedIdInformation::handle(Merge* merge) { |
88 | // Find inputs in the active_ids_ vector |
89 | const auto inner_it = |
90 | std::find(active_ids_.begin(), active_ids_.end(), merge->inner()); |
91 | const auto outer_it = |
92 | std::find(active_ids_.begin(), active_ids_.end(), merge->outer()); |
93 | |
94 | // If either aren't in active_ids_ it means the inputs were detected to not be |
95 | // ordered correctly before hitting this expression. |
96 | if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { |
97 | return; |
98 | } |
99 | |
100 | auto inner_pos = std::distance(active_ids_.begin(), inner_it); |
101 | auto outer_pos = std::distance(active_ids_.begin(), outer_it); |
102 | |
103 | // Find inputs in the ordered transforms map |
104 | const auto inner_ordered_it = consistently_ordered_ids_.find(merge->inner()); |
105 | const auto outer_ordered_it = consistently_ordered_ids_.find(merge->outer()); |
106 | |
107 | bool inner_ordered = inner_ordered_it != consistently_ordered_ids_.end(); |
108 | bool outer_ordered = outer_ordered_it != consistently_ordered_ids_.end(); |
109 | |
110 | // Get root ids of the two inputs |
111 | const auto inner_root_ids_it = id_to_root_ids_.find(merge->inner()); |
112 | const auto outer_root_ids_it = id_to_root_ids_.find(merge->outer()); |
113 | |
114 | TORCH_INTERNAL_ASSERT( |
115 | inner_root_ids_it != id_to_root_ids_.end() && |
116 | outer_root_ids_it != id_to_root_ids_.end(), |
117 | "Error replaying transforms in contiguous ID checker." ); |
118 | |
119 | const auto& inner_root_ids = inner_root_ids_it->second; |
120 | const auto& outer_root_ids = outer_root_ids_it->second; |
121 | |
122 | // TODO: Concretization may prevent contiguous indexing or vectorization. |
123 | // It prevents contiguous indexing if the concretization is within the IDs |
124 | // that are used for indexing. |
125 | // For vectorization it just means we need to make sure the extents of the |
126 | // axes to the right of the broadcast root domain in the contigous merge is |
127 | // bigger than the vectorization dimension. And that the tensor buffer |
128 | // supports the vector word size (always done). |
129 | bool outer_is_concretized_bcast = merge->outer()->isBroadcast() && |
130 | concrete_info_->isConcretized(merge->outer()); |
131 | |
132 | bool inner_is_concretized_bcast = merge->inner()->isBroadcast() && |
133 | concrete_info_->isConcretized(merge->inner()); |
134 | |
135 | // Update maps |
136 | // Find the position inner would have to have to be considered ordered |
137 | auto pos_after_outer = outer_pos + 1; |
138 | for (; pos_after_outer < int64_t(active_ids_.size()); pos_after_outer++) { |
139 | if (active_ids_[pos_after_outer] == nullptr) { |
140 | // Can't be considered ordered after a nullptr |
141 | break; |
142 | } |
143 | if (active_ids_[pos_after_outer]->isReduction() || |
144 | ((active_ids_[pos_after_outer]->isBroadcast() && |
145 | !concrete_info_->isConcretized(active_ids_[pos_after_outer])))) { |
146 | // Skip reduction or broadcast axes that aren't concretized in the fusion |
147 | continue; |
148 | } |
149 | break; |
150 | } |
151 | |
152 | // The output is ordered as long as the inputs were ordered and outer position |
153 | // is directly left of the inner position. |
154 | bool out_ordered = inner_ordered && outer_ordered; |
155 | out_ordered = out_ordered && |
156 | // If inner_pos is before outer_pos it's not ordered correctly. If for |
157 | // some reason it's the same, that would be an error. |
158 | inner_pos > outer_pos && |
159 | // Inner could be a broadcast, so doesn't have to be right on |
160 | // pos_after_outer as that ID (if it exists) should not be a broadcast. |
161 | // However, merging over a broadcast should be fine. |
162 | inner_pos <= pos_after_outer && !inner_is_concretized_bcast && |
163 | !outer_is_concretized_bcast; |
164 | |
165 | if (out_ordered) { |
166 | consistently_ordered_ids_.emplace(merge->out()); |
167 | } |
168 | |
169 | // Don't just remove active_ids_, as if we have something like: |
170 | // [i0, i1, i2, i3] |
171 | // ->merge(0, 2) |
172 | // ->merge(1) |
173 | // The latter merge looks like it's ordered correctly, if we update the active |
174 | // map as: |
175 | // [i0, i1, i2, i3] -> [i0*i2, i1, i3] |
176 | // Hoever if we instead mark it as: |
177 | // [i0, i1, i2, i3] -> [i0*i2, i1, nullptr, i3] |
178 | // Or: |
179 | // [i0, i1, i2, i3] -> [nullptr, i1, i0*i2, i3] |
180 | // It's clear the second merge is not ordered correctly. Doesn't matter which |
181 | // direction we put the iter domain in, prefer putting it in outer as we often |
182 | // are looking for inner dimensions that are contiguous. We don't want to |
183 | // always do this, as it could make ordered merges look non-ordered. |
184 | // For exmaple: [i0, i1, i2, i3] |
185 | // ->merge(0) |
186 | // ->merge(1) |
187 | // ->merge(0) |
188 | // If it's updated as: |
189 | // [i0, i1, i2, i3] |
190 | // -> [i0*i1, nullptr, i2, i3] |
191 | // -> [i0*i1, nullptr, i2*i3, nullptr] |
192 | // Now the final merge looks non-ordered but it is. So only insert a nullptr |
193 | // entry if the out is not ordered. |
194 | active_ids_[outer_pos] = merge->out(); |
195 | |
196 | if (!out_ordered) { |
197 | active_ids_[inner_pos] = nullptr; |
198 | } else { |
199 | active_ids_.erase(active_ids_.begin() + inner_pos); |
200 | for (auto i = outer_pos + 1; i < inner_pos; i++) { |
201 | // If there's broadcast axes between outer and inner and the merge was |
202 | // contiguous, there may be broadcasts between outer and inner that cannot |
203 | // be ordered merged anywhere else so remove them. |
204 | active_ids_.erase(active_ids_.begin() + outer_pos + 1); |
205 | } |
206 | } |
207 | |
208 | // Update the root_id entry for the output. |
209 | VectorOfUniqueEntries<IterDomain*> root_ids = inner_root_ids; |
210 | root_ids.pushBack(outer_root_ids); |
211 | |
212 | id_to_root_ids_[merge->out()] = root_ids; |
213 | |
214 | // Need to check this after updating active_ids_ and id_to_root_ids_ |
215 | if (checkExclusivelyConsumesRoots(merge->out())) { |
216 | exclusively_consumes_roots_.emplace(merge->out()); |
217 | } |
218 | } |
219 | |
220 | void OrderedIdInformation::handle(Split* split) { |
221 | // Find the input in the active_ids_ vector |
222 | const auto in_it = |
223 | std::find(active_ids_.begin(), active_ids_.end(), split->in()); |
224 | |
225 | if (in_it == active_ids_.end()) { |
226 | return; |
227 | } |
228 | |
229 | auto in_pos = std::distance(active_ids_.begin(), in_it); |
230 | |
231 | // Find the input in the ordered transforms map |
232 | const auto in_ordered_it = consistently_ordered_ids_.find(split->in()); |
233 | |
234 | bool in_ordered = in_ordered_it != consistently_ordered_ids_.end(); |
235 | |
236 | // Get root ids of the input |
237 | const auto in_root_ids_it = id_to_root_ids_.find(split->in()); |
238 | |
239 | TORCH_INTERNAL_ASSERT( |
240 | in_root_ids_it != id_to_root_ids_.end(), |
241 | "Error replaying transforms in contiguous ID checker." ); |
242 | |
243 | VectorOfUniqueEntries<IterDomain*> in_root_ids = in_root_ids_it->second; |
244 | |
245 | // Update map for outputs |
246 | // Remove inputs from the active_ids_ and insert the output ID |
247 | active_ids_[in_pos] = split->outer(); |
248 | active_ids_.insert(active_ids_.begin() + in_pos + 1, split->inner()); |
249 | |
250 | // The outputs are ordered as long as the input is ordered. |
251 | if (in_ordered) { |
252 | consistently_ordered_ids_.emplace(split->outer()); |
253 | consistently_ordered_ids_.emplace(split->inner()); |
254 | } |
255 | |
256 | // Update the root_id entry for the outputs. |
257 | id_to_root_ids_[split->outer()] = in_root_ids; |
258 | id_to_root_ids_[split->inner()] = in_root_ids; |
259 | } |
260 | |
261 | // Swizzle generally can't be contiguous because of the non-affine nature of it, |
262 | // but we can still analyze the operation in the same way as merge/split. |
263 | void OrderedIdInformation::handle(Swizzle2D* swizzle) { |
264 | // Find inputs in the active_ids_ vector |
265 | const auto in_x_it = |
266 | std::find(active_ids_.begin(), active_ids_.end(), swizzle->inX()); |
267 | const auto in_y_it = |
268 | std::find(active_ids_.begin(), active_ids_.end(), swizzle->inY()); |
269 | |
270 | if (in_x_it == active_ids_.end() || in_y_it == active_ids_.end()) { |
271 | return; |
272 | } |
273 | |
274 | auto in_x_pos = std::distance(active_ids_.begin(), in_x_it); |
275 | auto in_y_pos = std::distance(active_ids_.begin(), in_y_it); |
276 | |
277 | // Find inputs in the ordered transforms map |
278 | const auto in_x_ordered_it = consistently_ordered_ids_.find(swizzle->inX()); |
279 | const auto in_y_ordered_it = consistently_ordered_ids_.find(swizzle->inY()); |
280 | |
281 | bool in_x_ordered = in_x_ordered_it != consistently_ordered_ids_.end(); |
282 | bool in_y_ordered = in_y_ordered_it != consistently_ordered_ids_.end(); |
283 | |
284 | // Get root ids of the two inputs |
285 | const auto in_x_root_ids_it = id_to_root_ids_.find(swizzle->inX()); |
286 | const auto in_y_root_ids_it = id_to_root_ids_.find(swizzle->inY()); |
287 | |
288 | TORCH_INTERNAL_ASSERT( |
289 | in_x_root_ids_it != id_to_root_ids_.end() && |
290 | in_y_root_ids_it != id_to_root_ids_.end(), |
291 | "Error replaying transforms in contiguous ID checker." ); |
292 | |
293 | const auto& in_x_root_ids = in_x_root_ids_it->second; |
294 | const auto& in_y_root_ids = in_y_root_ids_it->second; |
295 | |
296 | // Update map for outputs |
297 | // Remove inputs from the active_ids_ and insert the output ID |
298 | active_ids_[in_x_pos] = swizzle->outX(); |
299 | active_ids_[in_y_pos] = swizzle->outY(); |
300 | |
301 | // In the case of no real swizzle we can forward properties on each domain |
302 | // independently. |
303 | if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle) { |
304 | if (in_x_ordered) { |
305 | consistently_ordered_ids_.emplace(swizzle->outX()); |
306 | } |
307 | |
308 | if (exclusivelyConsumesRoots(swizzle->inX())) { |
309 | exclusively_consumes_roots_.emplace(swizzle->outX()); |
310 | } |
311 | |
312 | if (in_y_ordered) { |
313 | consistently_ordered_ids_.emplace(swizzle->outY()); |
314 | } |
315 | |
316 | if (exclusivelyConsumesRoots(swizzle->inY())) { |
317 | exclusively_consumes_roots_.emplace(swizzle->outY()); |
318 | } |
319 | |
320 | id_to_root_ids_[swizzle->outX()] = in_x_root_ids; |
321 | id_to_root_ids_[swizzle->outY()] = in_y_root_ids; |
322 | } else { |
323 | VectorOfUniqueEntries<IterDomain*> root_ids = in_x_root_ids; |
324 | root_ids.pushBack(in_y_root_ids); |
325 | id_to_root_ids_[swizzle->outX()] = root_ids; |
326 | id_to_root_ids_[swizzle->outY()] = root_ids; |
327 | } |
328 | } |
329 | |
330 | NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( |
331 | // TODO: Revisit reduction rfactor axes and propagation. Should probably use |
332 | // ca_map to propogate non divisibility dependencies across exact map. Still |
333 | // need to think through divisible split and non divisible dependencies to |
334 | // see if there's conflicts where a split might look non divisible but |
335 | // actually is divisible and one's overruling the other. |
336 | const std::vector<IterDomain*>& ids, |
337 | const std::vector<IterDomain*>& root_domain, |
338 | const std::unordered_set<Split*>& divisible_splits) { |
339 | if (ids.empty() || root_domain.empty()) { |
340 | return; |
341 | } |
342 | auto transforms = StmtSort::getExprsBetween( |
343 | ids[0]->fusion(), |
344 | {root_domain.begin(), root_domain.end()}, |
345 | {ids.begin(), ids.end()}); |
346 | for (auto transform : transforms) { |
347 | auto inp_ids = ir_utils::filterByType<IterDomain>(transform->inputs()); |
348 | for (auto inp_id : inp_ids) { |
349 | if (std::find(root_domain.begin(), root_domain.end(), inp_id) != |
350 | root_domain.end()) { |
351 | // This generally shouldn't happen as there shouldn't be |
352 | // transformations before the root ids, but in case for some reason |
353 | // we eventually do have cases like that, we should reset the |
354 | // root_ids if for some reason they've been placed in the non |
355 | // divisible split set. |
356 | depends_on_non_divisible_split.erase(inp_id); |
357 | } |
358 | } |
359 | |
360 | bool inputs_non_divisible = |
361 | std::any_of(inp_ids.begin(), inp_ids.end(), [this](IterDomain* inp_id) { |
362 | return depends_on_non_divisible_split.find(inp_id) != |
363 | depends_on_non_divisible_split.end(); |
364 | }); |
365 | |
366 | auto out_ids = ir_utils::filterByType<IterDomain>(transform->outputs()); |
367 | |
368 | if (inputs_non_divisible) { |
369 | // If any inputs are known to be dependent on a divisible split |
370 | // Mark outputs as dependent on a non_divisible split |
371 | depends_on_non_divisible_split.insert(out_ids.begin(), out_ids.end()); |
372 | continue; |
373 | } |
374 | |
375 | if (!transform->isA<Split>()) { |
376 | continue; |
377 | } |
378 | |
379 | auto split = transform->as<Split>(); |
380 | // If this transform is a non-divisible split |
381 | if (divisible_splits.find(split) == divisible_splits.end()) { |
382 | // Mark outputs as dependent on a non_divisible split |
383 | auto out_ids = ir_utils::filterByType<IterDomain>(transform->outputs()); |
384 | depends_on_non_divisible_split.insert(out_ids.begin(), out_ids.end()); |
385 | } |
386 | } |
387 | } |
388 | |
389 | ContigIDs::ContigIDs( |
390 | const std::vector<IterDomain*>& ids, |
391 | const std::vector<IterDomain*>& root_domain, |
392 | const std::vector<bool>& root_contiguity, |
393 | const std::unordered_set<IterDomain*>& final_ids, |
394 | const std::unordered_map<IterDomain*, Val*>& index_map, |
395 | const std::unordered_set<Split*>& divisible_splits, |
396 | std::unordered_map<IterDomain*, IterDomain*> p2c_id_map, |
397 | bool ignore_indexability, |
398 | bool ignore_consistent_ordering) |
399 | : root_domain_(root_domain), |
400 | root_contiguity_(root_contiguity), |
401 | final_ids_(final_ids), |
402 | index_map_(index_map), |
403 | divisible_splits_(divisible_splits), |
404 | p2c_id_map_(std::move(p2c_id_map)), |
405 | ignore_indexability_(ignore_indexability), |
406 | ignore_consistent_ordering_(ignore_consistent_ordering), |
407 | non_divisible_id_info_(ids, root_domain_, divisible_splits_) { |
408 | if (ids.size() > 0) { |
409 | // This constructor doesn't provide the following information so it needs to |
410 | // be built. |
411 | ca_map_ = std::make_shared<ComputeAtMap>(ids[0]->fusion()); |
412 | halo_info_ = std::make_shared<HaloInfo>(ids[0]->fusion(), ca_map_); |
413 | concrete_info_ = |
414 | std::make_shared<ConcretizedBroadcastDomains>(ids[0]->fusion()); |
415 | |
416 | consistent_transform_info_ = std::make_unique<const OrderedIdInformation>( |
417 | ids, root_domain, concrete_info_); |
418 | } |
419 | build(ids); |
420 | } |
421 | |
422 | ContigIDs::ContigIDs( |
423 | const std::vector<IterDomain*>& ids, |
424 | const std::vector<IterDomain*>& root_domain, |
425 | const std::vector<bool>& root_contiguity, |
426 | const std::unordered_set<IterDomain*>& final_ids, |
427 | const std::unordered_map<IterDomain*, Val*>& index_map, |
428 | const std::unordered_set<Split*>& divisible_splits, |
429 | std::shared_ptr<const ComputeAtMap> ca_map, |
430 | std::shared_ptr<const HaloInfo> halo_info, |
431 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info, |
432 | std::unordered_map<IterDomain*, IterDomain*> p2c_id_map, |
433 | bool ignore_indexability, |
434 | bool ignore_consistent_ordering) |
435 | : root_domain_(root_domain), |
436 | root_contiguity_(root_contiguity), |
437 | final_ids_(final_ids), |
438 | index_map_(index_map), |
439 | divisible_splits_(divisible_splits), |
440 | ca_map_(ca_map), |
441 | halo_info_(halo_info), |
442 | concrete_info_(concrete_info), |
443 | p2c_id_map_(std::move(p2c_id_map)), |
444 | ignore_indexability_(ignore_indexability), |
445 | ignore_consistent_ordering_(ignore_consistent_ordering), |
446 | consistent_transform_info_(std::make_unique<const OrderedIdInformation>( |
447 | ids, |
448 | root_domain, |
449 | concrete_info_)), |
450 | non_divisible_id_info_(ids, root_domain, divisible_splits_) { |
451 | build(ids); |
452 | } |
453 | |
454 | ContigIDs ContigIDs::getNonContigIDs() { |
455 | return ContigIDs({}, {}, {}, {}, {}, {}); |
456 | } |
457 | |
458 | void ContigIDs::build(const std::vector<IterDomain*>& ids) { |
459 | if (ids.empty() || root_domain_.empty()) { |
460 | return; |
461 | } |
462 | |
463 | TORCH_INTERNAL_ASSERT( |
464 | root_domain_.size() == root_contiguity_.size(), |
465 | "Arguments don't match " , |
466 | root_domain_.size(), |
467 | " != " , |
468 | root_contiguity_.size()); |
469 | |
470 | for (const auto root_domain_i : c10::irange(root_domain_.size())) { |
471 | auto root_domain_id = root_domain_[root_domain_i]->as<IterDomain>(); |
472 | root_to_indexed_id_[root_domain_id] = root_domain_id; |
473 | // Initialize to false |
474 | is_contig_root_[root_domain_id] = false; |
475 | // If a root domain has halo, can't use merged domain even if |
476 | // both inputs are contiguous. HaloInfo is also initialized for |
477 | // rfactor root domains, which should just return "zero" |
478 | // RootAxisInfo. This should be safe as no rfactor tensor should |
479 | // need halo. |
480 | if (root_contiguity_[root_domain_i] && |
481 | !halo_info_->getRootAxisInfo(root_domain_id).hasHalo()) { |
482 | contig_ids_.emplace(root_domain_id); |
483 | is_contig_root_[root_domain_id] = true; |
484 | within_contig_ids_[root_domain_id] = std::unordered_set<IterDomain*>(); |
485 | } |
486 | } |
487 | |
488 | if (!contig_ids_.empty()) { |
489 | auto exprs = StmtSort::getExprsBetween( |
490 | ids[0]->fusion(), |
491 | {root_domain_.begin(), root_domain_.end()}, |
492 | {ids.begin(), ids.end()}); |
493 | for (auto expr : exprs) { |
494 | handle(expr); |
495 | } |
496 | } |
497 | } |
498 | |
499 | void ContigIDs::handle(Merge* merge) { |
500 | // If output is not consistently ordered or doesn't solely consume all root |
501 | // domains in its dependencies, then it can't be a contiguously indexable |
502 | // iterdomain. |
503 | if (!(ignore_consistent_ordering_ || |
504 | consistent_transform_info_->isConsistentlyOrdered(merge->out()))) { |
505 | return; |
506 | } |
507 | |
508 | if (!consistent_transform_info_->exclusivelyConsumesRoots(merge->out())) { |
509 | return; |
510 | } |
511 | |
512 | // If output is not "directly indexable" then it's definitely not contiguously |
513 | // indexable. |
514 | if (!ignore_indexability_ && !isIndexable(merge->out())) { |
515 | return; |
516 | } |
517 | |
518 | // If inputs are marked as final, stop |
519 | if (final_ids_.count(merge->inner()) || final_ids_.count(merge->outer())) { |
520 | return; |
521 | } |
522 | |
523 | // Check root domains for contiguity |
524 | auto root_ids_it = |
525 | consistent_transform_info_->idToRootIds().find(merge->out()); |
526 | |
527 | TORCH_INTERNAL_ASSERT( |
528 | root_ids_it != consistent_transform_info_->idToRootIds().end(), |
529 | "\nError in contiguous analysis, merge info doesn't exist for:\n" , |
530 | merge->toString(), |
531 | "\nId: " , |
532 | merge->out()->toString()); |
533 | |
534 | VectorOfUniqueEntries<IterDomain*> root_ids = root_ids_it->second; |
535 | |
536 | bool is_indexing_pass = !ignore_consistent_ordering_; |
537 | |
538 | IterDomain* last_root = nullptr; |
539 | for (auto root_id_i : c10::irange(root_domain_.size())) { |
540 | auto root_id = root_domain_[root_id_i]; |
541 | if (root_ids.has(root_id)) { |
542 | // ID found, remove it |
543 | root_ids.erase(root_id); |
544 | // If we're indexing: |
545 | // we could still potentially consider this ID linearly indexable, as we |
546 | // could multiple the index by the last root's stride. |
547 | // |
548 | // If we're computing predicates (ignore_consistent_ordering_==true), |
549 | // then we don't have this same constraint, we can just ignore |
550 | // contiguity of the roots all together. |
551 | if (!root_contiguity_[root_id_i] && is_indexing_pass) { |
552 | if (!root_ids.empty()) { |
553 | return; |
554 | } |
555 | } |
556 | last_root = root_id; |
557 | } |
558 | } |
559 | |
560 | // If there's a non_divisible split in the history of merge->out then it can't |
561 | // be contiguously indexable. |
562 | if (non_divisible_id_info_.dependsOnNonDivisibleSplit(merge->out())) { |
563 | return; |
564 | } |
565 | |
566 | // Now we know merge->out is a contiguously indexable ID |
567 | |
568 | TORCH_INTERNAL_ASSERT( |
569 | last_root != nullptr, |
570 | "Issue processing root ids for " , |
571 | merge->out()->toString()); |
572 | |
573 | // Reset root_ids |
574 | root_ids = root_ids_it->second; |
575 | for (auto root_id : root_ids) { |
576 | root_to_indexed_id_[root_id] = merge->out(); |
577 | } |
578 | |
579 | auto all_within_vals = DependencyCheck::getAllValsBetween( |
580 | {root_domain_.begin(), root_domain_.end()}, {merge->out()}); |
581 | auto all_within_ids = ir_utils::filterByType<IterDomain>(all_within_vals); |
582 | |
583 | std::unordered_set<IterDomain*> within_id_set( |
584 | all_within_ids.begin(), all_within_ids.end()); |
585 | |
586 | within_id_set.erase(merge->out()); |
587 | within_contig_ids_[merge->out()] = within_id_set; |
588 | for (auto id : all_within_ids) { |
589 | contig_ids_.erase(id); |
590 | } |
591 | |
592 | contig_ids_.emplace(merge->out()); |
593 | } |
594 | |
595 | IterDomain* ContigIDs::getMappedId(IterDomain* id) const { |
596 | auto it = p2c_id_map_.find(id); |
597 | if (it != p2c_id_map_.end()) { |
598 | return it->second; |
599 | } else { |
600 | return id; |
601 | } |
602 | } |
603 | |
604 | bool ContigIDs::isIndexable(IterDomain* id) const { |
605 | // If ID is mapped to consumer through persmissive map but not exact map it |
606 | // will not be mapped through to the exact map through the p2c map. Therefore |
607 | // reject because it involves broadcast resolution. |
608 | if (!ca_map_->idExistsInMap(getMappedId(id))) { |
609 | return false; |
610 | } |
611 | auto c_id = |
612 | ca_map_->getConcreteMappedID(getMappedId(id), IdMappingMode::EXACT); |
613 | return index_map_.find(c_id) != index_map_.end(); |
614 | } |
615 | |
616 | } // namespace cuda |
617 | } // namespace fuser |
618 | } // namespace jit |
619 | } // namespace torch |
620 | |