1 | #include <iter_visitor.h> |
2 | #include <kernel_ir_dispatch.h> |
3 | #include <lower2device.h> |
4 | #include <lower_magic_zero.h> |
5 | |
6 | #include <lower_index_hoist.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | namespace { |
14 | |
15 | // Return leaf domains of a given domain. |
16 | std::unordered_set<IterDomain*> getUsedLeafIds( |
17 | IterDomain* id, |
18 | TensorDomain* td) { |
19 | const auto all_vals_between = DependencyCheck::getAllValsBetween( |
20 | {id}, {td->domain().begin(), td->domain().end()}); |
21 | |
22 | std::unordered_set<IterDomain*> used_leaf_ids; |
23 | |
24 | for (const auto leaf : td->domain()) { |
25 | if (std::find(all_vals_between.begin(), all_vals_between.end(), leaf) != |
26 | all_vals_between.end()) { |
27 | used_leaf_ids.insert(leaf); |
28 | } |
29 | } |
30 | |
31 | TORCH_INTERNAL_ASSERT( |
32 | !used_leaf_ids.empty(), |
33 | "No used id found: " , |
34 | id->toString(), |
35 | ", " , |
36 | td->toString()); |
37 | |
38 | return used_leaf_ids; |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | CommonIndexKey::CommonIndexKey( |
44 | IterDomain* consumer_indexed_id, |
45 | TensorDomain* consumer_td, |
46 | TensorDomain* ref_td, |
47 | const std::unordered_map<IterDomain*, Val*>& ref_index_map, |
48 | const std::vector<kir::ForLoop*>& loops) { |
49 | auto gpu_lower = GpuLower::current(); |
50 | |
51 | concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID( |
52 | consumer_indexed_id, IdMappingMode::EXACT); |
53 | |
54 | const auto consumer_leaf_ids = |
55 | getUsedLeafIds(consumer_indexed_id, consumer_td); |
56 | |
57 | // Convert to Parallel concrete IDs to find matching loops. |
58 | std::unordered_set<IterDomain*> concrete_leaf_ids; |
59 | for (auto& id : consumer_leaf_ids) { |
60 | concrete_leaf_ids.insert( |
61 | gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP)); |
62 | } |
63 | |
64 | // Find used loops and their index vals |
65 | for (const auto i : c10::irange(loops.size())) { |
66 | auto loop = loops.at(i); |
67 | auto loop_id = gpu_lower->caMap()->getConcreteMappedID( |
68 | loop->iter_domain(), IdMappingMode::LOOP); |
69 | auto it = concrete_leaf_ids.find(loop_id); |
70 | if (it != concrete_leaf_ids.end()) { |
71 | // This leaf reference id is used for indexing the consumer id |
72 | used_loops_.push_back(loop); |
73 | auto index_it = ref_index_map.find(ref_td->axis(i)); |
74 | TORCH_INTERNAL_ASSERT( |
75 | index_it != ref_index_map.end(), |
76 | "Index not found for leaf ID, " , |
77 | ref_td->axis(i)->toString()); |
78 | loop_index_vals_.push_back(index_it->second); |
79 | } |
80 | } |
81 | |
82 | TORCH_INTERNAL_ASSERT( |
83 | !used_loops_.empty(), |
84 | "No loop used for indexing found. " , |
85 | consumer_indexed_id->toString()); |
86 | |
87 | TORCH_INTERNAL_ASSERT( |
88 | consumer_leaf_ids.size() == used_loops_.size(), |
89 | "consumer_leaf_ids.size() = " , |
90 | consumer_leaf_ids.size(), |
91 | ", used_loops_.size() == " , |
92 | used_loops_.size(), |
93 | ", loops.size() == " , |
94 | loops.size()); |
95 | } |
96 | |
97 | CommonIndexKey::CommonIndexKey( |
98 | IterDomain* consumer_indexed_id, |
99 | TensorDomain* consumer_td, |
100 | const std::vector<IterDomain*>& loop_domains, |
101 | const std::unordered_map<IterDomain*, Val*>& loop_index_map, |
102 | const std::vector<kir::ForLoop*>& loops) { |
103 | auto gpu_lower = GpuLower::current(); |
104 | |
105 | concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID( |
106 | consumer_indexed_id, IdMappingMode::EXACT); |
107 | |
108 | const auto consumer_leaf_ids = |
109 | getUsedLeafIds(consumer_indexed_id, consumer_td); |
110 | |
111 | // Convert to Parallel concrete IDs to find matching loops. |
112 | std::unordered_set<IterDomain*> concrete_leaf_ids; |
113 | for (auto& id : consumer_leaf_ids) { |
114 | concrete_leaf_ids.insert( |
115 | gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP)); |
116 | } |
117 | |
118 | // Find used loops and their index vals |
119 | for (const auto i : c10::irange(loops.size())) { |
120 | auto loop = loops.at(i); |
121 | auto loop_id = gpu_lower->caMap()->getConcreteMappedID( |
122 | loop->iter_domain(), IdMappingMode::LOOP); |
123 | auto it = concrete_leaf_ids.find(loop_id); |
124 | if (it != concrete_leaf_ids.end()) { |
125 | // This leaf reference id is used for indexing the consumer id |
126 | used_loops_.push_back(loop); |
127 | auto loop_concrete_id = gpu_lower->caMap()->getConcreteMappedID( |
128 | loop_domains.at(i), IdMappingMode::EXACT); |
129 | auto index_it = loop_index_map.find(loop_concrete_id); |
130 | TORCH_INTERNAL_ASSERT( |
131 | index_it != loop_index_map.end(), |
132 | "Index not found for leaf ID, " , |
133 | loop_domains.at(i)->toString(), |
134 | ", concrete ID: " , |
135 | loop_concrete_id->toString()); |
136 | loop_index_vals_.push_back(index_it->second); |
137 | } |
138 | } |
139 | |
140 | TORCH_INTERNAL_ASSERT( |
141 | !used_loops_.empty(), |
142 | "No loop used for indexing found. " , |
143 | consumer_indexed_id->toString()); |
144 | |
145 | TORCH_INTERNAL_ASSERT( |
146 | consumer_leaf_ids.size() == used_loops_.size(), |
147 | "consumer_leaf_ids.size() = " , |
148 | consumer_leaf_ids.size(), |
149 | ", used_loops_.size() == " , |
150 | used_loops_.size(), |
151 | ", loops.size() == " , |
152 | loops.size()); |
153 | } |
154 | |
155 | bool CommonIndexKey::operator==(const CommonIndexKey& other) const { |
156 | auto gpu_lower = GpuLower::current(); |
157 | |
158 | if (concrete_indexed_id_ != other.concrete_indexed_id_) { |
159 | return false; |
160 | } |
161 | |
162 | if (used_loops_.size() != other.used_loops_.size()) { |
163 | return false; |
164 | } |
165 | |
166 | // Check if both CommonIndexKeys use the same loops. If not, it's |
167 | // still valid to share the same hoisted index as long as: 1) each |
168 | // loop pair is mapped with the CA index map, and 2) they are not |
169 | // instantiated as actual loops. |
170 | for (const auto i : c10::irange(used_loops_.size())) { |
171 | auto lhs_loop = used_loops_.at(i); |
172 | auto rhs_loop = other.used_loops_.at(i); |
173 | if (lhs_loop == rhs_loop) { |
174 | continue; |
175 | } |
176 | if (gpu_lower->caMap()->areMapped( |
177 | lhs_loop->iter_domain(), |
178 | rhs_loop->iter_domain(), |
179 | IdMappingMode::EXACT) && |
180 | lhs_loop->isTrivial() && rhs_loop->isTrivial()) { |
181 | continue; |
182 | } |
183 | return false; |
184 | } |
185 | |
186 | for (const auto i : c10::irange(loop_index_vals_.size())) { |
187 | auto lhs_index = loop_index_vals_.at(i); |
188 | auto rhs_index = other.loop_index_vals_.at(i); |
189 | if (lhs_index == rhs_index) { |
190 | continue; |
191 | } |
192 | // Initial index variables can have some additions such as magic |
193 | // zero and "1" when used in producer indexing for double buffered |
194 | // tensors. Thus, the initial variables themselves may be |
195 | // different, and its components need to be examined. An easy way |
196 | // is to flatten them to strings as follows. |
197 | auto lhs_str = loop_index_vals_.at(i)->toInlineString(); |
198 | auto rhs_str = other.loop_index_vals_.at(i)->toInlineString(); |
199 | if (lhs_str == rhs_str) { |
200 | continue; |
201 | } |
202 | |
203 | return false; |
204 | } |
205 | |
206 | return true; |
207 | } |
208 | |
209 | std::string CommonIndexKey::toString() const { |
210 | TORCH_INTERNAL_ASSERT(concrete_indexed_id_ != nullptr); |
211 | std::stringstream ss; |
212 | ss << "CommonIndexKey: " << concrete_indexed_id_->toString(); |
213 | ss << ", { " ; |
214 | for (auto loop : used_loops_) { |
215 | ss << loop->iter_domain()->toString() << " " ; |
216 | } |
217 | ss << "}" ; |
218 | ss << ", { " ; |
219 | for (auto val : loop_index_vals_) { |
220 | ss << val->toString() << " " ; |
221 | } |
222 | ss << "}" ; |
223 | return ss.str(); |
224 | } |
225 | |
226 | std::pair<Val*, bool> CommonIndexMap::insert( |
227 | IterDomain* indexed_consumer_id, |
228 | TensorDomain* consumer_td, |
229 | TensorDomain* ref_td, |
230 | const std::unordered_map<IterDomain*, Val*>& ref_index_map, |
231 | const std::vector<kir::ForLoop*>& loops, |
232 | Val* index) { |
233 | if (index->definition() == nullptr) { |
234 | // Only defined val is eligible to hoist |
235 | return {index, false}; |
236 | } |
237 | |
238 | const CommonIndexKey key( |
239 | indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops); |
240 | |
241 | return tryInsertNewIndex(key, index); |
242 | } |
243 | |
244 | std::pair<Val*, bool> CommonIndexMap::insert( |
245 | IterDomain* indexed_consumer_id, |
246 | TensorDomain* consumer_td, |
247 | const std::vector<IterDomain*>& loop_domains, |
248 | const std::unordered_map<IterDomain*, Val*>& loop_index_map, |
249 | const std::vector<kir::ForLoop*>& loops, |
250 | Val* index) { |
251 | if (index->definition() == nullptr) { |
252 | // Only defined val is eligible to hoist |
253 | return {index, false}; |
254 | } |
255 | |
256 | const CommonIndexKey key( |
257 | indexed_consumer_id, consumer_td, loop_domains, loop_index_map, loops); |
258 | |
259 | return tryInsertNewIndex(key, index); |
260 | } |
261 | |
262 | std::pair<Val*, bool> CommonIndexMap::tryInsertNewIndex( |
263 | CommonIndexKey key, |
264 | Val* index) { |
265 | Val* hoisted_index = nullptr; |
266 | bool new_index_inserted = false; |
267 | |
268 | // Hoisting is not possible if any of used loops is grouped. |
269 | if (std::any_of( |
270 | key.usedLoops().begin(), key.usedLoops().end(), [](const auto loop) { |
271 | return loop->iter_domain()->getParallelType() == |
272 | ParallelType::Group; |
273 | })) { |
274 | return {index, false}; |
275 | } |
276 | |
277 | // If already mapped, return the previously mapped index |
278 | auto it = common_index_map_.find(key); |
279 | if (it != common_index_map_.end()) { |
280 | hoisted_index = it->second; |
281 | new_index_inserted = false; |
282 | ++use_counts_.at(key); |
283 | } else { |
284 | common_index_map_.emplace(key, index); |
285 | hoisted_index = index; |
286 | new_index_inserted = true; |
287 | use_counts_[key] = 1; |
288 | } |
289 | return {hoisted_index, new_index_inserted}; |
290 | } |
291 | |
292 | namespace { |
293 | |
294 | //! Insertion point of allocation |
295 | struct CommonIndexInsertionInfo { |
296 | Expr* ref = nullptr; |
297 | kir::Scope* scope = nullptr; |
298 | }; |
299 | |
300 | // Inserts allocations of hoisted indices |
301 | class CommonIndexInserter : private kir::ExprMutator { |
302 | public: |
303 | static std::vector<Expr*> run( |
304 | const std::vector<Expr*>& exprs, |
305 | const CommonIndexMap& common_indices) { |
306 | CommonIndexInserter inserter(exprs, common_indices); |
307 | return inserter.exprs_; |
308 | } |
309 | |
310 | private: |
311 | CommonIndexInserter( |
312 | const std::vector<Expr*>& exprs, |
313 | const CommonIndexMap& common_index_map) |
314 | : common_index_map_(common_index_map) { |
315 | // Create a map to keys from loops where they should be inserted |
316 | for (const auto& kv : common_index_map.commonIndexMap()) { |
317 | const auto& key = kv.first; |
318 | // Only consider indices used multiple times |
319 | if (!usedMultipleTimes(key)) { |
320 | continue; |
321 | } |
322 | TORCH_INTERNAL_ASSERT(!key.usedLoops().empty()); |
323 | auto insertion_loop = key.usedLoops().back(); |
324 | innermost_used_loop_map_[insertion_loop].push_back(key); |
325 | } |
326 | |
327 | traverseAndInsert(exprs); |
328 | } |
329 | |
330 | CommonIndexInsertionInfo findInsertionPoint( |
331 | const CommonIndexKey& key, |
332 | kir::ForLoop* current_loop) const { |
333 | CommonIndexInsertionInfo info; |
334 | |
335 | // Allocation must be inside any used non-trivial loop. Since the |
336 | // loop index value is constant if a loop is trivial, allocation |
337 | // does not need to be inside trivial loops. |
338 | for (const auto loop : key.usedLoops()) { |
339 | if (!loop->isTrivial()) { |
340 | info.ref = loop->body()[0]; |
341 | info.scope = &(loop->body()); |
342 | } |
343 | } |
344 | |
345 | // If no non-trivial used loop is found, insert at the top-level |
346 | // scope just before the outer-most loop. |
347 | if (info.ref == nullptr) { |
348 | info.ref = scope_exprs_.empty() ? current_loop : scope_exprs_.at(0); |
349 | info.scope = nullptr; |
350 | } |
351 | |
352 | return info; |
353 | } |
354 | |
355 | using kir::ExprMutator::handle; |
356 | |
357 | void handle(kir::ForLoop* loop) final { |
358 | auto innermost_loop_map_it = innermost_used_loop_map_.find(loop); |
359 | if (innermost_loop_map_it == innermost_used_loop_map_.end()) { |
360 | kir::ExprMutator::handle(loop); |
361 | return; |
362 | } |
363 | |
364 | for (const auto& key : innermost_loop_map_it->second) { |
365 | auto common_index = common_index_map_.commonIndexMap().at(key); |
366 | |
367 | // Insert only when the index is used multiple times and is not |
368 | // yet inserted. |
369 | if (inserted_indices_.find(common_index) != inserted_indices_.end()) { |
370 | continue; |
371 | } |
372 | |
373 | // Make the type of the hoisted index be the index type of the |
374 | // kernel, which can be either int64_t or int. Not very clean, |
375 | // but this seems to be the quickest way to use the index type |
376 | // as we don't have a scalar IR node for the index type. |
377 | common_index->resolveIndexDtype(); |
378 | |
379 | auto alloc = IrBuilder::create<kir::Allocate>( |
380 | common_index, |
381 | MemoryType::Local, |
382 | GpuLower::current()->kernel()->oneVal()); |
383 | const auto common_index_def = common_index->definition(); |
384 | TORCH_INTERNAL_ASSERT( |
385 | common_index_def != nullptr, |
386 | "Hosted index must have a definition. " , |
387 | common_index->toString()); |
388 | |
389 | const auto insertion_info = findInsertionPoint(key, loop); |
390 | registerInsertBefore(insertion_info.ref, alloc, insertion_info.scope); |
391 | registerInsertBefore( |
392 | insertion_info.ref, common_index_def, insertion_info.scope); |
393 | |
394 | // Track inserted index |
395 | inserted_indices_.emplace(common_index); |
396 | } |
397 | |
398 | kir::ExprMutator::handle(loop); |
399 | } |
400 | |
401 | bool usedMultipleTimes(const CommonIndexKey& key) { |
402 | auto it = common_index_map_.useCounts().find(key); |
403 | TORCH_INTERNAL_ASSERT( |
404 | it != common_index_map_.useCounts().end(), |
405 | "Key not found in the use-count map: " , |
406 | key.toString()); |
407 | return it->second > 1; |
408 | } |
409 | |
410 | private: |
411 | const CommonIndexMap& common_index_map_; |
412 | //! Map to CommonIndexKeys from their innermost used loops |
413 | std::unordered_map<kir::ForLoop*, std::vector<CommonIndexKey>> |
414 | innermost_used_loop_map_; |
415 | //! Keep track of inserted indices |
416 | std::unordered_set<Val*> inserted_indices_; |
417 | }; |
418 | |
419 | } // namespace |
420 | |
421 | std::vector<Expr*> allocateCommonIndices(const std::vector<Expr*>& exprs) { |
422 | return CommonIndexInserter::run(exprs, GpuLower::current()->commonIndexMap()); |
423 | } |
424 | |
425 | } // namespace cuda |
426 | } // namespace fuser |
427 | } // namespace jit |
428 | } // namespace torch |
429 | |