1 | #include <lower_misaligned_vectorization.h> |
2 | |
3 | #include <index_compute.h> |
4 | #include <instrumentation.h> |
5 | #include <ir_iostream.h> |
6 | #include <ir_utils.h> |
7 | #include <kernel_ir.h> |
8 | #include <kernel_ir_dispatch.h> |
9 | #include <lower2device.h> |
10 | #include <lower_utils.h> |
11 | #include <predicate_compute.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | namespace { |
19 | |
20 | class MisalignedVectorizationModifier : public kir::ExprMutator { |
21 | public: |
22 | MisalignedVectorizationModifier() = delete; |
23 | |
24 | static std::vector<Expr*> processMisalignedVectorization( |
25 | const std::vector<Expr*>& exprs) { |
26 | FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization" ); |
27 | MisalignedVectorizationModifier mvm(exprs); |
28 | return mvm.exprs_; |
29 | } |
30 | |
31 | private: |
32 | MisalignedVectorizationModifier(const std::vector<Expr*>& exprs) { |
33 | FUSER_PERF_SCOPE("GpuLower::Lower::MisalignedVectorizationModifier" ); |
34 | // Run through loop nests |
35 | // Find for-loops with misaligned vectorization domains |
36 | kir::ExprMutator::traverseAndInsert(exprs); |
37 | } |
38 | |
39 | void handle(kir::ForLoop* fl) final { |
40 | kir::Scope* scope = scope_.empty() ? nullptr : scope_.back(); |
41 | if (containsAnyDirectChildMisalignedVectorize(fl)) { |
42 | for_loops_.push_back(fl); |
43 | auto new_fl = handleMisalignedVectorize(for_loops_, fl); |
44 | for_loops_.pop_back(); |
45 | |
46 | kir::ExprMutator::registerReplace(fl, new_fl, scope); |
47 | } else { |
48 | kir::ExprMutator::handle(fl); |
49 | } |
50 | } |
51 | |
52 | struct ReferenceTensors { |
53 | // Input TensorView to Vectorize Set operation |
54 | TensorView* in_tv = nullptr; |
55 | // Output TensorView to Vectorize Set operation |
56 | TensorView* out_tv = nullptr; |
57 | // TensorView in global memory |
58 | TensorView* global_tv = nullptr; |
59 | // TensorView with vectorize IterDomain and not in global memory |
60 | TensorView* vec_tv = nullptr; |
61 | }; |
62 | |
63 | ReferenceTensors getReferenceTensors(Expr* vectorized_expr) { |
64 | TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr); |
65 | TORCH_INTERNAL_ASSERT( |
66 | vectorized_expr->outputs().front()->isA<TensorView>()); |
67 | TORCH_INTERNAL_ASSERT(vectorized_expr->inputs().front()->isA<TensorView>()); |
68 | |
69 | auto in_tv = vectorized_expr->inputs().front()->as<TensorView>(); |
70 | auto out_tv = vectorized_expr->outputs().front()->as<TensorView>(); |
71 | |
72 | const bool global_vectorize_write_op = |
73 | (out_tv->getMemoryType() == MemoryType::Global && |
74 | in_tv->getMemoryType() == MemoryType::Local); |
75 | const bool global_vectorize_read_op = |
76 | (out_tv->getMemoryType() == MemoryType::Local && |
77 | in_tv->getMemoryType() == MemoryType::Global); |
78 | TORCH_INTERNAL_ASSERT( |
79 | global_vectorize_write_op || global_vectorize_read_op, |
80 | "Unsupported vectorize memory configuration detected." ); |
81 | |
82 | // TensorView on global memory. This is the tensor that may have |
83 | // a non-aligned base address. |
84 | auto global_tv = |
85 | (out_tv->getMemoryType() == MemoryType::Global) ? out_tv : in_tv; |
86 | |
87 | // TensorView with the misaligned vec iterDomain. It is the consumer |
88 | // of vectorized load or the producer of vectorized store. It is |
89 | // assumed that when the output TV is not on global memory, this |
90 | // expression is a vectorized load, so the output TV is vec_tv. |
91 | auto vec_tv = |
92 | (out_tv->getMemoryType() != MemoryType::Global) ? out_tv : in_tv; |
93 | |
94 | return {in_tv, out_tv, global_tv, vec_tv}; |
95 | } |
96 | |
97 | struct VectorizeData { |
98 | Val* vector_size = nullptr; |
99 | Val* shift = nullptr; |
100 | Val* extent = nullptr; |
101 | Val* remainder = nullptr; |
102 | Val* extent_minus_remainder = nullptr; |
103 | Val* last_root_domain_index = nullptr; |
104 | Val* last_root_domain_index_shift = nullptr; |
105 | }; |
106 | |
107 | // Create constants for handling misaligned addresses |
108 | VectorizeData createVectorizeConstants( |
109 | const std::vector<kir::ForLoop*>& for_loop_structure, |
110 | const ReferenceTensors& tensors, |
111 | kir::IfThenElse* parent_scope_ite) { |
112 | // Generate vectorize index |
113 | auto indices = (tensors.out_tv->getMemoryType() == MemoryType::Global) |
114 | ? Index::getConsumerStridedIndices(tensors.out_tv, for_loop_structure) |
115 | : Index::getProducerStridedIndices( |
116 | tensors.in_tv, tensors.out_tv, for_loop_structure); |
117 | |
118 | // >>>>>>>>>>>>> |
119 | // Number of elements in vectorize access |
120 | auto vector_size = |
121 | tensors.vec_tv->domain()->domain().back()->extent()->as<Int>(); |
122 | |
123 | // Size of memory type for the elements |
124 | Int* data_size_in_bytes = |
125 | IrBuilder::create<Int>(dataTypeSize(tensors.vec_tv->dtype())); |
126 | |
127 | // The number of bytes in the vectorize access |
128 | auto vector_size_in_bytes = |
129 | IrBuilder::mulExpr(vector_size, data_size_in_bytes); |
130 | |
131 | auto index = |
132 | IrBuilder::create<kir::TensorIndex>(tensors.global_tv, indices); |
133 | auto address = createNamedScalarFromValue( |
134 | parent_scope_ite->thenBody(), index, "address" , true); |
135 | |
136 | // offset_size = (address % vector_size_bytes) / data_type_size_bytes |
137 | // shift_init = vector_size - offset_size |
138 | auto a = IrBuilder::modExpr(address, vector_size_in_bytes); |
139 | auto b = IrBuilder::divExpr(a, data_size_in_bytes); |
140 | auto c = IrBuilder::subExpr(vector_size, b); |
141 | auto shift_init = createNamedScalarFromValue( |
142 | parent_scope_ite->thenBody(), c, "shift_val" ); |
143 | |
144 | // shift = (shift_init == vector_size) ? 0 : shift_init |
145 | // The number of elements until the first aligned address |
146 | auto shift_pred = IrBuilder::eqExpr(shift_init, vector_size); |
147 | auto shift_val = IrBuilder::whereExpr( |
148 | shift_pred, GpuLower::current()->kernel()->zeroVal(), shift_init); |
149 | |
150 | // >>>>>>>>>>>>> |
151 | auto shift = createNamedScalarFromValue( |
152 | parent_scope_ite->thenBody(), shift_val, "shift" ); |
153 | |
154 | // >>>>>>>>>>>>> |
155 | // Get full extent for the inner-most, merged root domain |
156 | auto extent = getVectorizeExtent(tensors.in_tv, tensors.out_tv); |
157 | |
158 | // remainder = (extent - shift) % vector_size |
159 | // The number of elements remaining not accessed by vectorized operations |
160 | auto remaining_extent = IrBuilder::subExpr(extent, shift); |
161 | auto remainder_val = IrBuilder::modExpr(remaining_extent, vector_size); |
162 | auto remainder = createNamedScalarFromValue( |
163 | parent_scope_ite->thenBody(), remainder_val, "remainder" ); |
164 | |
165 | // (extent - remainder) is the upper-bound for the vectorize section |
166 | auto extent_remainder_val = IrBuilder::subExpr(extent, remainder); |
167 | |
168 | // >>>>>>>>>>>>> |
169 | auto extent_minus_remainder = createNamedScalarFromValue( |
170 | parent_scope_ite->thenBody(), |
171 | extent_remainder_val, |
172 | "extent_minus_remainder" ); |
173 | |
174 | // >>>>>>>>>>>>> |
175 | auto last_root_domain_index = createNamedScalarFromValue( |
176 | parent_scope_ite->thenBody(), indices.back(), "last_root_domain_index" ); |
177 | |
178 | // >>>>>>>>>>>>> |
179 | auto last_root_domain_index_shift = |
180 | IrBuilder::addExpr(last_root_domain_index, shift); |
181 | |
182 | return { |
183 | vector_size, |
184 | shift, |
185 | extent, |
186 | remainder, |
187 | extent_minus_remainder, |
188 | last_root_domain_index, |
189 | last_root_domain_index_shift}; |
190 | } |
191 | |
192 | // Vectorized : [shift - (extent-remainder)) |
193 | // From the first to the last aligned address |
194 | kir::IfThenElse* createVectorizeSection( |
195 | const std::vector<kir::ForLoop*>& child_loops, |
196 | const VectorizeData& params) { |
197 | auto vectorized_child_loops = cloneForLoops( |
198 | child_loops, params.vector_size, nullptr, true, params.shift); |
199 | |
200 | // Vectorize Range: [shift - (extent-remainder)) |
201 | // (last_root_domain_index + shift) < (extent - remainder) |
202 | Val* vectorize_cond = IrBuilder::ltExpr( |
203 | params.last_root_domain_index_shift, params.extent_minus_remainder); |
204 | |
205 | kir::Predicate* vectorize_pred = |
206 | IrBuilder::create<kir::Predicate>(vectorize_cond->as<Bool>()); |
207 | kir::IfThenElse* vectorize_ite = |
208 | IrBuilder::create<kir::IfThenElse>(vectorize_pred); |
209 | |
210 | for (auto cloned_loop : vectorized_child_loops) { |
211 | vectorize_ite->thenBody().push_back(cloned_loop); |
212 | } |
213 | |
214 | return vectorize_ite; |
215 | } |
216 | |
217 | // Initial : [0 - shift) |
218 | // From the initial address until the first aligned address |
219 | kir::IfThenElse* createInitialSection( |
220 | const std::vector<kir::ForLoop*>& child_loops, |
221 | const VectorizeData& params) { |
222 | auto pre_child_loops = cloneForLoops( |
223 | child_loops, params.vector_size, params.shift, false, nullptr); |
224 | |
225 | // Initial Range: [0 - shift) |
226 | // last_root_domain_index == 0 |
227 | Val* initial_cond = IrBuilder::eqExpr( |
228 | params.last_root_domain_index, |
229 | GpuLower::current()->kernel()->zeroVal()); |
230 | |
231 | kir::Predicate* initial_pred = |
232 | IrBuilder::create<kir::Predicate>(initial_cond->as<Bool>()); |
233 | kir::IfThenElse* initial_ite = |
234 | IrBuilder::create<kir::IfThenElse>(initial_pred); |
235 | |
236 | for (auto cloned_loop : pre_child_loops) { |
237 | initial_ite->thenBody().push_back(cloned_loop); |
238 | } |
239 | |
240 | return initial_ite; |
241 | } |
242 | |
243 | // Remainder : [(extent-remainder) - extent) |
244 | // From the last aligned address until the end of the extent |
245 | kir::IfThenElse* createRemainderSection( |
246 | const std::vector<kir::ForLoop*>& child_loops, |
247 | const VectorizeData& params) { |
248 | auto post_child_loops = cloneForLoops( |
249 | child_loops, params.vector_size, params.remainder, false, params.shift); |
250 | |
251 | // Remainder Range: [(extent-remainder) - extent) |
252 | // (extent - remainder) <= last_root_domain_index + shift < extent |
253 | Val* lower_bound = IrBuilder::geExpr( |
254 | params.last_root_domain_index_shift, params.extent_minus_remainder); |
255 | Val* upper_bound = |
256 | IrBuilder::ltExpr(params.last_root_domain_index_shift, params.extent); |
257 | Val* remainder_cond = IrBuilder::andExpr(lower_bound, upper_bound); |
258 | |
259 | kir::Predicate* remainder_pred = |
260 | IrBuilder::create<kir::Predicate>(remainder_cond->as<Bool>()); |
261 | kir::IfThenElse* remainder_ite = |
262 | IrBuilder::create<kir::IfThenElse>(remainder_pred); |
263 | |
264 | for (auto cloned_loop : post_child_loops) { |
265 | remainder_ite->thenBody().push_back(cloned_loop); |
266 | } |
267 | |
268 | return remainder_ite; |
269 | } |
270 | |
271 | kir::ForLoop* handleMisalignedVectorize( |
272 | std::vector<kir::ForLoop*> for_loop_structure, |
273 | const kir::ForLoop* parent_for_loop) { |
274 | auto child_loops = findChildForLoops(parent_for_loop); |
275 | |
276 | // Assumption: All vectorize operations have the same shift |
277 | auto vectorized_expr = |
278 | findFirstVectorizedSetOp(for_loop_structure, child_loops); |
279 | TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr); |
280 | |
281 | auto reference_tensors = getReferenceTensors(vectorized_expr); |
282 | |
283 | // The parent_for_loop contains allocate, read, compute, write operations |
284 | const auto new_parent_for_loop = |
285 | IrBuilder::create<kir::ForLoop>(parent_for_loop); |
286 | |
287 | // Transfer all expressions except for-loops to new parent for-loop |
288 | // All expressions are placed at the beginning of the new for-loop |
289 | copyExprsExceptForLoops(parent_for_loop, new_parent_for_loop); |
290 | |
291 | // Get the predicate for all but the last root domain |
292 | auto pred_except_last_root_domain = IrBuilder::create<kir::Predicate>( |
293 | PredicateType::Misaligned, |
294 | vectorized_expr, |
295 | GpuLower::current()->kernel()->trueVal()); |
296 | kir::IfThenElse* pred_ite = |
297 | IrBuilder::create<kir::IfThenElse>(pred_except_last_root_domain); |
298 | new_parent_for_loop->body().push_back(pred_ite); |
299 | |
300 | auto constants = createVectorizeConstants( |
301 | for_loop_structure, reference_tensors, pred_ite); |
302 | |
303 | // The last root domain is divided into three sections. |
304 | // | Initial - N/A Shift | Vectorize - Shift | Remainder - Shift | |
305 | |
306 | // Vectorized set operation with vectorize shift |
307 | auto vectorize_ite = createVectorizeSection(child_loops, constants); |
308 | pred_ite->thenBody().push_back(vectorize_ite); |
309 | |
310 | // Standard set operation without vectorize shift |
311 | auto initial_ite = createInitialSection(child_loops, constants); |
312 | pred_ite->thenBody().push_back(initial_ite); |
313 | |
314 | // Standard set operation with vectorize shift |
315 | auto remainder_ite = createRemainderSection(child_loops, constants); |
316 | pred_ite->thenBody().push_back(remainder_ite); |
317 | |
318 | return new_parent_for_loop; |
319 | } |
320 | |
321 | // Determine that the expression is UnaryOpType::Set AND |
322 | // the output TensorView domain is vectorized |
323 | bool isVectorizeSetOp(kir::ForLoop* fl, Expr* expr) { |
324 | if (fl->iter_domain()->getParallelType() != |
325 | ParallelType::MisalignedVectorize) { |
326 | return false; |
327 | } |
328 | |
329 | if (expr->isA<UnaryOp>()) { |
330 | auto unaryOp = expr->as<UnaryOp>(); |
331 | if (unaryOp->out()->isA<TensorView>()) { |
332 | auto out_tv = unaryOp->out()->as<TensorView>(); |
333 | return unaryOp->getUnaryOpType() == UnaryOpType::Set && |
334 | out_tv->domain()->hasVectorize(); |
335 | } |
336 | } |
337 | return false; |
338 | } |
339 | |
340 | // Clone each for loop |
341 | // loop_stop value - for (index = start; index < stop; index += step) |
342 | // pred_stop value - Predicate loop body as (index < pred_stop) if non null |
343 | // vectorize flag - Do not generate for loop header |
344 | // shift value - Add shift to global indices generated within for loop |
345 | std::vector<kir::ForLoop*> cloneForLoops( |
346 | const std::vector<kir::ForLoop*>& for_loops_, |
347 | Val* loop_stop, |
348 | Val* pred_stop, |
349 | bool vectorize, |
350 | Val* vectorize_shift) { |
351 | std::vector<kir::ForLoop*> cloned_for_loops; |
352 | |
353 | for (auto fl : for_loops_) { |
354 | auto first_expr = fl->body().exprs().front(); |
355 | bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); |
356 | |
357 | // If the for loop contains a vectorize Set operation, then |
358 | // it should only contain a single expression |
359 | TORCH_INTERNAL_ASSERT( |
360 | !has_vectorize_op || fl->body().exprs().size() == 1); |
361 | |
362 | const auto new_loop = IrBuilder::create<kir::ForLoop>( |
363 | fl->iter_domain(), |
364 | fl->index(), |
365 | GpuLower::current()->kernel()->zeroVal(), |
366 | loop_stop, |
367 | GpuLower::current()->kernel()->oneVal(), |
368 | vectorize && has_vectorize_op, |
369 | vectorize_shift, |
370 | fl->isUnrollRequired(), |
371 | fl->doubleBufferLoopStage()); |
372 | |
373 | auto body = &new_loop->body(); |
374 | |
375 | // Predicate the loop body if pred_stop is not null. This is to |
376 | // make sure the loop itself is completely unrollable. |
377 | if (pred_stop != nullptr) { |
378 | auto body_pred = IrBuilder::create<kir::Predicate>( |
379 | IrBuilder::ltExpr(new_loop->index(), pred_stop)->as<Bool>()); |
380 | auto body_ite = IrBuilder::create<kir::IfThenElse>(body_pred); |
381 | body->push_back(body_ite); |
382 | body = &body_ite->thenBody(); |
383 | } |
384 | |
385 | for (auto expr : fl->body().exprs()) { |
386 | body->push_back(expr); |
387 | } |
388 | |
389 | cloned_for_loops.push_back(new_loop); |
390 | } |
391 | return cloned_for_loops; |
392 | } |
393 | |
394 | // Add all expressions except for loops to new parent for loop |
395 | void copyExprsExceptForLoops( |
396 | const kir::ForLoop* for_loop, |
397 | kir::ForLoop* new_loop) { |
398 | std::vector<kir::ForLoop*> loops; |
399 | for (auto expr : for_loop->body().exprs()) { |
400 | if (!expr->isA<kir::ForLoop>()) { |
401 | new_loop->body().push_back(expr); |
402 | } |
403 | } |
404 | } |
405 | |
406 | // Find any child for loops inside parent for loop |
407 | std::vector<kir::ForLoop*> findChildForLoops(const kir::ForLoop* for_loop) { |
408 | std::vector<kir::ForLoop*> loops; |
409 | for (auto expr : for_loop->body().exprs()) { |
410 | if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
411 | loops.push_back(nested_for_loop); |
412 | } |
413 | } |
414 | return loops; |
415 | } |
416 | |
417 | // Find the first vectorize set - either read or write |
418 | // Add child For-Loop to for_loop_structure |
419 | // Enable vectorize flag in child For-Loop |
420 | Expr* findFirstVectorizedSetOp( |
421 | std::vector<kir::ForLoop*>& for_loop_structure, |
422 | const std::vector<kir::ForLoop*>& for_loops_) { |
423 | for (auto fl : for_loops_) { |
424 | auto first_expr = fl->body().exprs().front(); |
425 | bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); |
426 | if (has_vectorize_op) { |
427 | for_loop_structure.push_back(fl); |
428 | return first_expr; |
429 | } |
430 | } |
431 | return nullptr; |
432 | } |
433 | |
434 | // Get full extent for the inner-most, merged root domain |
435 | Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) { |
436 | const auto gpu_lower = GpuLower::current(); |
437 | |
438 | auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv) |
439 | .mapProducerToConsumer( |
440 | producer_tv->domain(), consumer_tv->domain()); |
441 | |
442 | auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo( |
443 | {consumer_tv->domain()->domain().begin() + |
444 | consumer_tv->getComputeAtPosition(), |
445 | consumer_tv->domain()->domain().end()}); |
446 | auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo( |
447 | {producer_tv->domain()->domain().begin() + |
448 | producer_tv->getComputeAtPosition(), |
449 | producer_tv->domain()->domain().end()}); |
450 | |
451 | const auto& consumer_contig = consumer_tv->domain()->contiguity(); |
452 | const auto& producer_contig = producer_tv->domain()->contiguity(); |
453 | |
454 | auto producer_root_domain = producer_tv->getMaybeRFactorDomain(); |
455 | |
456 | // Calculate extent of merged root domains |
457 | Val* extent = nullptr; |
458 | auto consumer_root_idx = |
459 | int(consumer_tv->getMaybeRFactorDomain().size()) - 1; |
460 | for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) { |
461 | auto producer_root_id = producer_root_domain.at(i); |
462 | |
463 | TORCH_INTERNAL_ASSERT( |
464 | !gpu_lower->trivialReductionInfo().isDerived(producer_root_id), |
465 | "No trivial reduction axis should exist: " , |
466 | producer_root_id); |
467 | |
468 | // If the producer ID is reduction or broadcast, it should be safe |
469 | // to ignore. |
470 | if (producer_root_id->isReduction()) { |
471 | continue; |
472 | } else if (producer_root_id->isBroadcast()) { |
473 | --consumer_root_idx; |
474 | continue; |
475 | } |
476 | |
477 | // There must be a matching consumer root ID as the producer ID is |
478 | // not reduction and the expression between them is UnaryOpType::Set. |
479 | auto it = p2c.find(producer_root_id); |
480 | TORCH_INTERNAL_ASSERT( |
481 | it != p2c.end(), "No matching consumer root ID found" ); |
482 | auto consumer_root_id = it->second; |
483 | |
484 | // Don't extend the vectorization domain beyond the CA position |
485 | if (std::find( |
486 | consumer_root_right_of_ca_domains.begin(), |
487 | consumer_root_right_of_ca_domains.end(), |
488 | consumer_root_id) == consumer_root_right_of_ca_domains.end() || |
489 | std::find( |
490 | producer_root_right_of_ca_domains.begin(), |
491 | producer_root_right_of_ca_domains.end(), |
492 | producer_root_id) == producer_root_right_of_ca_domains.end()) { |
493 | break; |
494 | } |
495 | |
496 | // We now know it's safe to extend the vectorization domain to these |
497 | // axes. It shouldn't matter whether producer or consumer is used. |
498 | if (extent == nullptr) { |
499 | extent = consumer_root_id->extent(); |
500 | } else { |
501 | extent = IrBuilder::mulExpr(extent, consumer_root_id->extent()); |
502 | } |
503 | |
504 | // If it's not contiguous, extending the vectorization domain |
505 | // further is not possible |
506 | if (!(producer_contig.at(i) && consumer_contig.at(consumer_root_idx))) { |
507 | break; |
508 | } |
509 | |
510 | --consumer_root_idx; |
511 | } |
512 | |
513 | TORCH_INTERNAL_ASSERT(extent != nullptr); |
514 | |
515 | return extent; |
516 | } |
517 | |
518 | Val* createNamedScalarFromValue( |
519 | kir::Scope& body, |
520 | Val* val, |
521 | const std::string& name, |
522 | bool address = false) { |
523 | auto namedScalar = (address) ? IrBuilder::addressExprNamedScalar(name, val) |
524 | : IrBuilder::setExprNamedScalar(name, val); |
525 | TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr); |
526 | |
527 | auto alloc = IrBuilder::create<kir::Allocate>( |
528 | namedScalar, |
529 | MemoryType::Local, |
530 | GpuLower::current()->kernel()->oneVal()); |
531 | body.push_back(alloc); |
532 | body.push_back(namedScalar->definition()); |
533 | return namedScalar; |
534 | } |
535 | }; |
536 | |
537 | } // namespace |
538 | |
539 | std::vector<Expr*> processMisalignedVectorization( |
540 | const std::vector<Expr*>& exprs) { |
541 | return MisalignedVectorizationModifier::processMisalignedVectorization(exprs); |
542 | } |
543 | |
544 | bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) { |
545 | for (auto expr : fl->body().exprs()) { |
546 | if (expr->isA<kir::ForLoop>()) { |
547 | auto child_fl = expr->as<kir::ForLoop>(); |
548 | if (child_fl->iter_domain()->getParallelType() == |
549 | ParallelType::MisalignedVectorize) { |
550 | return true; |
551 | } |
552 | } |
553 | } |
554 | return false; |
555 | } |
556 | |
557 | } // namespace cuda |
558 | } // namespace fuser |
559 | } // namespace jit |
560 | } // namespace torch |
561 | |