1#include <iter_visitor.h>
2#include <kernel_ir_dispatch.h>
3#include <lower2device.h>
4#include <lower_magic_zero.h>
5
6#include <lower_index_hoist.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13namespace {
14
15// Return leaf domains of a given domain.
16std::unordered_set<IterDomain*> getUsedLeafIds(
17 IterDomain* id,
18 TensorDomain* td) {
19 const auto all_vals_between = DependencyCheck::getAllValsBetween(
20 {id}, {td->domain().begin(), td->domain().end()});
21
22 std::unordered_set<IterDomain*> used_leaf_ids;
23
24 for (const auto leaf : td->domain()) {
25 if (std::find(all_vals_between.begin(), all_vals_between.end(), leaf) !=
26 all_vals_between.end()) {
27 used_leaf_ids.insert(leaf);
28 }
29 }
30
31 TORCH_INTERNAL_ASSERT(
32 !used_leaf_ids.empty(),
33 "No used id found: ",
34 id->toString(),
35 ", ",
36 td->toString());
37
38 return used_leaf_ids;
39}
40
41} // namespace
42
43CommonIndexKey::CommonIndexKey(
44 IterDomain* consumer_indexed_id,
45 TensorDomain* consumer_td,
46 TensorDomain* ref_td,
47 const std::unordered_map<IterDomain*, Val*>& ref_index_map,
48 const std::vector<kir::ForLoop*>& loops) {
49 auto gpu_lower = GpuLower::current();
50
51 concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID(
52 consumer_indexed_id, IdMappingMode::EXACT);
53
54 const auto consumer_leaf_ids =
55 getUsedLeafIds(consumer_indexed_id, consumer_td);
56
57 // Convert to Parallel concrete IDs to find matching loops.
58 std::unordered_set<IterDomain*> concrete_leaf_ids;
59 for (auto& id : consumer_leaf_ids) {
60 concrete_leaf_ids.insert(
61 gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP));
62 }
63
64 // Find used loops and their index vals
65 for (const auto i : c10::irange(loops.size())) {
66 auto loop = loops.at(i);
67 auto loop_id = gpu_lower->caMap()->getConcreteMappedID(
68 loop->iter_domain(), IdMappingMode::LOOP);
69 auto it = concrete_leaf_ids.find(loop_id);
70 if (it != concrete_leaf_ids.end()) {
71 // This leaf reference id is used for indexing the consumer id
72 used_loops_.push_back(loop);
73 auto index_it = ref_index_map.find(ref_td->axis(i));
74 TORCH_INTERNAL_ASSERT(
75 index_it != ref_index_map.end(),
76 "Index not found for leaf ID, ",
77 ref_td->axis(i)->toString());
78 loop_index_vals_.push_back(index_it->second);
79 }
80 }
81
82 TORCH_INTERNAL_ASSERT(
83 !used_loops_.empty(),
84 "No loop used for indexing found. ",
85 consumer_indexed_id->toString());
86
87 TORCH_INTERNAL_ASSERT(
88 consumer_leaf_ids.size() == used_loops_.size(),
89 "consumer_leaf_ids.size() = ",
90 consumer_leaf_ids.size(),
91 ", used_loops_.size() == ",
92 used_loops_.size(),
93 ", loops.size() == ",
94 loops.size());
95}
96
97CommonIndexKey::CommonIndexKey(
98 IterDomain* consumer_indexed_id,
99 TensorDomain* consumer_td,
100 const std::vector<IterDomain*>& loop_domains,
101 const std::unordered_map<IterDomain*, Val*>& loop_index_map,
102 const std::vector<kir::ForLoop*>& loops) {
103 auto gpu_lower = GpuLower::current();
104
105 concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID(
106 consumer_indexed_id, IdMappingMode::EXACT);
107
108 const auto consumer_leaf_ids =
109 getUsedLeafIds(consumer_indexed_id, consumer_td);
110
111 // Convert to Parallel concrete IDs to find matching loops.
112 std::unordered_set<IterDomain*> concrete_leaf_ids;
113 for (auto& id : consumer_leaf_ids) {
114 concrete_leaf_ids.insert(
115 gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP));
116 }
117
118 // Find used loops and their index vals
119 for (const auto i : c10::irange(loops.size())) {
120 auto loop = loops.at(i);
121 auto loop_id = gpu_lower->caMap()->getConcreteMappedID(
122 loop->iter_domain(), IdMappingMode::LOOP);
123 auto it = concrete_leaf_ids.find(loop_id);
124 if (it != concrete_leaf_ids.end()) {
125 // This leaf reference id is used for indexing the consumer id
126 used_loops_.push_back(loop);
127 auto loop_concrete_id = gpu_lower->caMap()->getConcreteMappedID(
128 loop_domains.at(i), IdMappingMode::EXACT);
129 auto index_it = loop_index_map.find(loop_concrete_id);
130 TORCH_INTERNAL_ASSERT(
131 index_it != loop_index_map.end(),
132 "Index not found for leaf ID, ",
133 loop_domains.at(i)->toString(),
134 ", concrete ID: ",
135 loop_concrete_id->toString());
136 loop_index_vals_.push_back(index_it->second);
137 }
138 }
139
140 TORCH_INTERNAL_ASSERT(
141 !used_loops_.empty(),
142 "No loop used for indexing found. ",
143 consumer_indexed_id->toString());
144
145 TORCH_INTERNAL_ASSERT(
146 consumer_leaf_ids.size() == used_loops_.size(),
147 "consumer_leaf_ids.size() = ",
148 consumer_leaf_ids.size(),
149 ", used_loops_.size() == ",
150 used_loops_.size(),
151 ", loops.size() == ",
152 loops.size());
153}
154
155bool CommonIndexKey::operator==(const CommonIndexKey& other) const {
156 auto gpu_lower = GpuLower::current();
157
158 if (concrete_indexed_id_ != other.concrete_indexed_id_) {
159 return false;
160 }
161
162 if (used_loops_.size() != other.used_loops_.size()) {
163 return false;
164 }
165
166 // Check if both CommonIndexKeys use the same loops. If not, it's
167 // still valid to share the same hoisted index as long as: 1) each
168 // loop pair is mapped with the CA index map, and 2) they are not
169 // instantiated as actual loops.
170 for (const auto i : c10::irange(used_loops_.size())) {
171 auto lhs_loop = used_loops_.at(i);
172 auto rhs_loop = other.used_loops_.at(i);
173 if (lhs_loop == rhs_loop) {
174 continue;
175 }
176 if (gpu_lower->caMap()->areMapped(
177 lhs_loop->iter_domain(),
178 rhs_loop->iter_domain(),
179 IdMappingMode::EXACT) &&
180 lhs_loop->isTrivial() && rhs_loop->isTrivial()) {
181 continue;
182 }
183 return false;
184 }
185
186 for (const auto i : c10::irange(loop_index_vals_.size())) {
187 auto lhs_index = loop_index_vals_.at(i);
188 auto rhs_index = other.loop_index_vals_.at(i);
189 if (lhs_index == rhs_index) {
190 continue;
191 }
192 // Initial index variables can have some additions such as magic
193 // zero and "1" when used in producer indexing for double buffered
194 // tensors. Thus, the initial variables themselves may be
195 // different, and its components need to be examined. An easy way
196 // is to flatten them to strings as follows.
197 auto lhs_str = loop_index_vals_.at(i)->toInlineString();
198 auto rhs_str = other.loop_index_vals_.at(i)->toInlineString();
199 if (lhs_str == rhs_str) {
200 continue;
201 }
202
203 return false;
204 }
205
206 return true;
207}
208
209std::string CommonIndexKey::toString() const {
210 TORCH_INTERNAL_ASSERT(concrete_indexed_id_ != nullptr);
211 std::stringstream ss;
212 ss << "CommonIndexKey: " << concrete_indexed_id_->toString();
213 ss << ", { ";
214 for (auto loop : used_loops_) {
215 ss << loop->iter_domain()->toString() << " ";
216 }
217 ss << "}";
218 ss << ", { ";
219 for (auto val : loop_index_vals_) {
220 ss << val->toString() << " ";
221 }
222 ss << "}";
223 return ss.str();
224}
225
226std::pair<Val*, bool> CommonIndexMap::insert(
227 IterDomain* indexed_consumer_id,
228 TensorDomain* consumer_td,
229 TensorDomain* ref_td,
230 const std::unordered_map<IterDomain*, Val*>& ref_index_map,
231 const std::vector<kir::ForLoop*>& loops,
232 Val* index) {
233 if (index->definition() == nullptr) {
234 // Only defined val is eligible to hoist
235 return {index, false};
236 }
237
238 const CommonIndexKey key(
239 indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops);
240
241 return tryInsertNewIndex(key, index);
242}
243
244std::pair<Val*, bool> CommonIndexMap::insert(
245 IterDomain* indexed_consumer_id,
246 TensorDomain* consumer_td,
247 const std::vector<IterDomain*>& loop_domains,
248 const std::unordered_map<IterDomain*, Val*>& loop_index_map,
249 const std::vector<kir::ForLoop*>& loops,
250 Val* index) {
251 if (index->definition() == nullptr) {
252 // Only defined val is eligible to hoist
253 return {index, false};
254 }
255
256 const CommonIndexKey key(
257 indexed_consumer_id, consumer_td, loop_domains, loop_index_map, loops);
258
259 return tryInsertNewIndex(key, index);
260}
261
262std::pair<Val*, bool> CommonIndexMap::tryInsertNewIndex(
263 CommonIndexKey key,
264 Val* index) {
265 Val* hoisted_index = nullptr;
266 bool new_index_inserted = false;
267
268 // Hoisting is not possible if any of used loops is grouped.
269 if (std::any_of(
270 key.usedLoops().begin(), key.usedLoops().end(), [](const auto loop) {
271 return loop->iter_domain()->getParallelType() ==
272 ParallelType::Group;
273 })) {
274 return {index, false};
275 }
276
277 // If already mapped, return the previously mapped index
278 auto it = common_index_map_.find(key);
279 if (it != common_index_map_.end()) {
280 hoisted_index = it->second;
281 new_index_inserted = false;
282 ++use_counts_.at(key);
283 } else {
284 common_index_map_.emplace(key, index);
285 hoisted_index = index;
286 new_index_inserted = true;
287 use_counts_[key] = 1;
288 }
289 return {hoisted_index, new_index_inserted};
290}
291
292namespace {
293
294//! Insertion point of allocation
295struct CommonIndexInsertionInfo {
296 Expr* ref = nullptr;
297 kir::Scope* scope = nullptr;
298};
299
300// Inserts allocations of hoisted indices
301class CommonIndexInserter : private kir::ExprMutator {
302 public:
303 static std::vector<Expr*> run(
304 const std::vector<Expr*>& exprs,
305 const CommonIndexMap& common_indices) {
306 CommonIndexInserter inserter(exprs, common_indices);
307 return inserter.exprs_;
308 }
309
310 private:
311 CommonIndexInserter(
312 const std::vector<Expr*>& exprs,
313 const CommonIndexMap& common_index_map)
314 : common_index_map_(common_index_map) {
315 // Create a map to keys from loops where they should be inserted
316 for (const auto& kv : common_index_map.commonIndexMap()) {
317 const auto& key = kv.first;
318 // Only consider indices used multiple times
319 if (!usedMultipleTimes(key)) {
320 continue;
321 }
322 TORCH_INTERNAL_ASSERT(!key.usedLoops().empty());
323 auto insertion_loop = key.usedLoops().back();
324 innermost_used_loop_map_[insertion_loop].push_back(key);
325 }
326
327 traverseAndInsert(exprs);
328 }
329
330 CommonIndexInsertionInfo findInsertionPoint(
331 const CommonIndexKey& key,
332 kir::ForLoop* current_loop) const {
333 CommonIndexInsertionInfo info;
334
335 // Allocation must be inside any used non-trivial loop. Since the
336 // loop index value is constant if a loop is trivial, allocation
337 // does not need to be inside trivial loops.
338 for (const auto loop : key.usedLoops()) {
339 if (!loop->isTrivial()) {
340 info.ref = loop->body()[0];
341 info.scope = &(loop->body());
342 }
343 }
344
345 // If no non-trivial used loop is found, insert at the top-level
346 // scope just before the outer-most loop.
347 if (info.ref == nullptr) {
348 info.ref = scope_exprs_.empty() ? current_loop : scope_exprs_.at(0);
349 info.scope = nullptr;
350 }
351
352 return info;
353 }
354
355 using kir::ExprMutator::handle;
356
357 void handle(kir::ForLoop* loop) final {
358 auto innermost_loop_map_it = innermost_used_loop_map_.find(loop);
359 if (innermost_loop_map_it == innermost_used_loop_map_.end()) {
360 kir::ExprMutator::handle(loop);
361 return;
362 }
363
364 for (const auto& key : innermost_loop_map_it->second) {
365 auto common_index = common_index_map_.commonIndexMap().at(key);
366
367 // Insert only when the index is used multiple times and is not
368 // yet inserted.
369 if (inserted_indices_.find(common_index) != inserted_indices_.end()) {
370 continue;
371 }
372
373 // Make the type of the hoisted index be the index type of the
374 // kernel, which can be either int64_t or int. Not very clean,
375 // but this seems to be the quickest way to use the index type
376 // as we don't have a scalar IR node for the index type.
377 common_index->resolveIndexDtype();
378
379 auto alloc = IrBuilder::create<kir::Allocate>(
380 common_index,
381 MemoryType::Local,
382 GpuLower::current()->kernel()->oneVal());
383 const auto common_index_def = common_index->definition();
384 TORCH_INTERNAL_ASSERT(
385 common_index_def != nullptr,
386 "Hosted index must have a definition. ",
387 common_index->toString());
388
389 const auto insertion_info = findInsertionPoint(key, loop);
390 registerInsertBefore(insertion_info.ref, alloc, insertion_info.scope);
391 registerInsertBefore(
392 insertion_info.ref, common_index_def, insertion_info.scope);
393
394 // Track inserted index
395 inserted_indices_.emplace(common_index);
396 }
397
398 kir::ExprMutator::handle(loop);
399 }
400
401 bool usedMultipleTimes(const CommonIndexKey& key) {
402 auto it = common_index_map_.useCounts().find(key);
403 TORCH_INTERNAL_ASSERT(
404 it != common_index_map_.useCounts().end(),
405 "Key not found in the use-count map: ",
406 key.toString());
407 return it->second > 1;
408 }
409
410 private:
411 const CommonIndexMap& common_index_map_;
412 //! Map to CommonIndexKeys from their innermost used loops
413 std::unordered_map<kir::ForLoop*, std::vector<CommonIndexKey>>
414 innermost_used_loop_map_;
415 //! Keep track of inserted indices
416 std::unordered_set<Val*> inserted_indices_;
417};
418
419} // namespace
420
421std::vector<Expr*> allocateCommonIndices(const std::vector<Expr*>& exprs) {
422 return CommonIndexInserter::run(exprs, GpuLower::current()->commonIndexMap());
423}
424
425} // namespace cuda
426} // namespace fuser
427} // namespace jit
428} // namespace torch
429