1 | #include <instrumentation.h> |
2 | #include <ir_iostream.h> |
3 | #include <kernel_expr_evaluator.h> |
4 | #include <kernel_ir.h> |
5 | #include <kernel_ir_dispatch.h> |
6 | #include <lower2device.h> |
7 | #include <lower_allocation.h> |
8 | |
9 | #include <unordered_set> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | namespace { |
17 | |
18 | class AllocationInserter : public kir::ExprMutator { |
19 | private: |
20 | using kir::ExprMutator::handle; |
21 | |
22 | // Expanded version of BasicAllocInfo in lower_utils.h helps to track |
23 | // additional information |
24 | struct AllocationInformation { |
25 | // The for loop that the initialization of this allocation must be |
26 | // placed in, nullptr if not within a loop |
27 | kir::ForLoop* init_for_loop = nullptr; |
28 | |
29 | // The expression that the initialization of this allocation must |
30 | // be placed before |
31 | Expr* init_place_before = nullptr; |
32 | |
33 | // Keep track of the actual allocation loop. This can be different |
34 | // from init_for_loop only with unswitched shared memory allocations, |
35 | // which are moved outer loops to avoid duplicated allocations |
36 | // (see issue #1133). |
37 | kir::ForLoop* alloc_for_loop = nullptr; |
38 | |
39 | // The expression that this allocation must be placed |
40 | // before. Similar to alloc_for_loop, this is different from |
41 | // init_place_before only with unswitched shared memory allocations. |
42 | Expr* alloc_place_before = nullptr; |
43 | |
44 | // The allocation position relative to buffer |
45 | size_t alloc_pos = 0; |
46 | |
47 | // The buffer this allocation is for |
48 | TensorView* buffer = nullptr; |
49 | |
50 | // Info to transfer to GPU lower |
51 | bool has_halo = false; |
52 | |
53 | // Local Iterdomains that this allocation covers |
54 | std::unique_ptr<std::vector<IterDomain*>> allocation_domains; |
55 | }; |
56 | |
57 | // Find allocation point |
58 | // Fills info.buffer, info.alloc_pos, info.init_for_loop, |
59 | // info.init_place_before, info.alloc_for_loop, info.alloc_place_before |
60 | void fillAllocationInformation(AllocationInformation& info, Expr* expr) { |
61 | auto loop_alloc_info = |
62 | lower_utils::getAllocInformation(info.buffer, for_loops_); |
63 | |
64 | info.init_for_loop = loop_alloc_info.init_for_loop; |
65 | info.alloc_for_loop = loop_alloc_info.alloc_for_loop; |
66 | info.alloc_pos = loop_alloc_info.alloc_pos; |
67 | |
68 | auto next_fl = [](kir::ForLoop* fl, const std::vector<kir::ForLoop*> fls) { |
69 | for (auto i : c10::irange(fls.size())) { |
70 | if (fl == fls[i]) { |
71 | if (i + 1 < fls.size()) { |
72 | return fls[i + 1]; |
73 | } |
74 | } |
75 | } |
76 | TORCH_INTERNAL_ASSERT(false, "Could not find desired loop." ); |
77 | }; |
78 | |
79 | if (info.init_for_loop == nullptr) { |
80 | info.init_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; |
81 | } else { |
82 | if (info.init_for_loop == for_loops_.back()) { |
83 | // Inline allocation, place before expr |
84 | info.init_place_before = expr; |
85 | } else { |
86 | // Place allocation after the last computeAt axis |
87 | // TODO: may be more efficient to place before the first non-computeAt |
88 | // axis |
89 | info.init_place_before = next_fl(info.init_for_loop, for_loops_); |
90 | } |
91 | } |
92 | |
93 | // Set the allocation loop and the place_before expression in the |
94 | // same way as the initialization loop and place_before expression |
95 | if (info.alloc_for_loop == info.init_for_loop) { |
96 | info.alloc_for_loop = info.init_for_loop; |
97 | info.alloc_place_before = info.init_place_before; |
98 | } else { |
99 | if (info.alloc_for_loop == nullptr) { |
100 | info.alloc_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; |
101 | } else { |
102 | // Since there must be an inner unswitched domain, |
103 | // alloc_for_loop should never be the inner-most loop. |
104 | TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops_.back()); |
105 | info.alloc_place_before = next_fl(info.alloc_for_loop, for_loops_); |
106 | } |
107 | } |
108 | } |
109 | |
110 | // Create initialization expression if init_val is non-null. |
111 | Expr* createInitExpr(AllocationInformation& info, Val* init_val) { |
112 | if (init_val == nullptr) { |
113 | return nullptr; |
114 | } |
115 | |
116 | std::vector<IterDomain*> init_dims; |
117 | for (const auto axis_i : |
118 | c10::irange(info.alloc_pos, info.buffer->nDims())) { |
119 | if (info.buffer->axis(axis_i)->isReduction() || |
120 | info.buffer->axis(axis_i)->isBroadcast()) { |
121 | continue; |
122 | } |
123 | auto concrete_id = gpu_lower->caMap()->getConcreteMappedID( |
124 | info.buffer->axis(axis_i), IdMappingMode::LOOP); |
125 | init_dims.push_back(concrete_id); |
126 | } |
127 | Expr* init_expr = |
128 | IrBuilder::create<UnaryOp>(UnaryOpType::Set, info.buffer, init_val); |
129 | for (auto init_loop_it = init_dims.rbegin(); |
130 | init_loop_it != init_dims.rend(); |
131 | ++init_loop_it) { |
132 | auto id = *init_loop_it; |
133 | kir::ForLoop* new_loop = nullptr; |
134 | auto extent_with_halo = gpu_lower->haloInfo()->getExtent(id); |
135 | if (extent_with_halo) { |
136 | new_loop = IrBuilder::create<kir::ForLoop>( |
137 | id, |
138 | IrBuilder::create<Int>(c10::nullopt), |
139 | nullptr, |
140 | extent_with_halo, |
141 | nullptr, |
142 | false, |
143 | nullptr, |
144 | false, |
145 | DoubleBufferLoopStage::NotApplicable); |
146 | } else { |
147 | new_loop = IrBuilder::create<kir::ForLoop>(id); |
148 | } |
149 | new_loop->body().push_back(init_expr); |
150 | init_expr = new_loop; |
151 | } |
152 | return init_expr; |
153 | } |
154 | |
155 | std::vector<Val*> getGlobalAllocationSizes(AllocationInformation& info) { |
156 | const auto& domain = info.buffer->domain(); |
157 | const auto& maybe_rfactor_domain = domain->hasRFactor() |
158 | ? domain->getRFactorDomain() |
159 | : domain->getRootDomain(); |
160 | |
161 | std::vector<Val*> alloc_dims; |
162 | |
163 | for (const auto id : maybe_rfactor_domain) { |
164 | if (id->isReduction() || id->isStride() || id->isBroadcast()) { |
165 | continue; |
166 | } |
167 | auto extent = id->extent(); |
168 | // Use halo-extended extent if found |
169 | auto halo_extent = gpu_lower->haloInfo()->getRootAxisInfo(id); |
170 | if (halo_extent.hasHalo()) { |
171 | extent = IrBuilder::addExpr( |
172 | extent, IrBuilder::create<Int>(halo_extent.width())); |
173 | } |
174 | alloc_dims.push_back(extent); |
175 | } |
176 | |
177 | return alloc_dims; |
178 | } |
179 | |
180 | // Get allocation extents of root axes with halo |
181 | // |
182 | // Allocation can be done with leaf IDs with halo as well, but |
183 | // allocation size could be larger than necessary. |
184 | // |
185 | // For example, suppose the shift offset of an axis is 1. When it is |
186 | // split by N, the halo size of the inner output is N+1. When the |
187 | // allocation only has the inner split output, the allocation size |
188 | // would be N+1. Suppose that ID is further split by M, the output |
189 | // extents would be N/M and M+1. The allocation size based on the |
190 | // leaves would be N/M*(M+1) or N+N/M, which is larger than N+1. |
191 | // |
192 | // This function tries to propagate back halo informatin to root |
193 | // axes to avoid inflating allocations. It fails when merged domains |
194 | // are split and only one of the split outputs is used for |
195 | // allocations since in such a case we can't un-merge and properly |
196 | // determine the extents of the merge inputs. Currently, that |
197 | // results in an exception, but it may be more reasonable to simply |
198 | // fall back to the leaf-based allocation. |
199 | // |
200 | // See the FusionShiftDoubleSplit test for an example case. |
201 | std::vector<Val*> getNonGlobalAllocExprWithHalo( |
202 | TensorView* tv, |
203 | const std::vector<IterDomain*>& alloc_domains) { |
204 | std::vector<Val*> start_vals; |
205 | std::transform( |
206 | alloc_domains.begin(), |
207 | alloc_domains.end(), |
208 | std::back_inserter(start_vals), |
209 | [](IterDomain* dom) { return dom->as<Val>(); }); |
210 | |
211 | // Get all exprs involved in generating the allocation IDs |
212 | auto exprs = StmtSort::getExprs(tv->fusion(), start_vals); |
213 | |
214 | // Get the halo extent if found |
215 | auto getExtent = [this](IterDomain* id) { |
216 | auto extent = gpu_lower->haloInfo()->getExtent(id); |
217 | if (extent == nullptr) { |
218 | extent = id->extent(); |
219 | } |
220 | return extent; |
221 | }; |
222 | |
223 | std::unordered_map<IterDomain*, Val*> known_extents; |
224 | |
225 | // IterDomains that are allocated fully. For example, if an ID is |
226 | // split and only one of them is used for allocation, that's not |
227 | // considered full. Only full domains can be unmerged, which is |
228 | // needed to propagate back the halo information to root domains. |
229 | std::unordered_set<IterDomain*> full_domains; |
230 | |
231 | for (auto alloc_domain : alloc_domains) { |
232 | known_extents.insert({alloc_domain, getExtent(alloc_domain)}); |
233 | full_domains.insert(alloc_domain); |
234 | } |
235 | |
236 | for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) { |
237 | auto expr = *it; |
238 | if (auto merge = dynamic_cast<Merge*>(expr)) { |
239 | auto out_it = known_extents.find(merge->out()); |
240 | // If nothing is know about the out id, no propagation can be |
241 | // done. Note that's not necessarily an error. |
242 | if (out_it == known_extents.end()) { |
243 | continue; |
244 | } |
245 | // Similarly, if the extent of the out id is not full extent, |
246 | // we can't un-merge it. |
247 | if (full_domains.find(merge->out()) == full_domains.end()) { |
248 | continue; |
249 | } |
250 | // Since the extent of the out id is full, the extent of each |
251 | // of the input axes is also full |
252 | known_extents.insert({merge->inner(), getExtent(merge->inner())}); |
253 | full_domains.insert(merge->inner()); |
254 | known_extents.insert({merge->outer(), getExtent(merge->outer())}); |
255 | full_domains.insert(merge->outer()); |
256 | known_extents.erase(out_it); |
257 | } else if (auto split = dynamic_cast<Split*>(expr)) { |
258 | auto inner = split->inner(); |
259 | const auto inner_it = known_extents.find(inner); |
260 | auto outer = split->outer(); |
261 | const auto outer_it = known_extents.find(outer); |
262 | if (inner_it != known_extents.end() && |
263 | outer_it != known_extents.end()) { |
264 | if (full_domains.find(inner) != full_domains.end() && |
265 | full_domains.find(outer) != full_domains.end()) { |
266 | known_extents.insert({split->in(), getExtent(split->in())}); |
267 | full_domains.insert(split->in()); |
268 | } else { |
269 | known_extents.insert( |
270 | {split->in(), |
271 | IrBuilder::mulExpr(outer_it->second, inner_it->second)}); |
272 | } |
273 | known_extents.erase(inner_it); |
274 | known_extents.erase(outer_it); |
275 | } else if (inner_it != known_extents.end()) { |
276 | known_extents.insert({split->in(), inner_it->second}); |
277 | known_extents.erase(inner_it); |
278 | } else if (outer_it != known_extents.end()) { |
279 | known_extents.insert({split->in(), outer_it->second}); |
280 | known_extents.erase(outer_it); |
281 | } |
282 | } else { |
283 | TORCH_INTERNAL_ASSERT(false, "Unexpected expr: " , expr); |
284 | } |
285 | } |
286 | |
287 | std::vector<Val*> alloc_dims; |
288 | |
289 | for (auto root_axis : tv->getRootDomain()) { |
290 | auto it = known_extents.find(root_axis); |
291 | if (it == known_extents.end()) { |
292 | continue; |
293 | } |
294 | alloc_dims.push_back(it->second); |
295 | known_extents.erase(it); |
296 | } |
297 | |
298 | // known_extents should have only mappings for root axes, so |
299 | // if anything remains in the map, it's an error |
300 | if (!known_extents.empty()) { |
301 | std::stringstream ss; |
302 | for (auto kv : known_extents) { |
303 | ss << kv.first << " " ; |
304 | } |
305 | TORCH_INTERNAL_ASSERT( |
306 | false, "Non-root axes found for TV" , tv->name(), ": " , ss.str()); |
307 | } |
308 | |
309 | return alloc_dims; |
310 | } |
311 | |
312 | std::vector<Val*> getNonGlobalAllocExpr(AllocationInformation& info) { |
313 | const auto memory_type = info.buffer->getMemoryType(); |
314 | TORCH_INTERNAL_ASSERT( |
315 | memory_type != MemoryType::Global, |
316 | "Invalid memory type: " , |
317 | memory_type); |
318 | |
319 | std::vector<Val*> alloc_dims; |
320 | |
321 | bool has_halo = false; |
322 | std::vector<IterDomain*> alloc_domains; |
323 | |
324 | info.allocation_domains = std::make_unique<std::vector<IterDomain*>>(); |
325 | |
326 | for (const auto axis_i : c10::irange(info.buffer->nDims())) { |
327 | const auto local_id = info.buffer->axis(axis_i); |
328 | |
329 | // Don't use reduction/stride/broadcast axis in the allocation |
330 | // computation |
331 | if (local_id->isReduction() || local_id->isStride() || |
332 | local_id->isBroadcast()) { |
333 | continue; |
334 | } |
335 | |
336 | auto concrete_id = gpu_lower->caMap()->getConcreteMappedID( |
337 | info.buffer->axis(axis_i), IdMappingMode::LOOP); |
338 | const bool is_block_dim = |
339 | isParallelTypeBlockDim(concrete_id->getParallelType()); |
340 | const bool is_thread_dim = |
341 | isParallelTypeThreadDim(concrete_id->getParallelType()); |
342 | const bool is_thread = |
343 | isParallelTypeThread(concrete_id->getParallelType()); |
344 | |
345 | if (axis_i < info.alloc_pos) { |
346 | // Even when the axis is outside the allocation position, if the |
347 | // tensor is shared with respect to the axis, the buffer size |
348 | // needs to be expanded for the axis. Sharing occurs in two |
349 | // cases: 1) the tensor is on shared memory with the axis |
350 | // parallelized by TIDs, and 2) the tensor is on global memory |
351 | // with the axis parallelized by TIDs or BIDs. |
352 | if (!((memory_type == MemoryType::Shared && is_thread_dim) || |
353 | (memory_type == MemoryType::Global && is_thread))) { |
354 | continue; |
355 | } |
356 | alloc_domains.push_back(info.buffer->axis(axis_i)); |
357 | } else { |
358 | if ( |
359 | // If shared memory, don't use any IDs bound to a grid dimension |
360 | (memory_type == MemoryType::Shared && is_block_dim) || |
361 | // If local memory, don't use any IDs bound to a grid or block |
362 | // dimension |
363 | (memory_type == MemoryType::Local && is_thread)) { |
364 | continue; |
365 | } |
366 | alloc_domains.push_back(info.buffer->axis(axis_i)); |
367 | } |
368 | |
369 | auto extent = concrete_id->extent(); |
370 | |
371 | if (gpu_lower->haloInfo()->getExtent(info.buffer->axis(axis_i)) != |
372 | nullptr) { |
373 | has_halo = true; |
374 | } |
375 | |
376 | alloc_dims.push_back(extent); |
377 | info.allocation_domains->push_back(local_id); |
378 | } |
379 | |
380 | // When an axis with halo extension is detected, propagate back |
381 | // the halo extents from leaf IDs to root IDs |
382 | if (has_halo) { |
383 | info.has_halo = true; |
384 | return getNonGlobalAllocExprWithHalo(info.buffer, alloc_domains); |
385 | } |
386 | |
387 | return alloc_dims; |
388 | } |
389 | |
390 | kir::Allocate* createAllocExpr(AllocationInformation& info, bool is_output) { |
391 | if (is_output) { |
392 | return nullptr; |
393 | } |
394 | |
395 | std::vector<Val*> alloc_dims; |
396 | const MemoryType memory_type = info.buffer->getMemoryType(); |
397 | |
398 | if (memory_type == MemoryType::Global) { |
399 | alloc_dims = getGlobalAllocationSizes(info); |
400 | } else { |
401 | alloc_dims = getNonGlobalAllocExpr(info); |
402 | } |
403 | |
404 | if (alloc_dims.size() == 0 && |
405 | info.buffer->domain()->noReductions().size() != 0) { |
406 | alloc_dims.push_back(info.buffer->container()->oneVal()); |
407 | } |
408 | |
409 | // Double the allocation size if double-buffered. Record the |
410 | // original size for indexing. |
411 | if (info.buffer->isDoubleBuffered() || info.buffer->isCircularBuffered()) { |
412 | Val* original_alloc_size = nullptr; |
413 | for (auto alloc_dim : alloc_dims) { |
414 | if (original_alloc_size == nullptr) { |
415 | original_alloc_size = alloc_dim; |
416 | } else { |
417 | original_alloc_size = |
418 | IrBuilder::mulExpr(original_alloc_size, alloc_dim); |
419 | } |
420 | } |
421 | GpuLower::current()->doubleBufferInfo().setOriginalAllocSize( |
422 | info.buffer, original_alloc_size); |
423 | int double_buffer_stage = 2; |
424 | if (info.buffer->isCircularBuffered()) { |
425 | double_buffer_stage = info.buffer->circularBufferDepth(); |
426 | } |
427 | alloc_dims.push_back(IrBuilder::create<Int>(double_buffer_stage)); |
428 | } |
429 | |
430 | // Create the allocation node |
431 | return IrBuilder::create<kir::Allocate>( |
432 | info.buffer, info.buffer->getMemoryType(), alloc_dims); |
433 | } |
434 | |
435 | void handle(Expr* expr) override { |
436 | if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) { |
437 | ExprMutator::handle(expr); |
438 | return; |
439 | } |
440 | |
441 | // // Found where the allocation needs to be inserted |
442 | |
443 | for (const auto i : c10::irange(expr->outputs().size())) { |
444 | auto out = expr->output(i); |
445 | if (!out->isA<TensorView>()) { |
446 | continue; |
447 | } |
448 | |
449 | auto out_tv = out->as<TensorView>(); |
450 | auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv); |
451 | |
452 | Val* init = nullptr; |
453 | if (expr->isA<ReductionOp>() && out_tv->hasReduction()) { |
454 | TORCH_INTERNAL_ASSERT( |
455 | default_val == nullptr, |
456 | "Reduction should not have a default initialization value for predicate elimination." ); |
457 | init = expr->as<ReductionOp>()->init(); |
458 | } else if (expr->isA<GroupedReductionOp>() && out_tv->hasReduction()) { |
459 | TORCH_INTERNAL_ASSERT( |
460 | default_val == nullptr, |
461 | "Reduction should not have a default initialization value for predicate elimination." ); |
462 | init = expr->as<GroupedReductionOp>()->initVal(i); |
463 | } else if (expr->isA<MmaOp>()) { |
464 | init = expr->as<MmaOp>()->init(); |
465 | } else if (expr->isA<WelfordOp>()) { |
466 | TORCH_INTERNAL_ASSERT( |
467 | default_val == nullptr, |
468 | "Welford should not have a default initialization value for predicate elimination." ); |
469 | const auto welford = expr->as<WelfordOp>(); |
470 | if (out->name() == welford->outVar()->name()) { |
471 | init = welford->initVar() == nullptr ? IrBuilder::create<Double>(0) |
472 | : welford->initVar(); |
473 | } else if (out->name() == welford->outAvg()->name()) { |
474 | init = welford->initAvg() == nullptr ? IrBuilder::create<Double>(0) |
475 | : welford->initAvg(); |
476 | } else { |
477 | TORCH_INTERNAL_ASSERT( |
478 | out->name() == welford->outN()->name(), "Unreachable" ); |
479 | init = welford->initN(); |
480 | } |
481 | } else if (expr->isA<GroupedWelfordOp>()) { |
482 | TORCH_INTERNAL_ASSERT( |
483 | default_val == nullptr, |
484 | "Welford should not have a default initialization value for predicate elimination." ); |
485 | init = expr->as<GroupedWelfordOp>()->getInitValOfOutput(out); |
486 | } else if (default_val != nullptr) { |
487 | init = default_val; |
488 | } |
489 | |
490 | const bool is_output = out->isFusionOutput(); |
491 | |
492 | // Don't need to alloc outputs, and if we don't need to initialize we're |
493 | // done. |
494 | if (is_output && init == nullptr) { |
495 | continue; |
496 | } |
497 | |
498 | AllocationInformation allocation; |
499 | allocation.buffer = out_tv; |
500 | fillAllocationInformation(allocation, expr); |
501 | |
502 | auto alloc_expr = createAllocExpr(allocation, is_output); |
503 | auto init_expr = createInitExpr(allocation, init); |
504 | |
505 | // Write information to GPULower |
506 | writeInfoToGPULower(allocation, alloc_expr); |
507 | |
508 | // Register allocations before initializations to keep them in the right |
509 | // order |
510 | if (alloc_expr != nullptr) { |
511 | if (allocation.buffer->getMemoryType() == MemoryType::Shared) { |
512 | // Shared allocations go at the begining of scope |
513 | TORCH_INTERNAL_ASSERT(!exprs_.empty()); |
514 | registerInsertBefore(exprs_[0], alloc_expr, nullptr); |
515 | } else { |
516 | TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); |
517 | kir::Scope* scope = allocation.alloc_for_loop == nullptr |
518 | ? nullptr |
519 | : &allocation.alloc_for_loop->body(); |
520 | registerInsertBefore( |
521 | allocation.alloc_place_before, alloc_expr, scope); |
522 | } |
523 | } |
524 | |
525 | if (init_expr != nullptr) { |
526 | TORCH_INTERNAL_ASSERT(allocation.init_place_before != nullptr); |
527 | kir::Scope* scope = allocation.init_for_loop == nullptr |
528 | ? nullptr |
529 | : &allocation.init_for_loop->body(); |
530 | registerInsertBefore(allocation.init_place_before, init_expr, scope); |
531 | } |
532 | } |
533 | } |
534 | |
535 | // Sends alloc_expr, info.has_halo, info.allocation_domains to GpuLower |
536 | void writeInfoToGPULower( |
537 | const AllocationInformation& allocation, |
538 | kir::Allocate* alloc_expr) { |
539 | auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); |
540 | if (alloc_expr == nullptr) { |
541 | // Skip output allocation. |
542 | return; |
543 | } |
544 | TORCH_INTERNAL_ASSERT( |
545 | !lower_alloc_info_map.count(alloc_expr), |
546 | "duplicated allocation info entry" ); |
547 | |
548 | // Create info entry for GPULower |
549 | auto lower_alloc_info_ptr = std::make_unique<LocalAllocationInfo>(); |
550 | lower_alloc_info_ptr->alloc_expr = alloc_expr; |
551 | lower_alloc_info_ptr->has_halo = allocation.has_halo; |
552 | if (allocation.allocation_domains) { |
553 | lower_alloc_info_ptr->alloc_domains = *(allocation.allocation_domains); |
554 | } |
555 | |
556 | // Write entry to the stored map |
557 | lower_alloc_info_map[alloc_expr] = std::move(lower_alloc_info_ptr); |
558 | } |
559 | |
560 | void handle(kir::IfThenElse*) final { |
561 | TORCH_INTERNAL_ASSERT( |
562 | false, |
563 | "Pass does not support conditional statements, " , |
564 | "this pass should be run before any conditionals are placed in code." ); |
565 | } |
566 | |
567 | AllocationInserter(const std::vector<Expr*>& exprs) |
568 | : gpu_lower(GpuLower::current()) { |
569 | kir::ExprMutator::traverseAndInsert(exprs); |
570 | } |
571 | |
572 | private: |
573 | GpuLower* gpu_lower; |
574 | |
575 | public: |
576 | static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) { |
577 | AllocationInserter inserter(exprs); |
578 | return inserter.exprs_; |
579 | } |
580 | }; |
581 | |
582 | } // namespace |
583 | |
584 | std::vector<Expr*> insertAllocations(const std::vector<Expr*>& exprs) { |
585 | FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations" ); |
586 | return AllocationInserter::insert(exprs); |
587 | } |
588 | |
589 | } // namespace cuda |
590 | } // namespace fuser |
591 | } // namespace jit |
592 | } // namespace torch |
593 | |