1#include <lower_misaligned_vectorization.h>
2
3#include <index_compute.h>
4#include <instrumentation.h>
5#include <ir_iostream.h>
6#include <ir_utils.h>
7#include <kernel_ir.h>
8#include <kernel_ir_dispatch.h>
9#include <lower2device.h>
10#include <lower_utils.h>
11#include <predicate_compute.h>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18namespace {
19
20class MisalignedVectorizationModifier : public kir::ExprMutator {
21 public:
22 MisalignedVectorizationModifier() = delete;
23
24 static std::vector<Expr*> processMisalignedVectorization(
25 const std::vector<Expr*>& exprs) {
26 FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization");
27 MisalignedVectorizationModifier mvm(exprs);
28 return mvm.exprs_;
29 }
30
31 private:
32 MisalignedVectorizationModifier(const std::vector<Expr*>& exprs) {
33 FUSER_PERF_SCOPE("GpuLower::Lower::MisalignedVectorizationModifier");
34 // Run through loop nests
35 // Find for-loops with misaligned vectorization domains
36 kir::ExprMutator::traverseAndInsert(exprs);
37 }
38
39 void handle(kir::ForLoop* fl) final {
40 kir::Scope* scope = scope_.empty() ? nullptr : scope_.back();
41 if (containsAnyDirectChildMisalignedVectorize(fl)) {
42 for_loops_.push_back(fl);
43 auto new_fl = handleMisalignedVectorize(for_loops_, fl);
44 for_loops_.pop_back();
45
46 kir::ExprMutator::registerReplace(fl, new_fl, scope);
47 } else {
48 kir::ExprMutator::handle(fl);
49 }
50 }
51
52 struct ReferenceTensors {
53 // Input TensorView to Vectorize Set operation
54 TensorView* in_tv = nullptr;
55 // Output TensorView to Vectorize Set operation
56 TensorView* out_tv = nullptr;
57 // TensorView in global memory
58 TensorView* global_tv = nullptr;
59 // TensorView with vectorize IterDomain and not in global memory
60 TensorView* vec_tv = nullptr;
61 };
62
63 ReferenceTensors getReferenceTensors(Expr* vectorized_expr) {
64 TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
65 TORCH_INTERNAL_ASSERT(
66 vectorized_expr->outputs().front()->isA<TensorView>());
67 TORCH_INTERNAL_ASSERT(vectorized_expr->inputs().front()->isA<TensorView>());
68
69 auto in_tv = vectorized_expr->inputs().front()->as<TensorView>();
70 auto out_tv = vectorized_expr->outputs().front()->as<TensorView>();
71
72 const bool global_vectorize_write_op =
73 (out_tv->getMemoryType() == MemoryType::Global &&
74 in_tv->getMemoryType() == MemoryType::Local);
75 const bool global_vectorize_read_op =
76 (out_tv->getMemoryType() == MemoryType::Local &&
77 in_tv->getMemoryType() == MemoryType::Global);
78 TORCH_INTERNAL_ASSERT(
79 global_vectorize_write_op || global_vectorize_read_op,
80 "Unsupported vectorize memory configuration detected.");
81
82 // TensorView on global memory. This is the tensor that may have
83 // a non-aligned base address.
84 auto global_tv =
85 (out_tv->getMemoryType() == MemoryType::Global) ? out_tv : in_tv;
86
87 // TensorView with the misaligned vec iterDomain. It is the consumer
88 // of vectorized load or the producer of vectorized store. It is
89 // assumed that when the output TV is not on global memory, this
90 // expression is a vectorized load, so the output TV is vec_tv.
91 auto vec_tv =
92 (out_tv->getMemoryType() != MemoryType::Global) ? out_tv : in_tv;
93
94 return {in_tv, out_tv, global_tv, vec_tv};
95 }
96
97 struct VectorizeData {
98 Val* vector_size = nullptr;
99 Val* shift = nullptr;
100 Val* extent = nullptr;
101 Val* remainder = nullptr;
102 Val* extent_minus_remainder = nullptr;
103 Val* last_root_domain_index = nullptr;
104 Val* last_root_domain_index_shift = nullptr;
105 };
106
107 // Create constants for handling misaligned addresses
108 VectorizeData createVectorizeConstants(
109 const std::vector<kir::ForLoop*>& for_loop_structure,
110 const ReferenceTensors& tensors,
111 kir::IfThenElse* parent_scope_ite) {
112 // Generate vectorize index
113 auto indices = (tensors.out_tv->getMemoryType() == MemoryType::Global)
114 ? Index::getConsumerStridedIndices(tensors.out_tv, for_loop_structure)
115 : Index::getProducerStridedIndices(
116 tensors.in_tv, tensors.out_tv, for_loop_structure);
117
118 // >>>>>>>>>>>>>
119 // Number of elements in vectorize access
120 auto vector_size =
121 tensors.vec_tv->domain()->domain().back()->extent()->as<Int>();
122
123 // Size of memory type for the elements
124 Int* data_size_in_bytes =
125 IrBuilder::create<Int>(dataTypeSize(tensors.vec_tv->dtype()));
126
127 // The number of bytes in the vectorize access
128 auto vector_size_in_bytes =
129 IrBuilder::mulExpr(vector_size, data_size_in_bytes);
130
131 auto index =
132 IrBuilder::create<kir::TensorIndex>(tensors.global_tv, indices);
133 auto address = createNamedScalarFromValue(
134 parent_scope_ite->thenBody(), index, "address", true);
135
136 // offset_size = (address % vector_size_bytes) / data_type_size_bytes
137 // shift_init = vector_size - offset_size
138 auto a = IrBuilder::modExpr(address, vector_size_in_bytes);
139 auto b = IrBuilder::divExpr(a, data_size_in_bytes);
140 auto c = IrBuilder::subExpr(vector_size, b);
141 auto shift_init = createNamedScalarFromValue(
142 parent_scope_ite->thenBody(), c, "shift_val");
143
144 // shift = (shift_init == vector_size) ? 0 : shift_init
145 // The number of elements until the first aligned address
146 auto shift_pred = IrBuilder::eqExpr(shift_init, vector_size);
147 auto shift_val = IrBuilder::whereExpr(
148 shift_pred, GpuLower::current()->kernel()->zeroVal(), shift_init);
149
150 // >>>>>>>>>>>>>
151 auto shift = createNamedScalarFromValue(
152 parent_scope_ite->thenBody(), shift_val, "shift");
153
154 // >>>>>>>>>>>>>
155 // Get full extent for the inner-most, merged root domain
156 auto extent = getVectorizeExtent(tensors.in_tv, tensors.out_tv);
157
158 // remainder = (extent - shift) % vector_size
159 // The number of elements remaining not accessed by vectorized operations
160 auto remaining_extent = IrBuilder::subExpr(extent, shift);
161 auto remainder_val = IrBuilder::modExpr(remaining_extent, vector_size);
162 auto remainder = createNamedScalarFromValue(
163 parent_scope_ite->thenBody(), remainder_val, "remainder");
164
165 // (extent - remainder) is the upper-bound for the vectorize section
166 auto extent_remainder_val = IrBuilder::subExpr(extent, remainder);
167
168 // >>>>>>>>>>>>>
169 auto extent_minus_remainder = createNamedScalarFromValue(
170 parent_scope_ite->thenBody(),
171 extent_remainder_val,
172 "extent_minus_remainder");
173
174 // >>>>>>>>>>>>>
175 auto last_root_domain_index = createNamedScalarFromValue(
176 parent_scope_ite->thenBody(), indices.back(), "last_root_domain_index");
177
178 // >>>>>>>>>>>>>
179 auto last_root_domain_index_shift =
180 IrBuilder::addExpr(last_root_domain_index, shift);
181
182 return {
183 vector_size,
184 shift,
185 extent,
186 remainder,
187 extent_minus_remainder,
188 last_root_domain_index,
189 last_root_domain_index_shift};
190 }
191
192 // Vectorized : [shift - (extent-remainder))
193 // From the first to the last aligned address
194 kir::IfThenElse* createVectorizeSection(
195 const std::vector<kir::ForLoop*>& child_loops,
196 const VectorizeData& params) {
197 auto vectorized_child_loops = cloneForLoops(
198 child_loops, params.vector_size, nullptr, true, params.shift);
199
200 // Vectorize Range: [shift - (extent-remainder))
201 // (last_root_domain_index + shift) < (extent - remainder)
202 Val* vectorize_cond = IrBuilder::ltExpr(
203 params.last_root_domain_index_shift, params.extent_minus_remainder);
204
205 kir::Predicate* vectorize_pred =
206 IrBuilder::create<kir::Predicate>(vectorize_cond->as<Bool>());
207 kir::IfThenElse* vectorize_ite =
208 IrBuilder::create<kir::IfThenElse>(vectorize_pred);
209
210 for (auto cloned_loop : vectorized_child_loops) {
211 vectorize_ite->thenBody().push_back(cloned_loop);
212 }
213
214 return vectorize_ite;
215 }
216
217 // Initial : [0 - shift)
218 // From the initial address until the first aligned address
219 kir::IfThenElse* createInitialSection(
220 const std::vector<kir::ForLoop*>& child_loops,
221 const VectorizeData& params) {
222 auto pre_child_loops = cloneForLoops(
223 child_loops, params.vector_size, params.shift, false, nullptr);
224
225 // Initial Range: [0 - shift)
226 // last_root_domain_index == 0
227 Val* initial_cond = IrBuilder::eqExpr(
228 params.last_root_domain_index,
229 GpuLower::current()->kernel()->zeroVal());
230
231 kir::Predicate* initial_pred =
232 IrBuilder::create<kir::Predicate>(initial_cond->as<Bool>());
233 kir::IfThenElse* initial_ite =
234 IrBuilder::create<kir::IfThenElse>(initial_pred);
235
236 for (auto cloned_loop : pre_child_loops) {
237 initial_ite->thenBody().push_back(cloned_loop);
238 }
239
240 return initial_ite;
241 }
242
243 // Remainder : [(extent-remainder) - extent)
244 // From the last aligned address until the end of the extent
245 kir::IfThenElse* createRemainderSection(
246 const std::vector<kir::ForLoop*>& child_loops,
247 const VectorizeData& params) {
248 auto post_child_loops = cloneForLoops(
249 child_loops, params.vector_size, params.remainder, false, params.shift);
250
251 // Remainder Range: [(extent-remainder) - extent)
252 // (extent - remainder) <= last_root_domain_index + shift < extent
253 Val* lower_bound = IrBuilder::geExpr(
254 params.last_root_domain_index_shift, params.extent_minus_remainder);
255 Val* upper_bound =
256 IrBuilder::ltExpr(params.last_root_domain_index_shift, params.extent);
257 Val* remainder_cond = IrBuilder::andExpr(lower_bound, upper_bound);
258
259 kir::Predicate* remainder_pred =
260 IrBuilder::create<kir::Predicate>(remainder_cond->as<Bool>());
261 kir::IfThenElse* remainder_ite =
262 IrBuilder::create<kir::IfThenElse>(remainder_pred);
263
264 for (auto cloned_loop : post_child_loops) {
265 remainder_ite->thenBody().push_back(cloned_loop);
266 }
267
268 return remainder_ite;
269 }
270
271 kir::ForLoop* handleMisalignedVectorize(
272 std::vector<kir::ForLoop*> for_loop_structure,
273 const kir::ForLoop* parent_for_loop) {
274 auto child_loops = findChildForLoops(parent_for_loop);
275
276 // Assumption: All vectorize operations have the same shift
277 auto vectorized_expr =
278 findFirstVectorizedSetOp(for_loop_structure, child_loops);
279 TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
280
281 auto reference_tensors = getReferenceTensors(vectorized_expr);
282
283 // The parent_for_loop contains allocate, read, compute, write operations
284 const auto new_parent_for_loop =
285 IrBuilder::create<kir::ForLoop>(parent_for_loop);
286
287 // Transfer all expressions except for-loops to new parent for-loop
288 // All expressions are placed at the beginning of the new for-loop
289 copyExprsExceptForLoops(parent_for_loop, new_parent_for_loop);
290
291 // Get the predicate for all but the last root domain
292 auto pred_except_last_root_domain = IrBuilder::create<kir::Predicate>(
293 PredicateType::Misaligned,
294 vectorized_expr,
295 GpuLower::current()->kernel()->trueVal());
296 kir::IfThenElse* pred_ite =
297 IrBuilder::create<kir::IfThenElse>(pred_except_last_root_domain);
298 new_parent_for_loop->body().push_back(pred_ite);
299
300 auto constants = createVectorizeConstants(
301 for_loop_structure, reference_tensors, pred_ite);
302
303 // The last root domain is divided into three sections.
304 // | Initial - N/A Shift | Vectorize - Shift | Remainder - Shift |
305
306 // Vectorized set operation with vectorize shift
307 auto vectorize_ite = createVectorizeSection(child_loops, constants);
308 pred_ite->thenBody().push_back(vectorize_ite);
309
310 // Standard set operation without vectorize shift
311 auto initial_ite = createInitialSection(child_loops, constants);
312 pred_ite->thenBody().push_back(initial_ite);
313
314 // Standard set operation with vectorize shift
315 auto remainder_ite = createRemainderSection(child_loops, constants);
316 pred_ite->thenBody().push_back(remainder_ite);
317
318 return new_parent_for_loop;
319 }
320
321 // Determine that the expression is UnaryOpType::Set AND
322 // the output TensorView domain is vectorized
323 bool isVectorizeSetOp(kir::ForLoop* fl, Expr* expr) {
324 if (fl->iter_domain()->getParallelType() !=
325 ParallelType::MisalignedVectorize) {
326 return false;
327 }
328
329 if (expr->isA<UnaryOp>()) {
330 auto unaryOp = expr->as<UnaryOp>();
331 if (unaryOp->out()->isA<TensorView>()) {
332 auto out_tv = unaryOp->out()->as<TensorView>();
333 return unaryOp->getUnaryOpType() == UnaryOpType::Set &&
334 out_tv->domain()->hasVectorize();
335 }
336 }
337 return false;
338 }
339
340 // Clone each for loop
341 // loop_stop value - for (index = start; index < stop; index += step)
342 // pred_stop value - Predicate loop body as (index < pred_stop) if non null
343 // vectorize flag - Do not generate for loop header
344 // shift value - Add shift to global indices generated within for loop
345 std::vector<kir::ForLoop*> cloneForLoops(
346 const std::vector<kir::ForLoop*>& for_loops_,
347 Val* loop_stop,
348 Val* pred_stop,
349 bool vectorize,
350 Val* vectorize_shift) {
351 std::vector<kir::ForLoop*> cloned_for_loops;
352
353 for (auto fl : for_loops_) {
354 auto first_expr = fl->body().exprs().front();
355 bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
356
357 // If the for loop contains a vectorize Set operation, then
358 // it should only contain a single expression
359 TORCH_INTERNAL_ASSERT(
360 !has_vectorize_op || fl->body().exprs().size() == 1);
361
362 const auto new_loop = IrBuilder::create<kir::ForLoop>(
363 fl->iter_domain(),
364 fl->index(),
365 GpuLower::current()->kernel()->zeroVal(),
366 loop_stop,
367 GpuLower::current()->kernel()->oneVal(),
368 vectorize && has_vectorize_op,
369 vectorize_shift,
370 fl->isUnrollRequired(),
371 fl->doubleBufferLoopStage());
372
373 auto body = &new_loop->body();
374
375 // Predicate the loop body if pred_stop is not null. This is to
376 // make sure the loop itself is completely unrollable.
377 if (pred_stop != nullptr) {
378 auto body_pred = IrBuilder::create<kir::Predicate>(
379 IrBuilder::ltExpr(new_loop->index(), pred_stop)->as<Bool>());
380 auto body_ite = IrBuilder::create<kir::IfThenElse>(body_pred);
381 body->push_back(body_ite);
382 body = &body_ite->thenBody();
383 }
384
385 for (auto expr : fl->body().exprs()) {
386 body->push_back(expr);
387 }
388
389 cloned_for_loops.push_back(new_loop);
390 }
391 return cloned_for_loops;
392 }
393
394 // Add all expressions except for loops to new parent for loop
395 void copyExprsExceptForLoops(
396 const kir::ForLoop* for_loop,
397 kir::ForLoop* new_loop) {
398 std::vector<kir::ForLoop*> loops;
399 for (auto expr : for_loop->body().exprs()) {
400 if (!expr->isA<kir::ForLoop>()) {
401 new_loop->body().push_back(expr);
402 }
403 }
404 }
405
406 // Find any child for loops inside parent for loop
407 std::vector<kir::ForLoop*> findChildForLoops(const kir::ForLoop* for_loop) {
408 std::vector<kir::ForLoop*> loops;
409 for (auto expr : for_loop->body().exprs()) {
410 if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
411 loops.push_back(nested_for_loop);
412 }
413 }
414 return loops;
415 }
416
417 // Find the first vectorize set - either read or write
418 // Add child For-Loop to for_loop_structure
419 // Enable vectorize flag in child For-Loop
420 Expr* findFirstVectorizedSetOp(
421 std::vector<kir::ForLoop*>& for_loop_structure,
422 const std::vector<kir::ForLoop*>& for_loops_) {
423 for (auto fl : for_loops_) {
424 auto first_expr = fl->body().exprs().front();
425 bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
426 if (has_vectorize_op) {
427 for_loop_structure.push_back(fl);
428 return first_expr;
429 }
430 }
431 return nullptr;
432 }
433
434 // Get full extent for the inner-most, merged root domain
435 Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) {
436 const auto gpu_lower = GpuLower::current();
437
438 auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv)
439 .mapProducerToConsumer(
440 producer_tv->domain(), consumer_tv->domain());
441
442 auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo(
443 {consumer_tv->domain()->domain().begin() +
444 consumer_tv->getComputeAtPosition(),
445 consumer_tv->domain()->domain().end()});
446 auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo(
447 {producer_tv->domain()->domain().begin() +
448 producer_tv->getComputeAtPosition(),
449 producer_tv->domain()->domain().end()});
450
451 const auto& consumer_contig = consumer_tv->domain()->contiguity();
452 const auto& producer_contig = producer_tv->domain()->contiguity();
453
454 auto producer_root_domain = producer_tv->getMaybeRFactorDomain();
455
456 // Calculate extent of merged root domains
457 Val* extent = nullptr;
458 auto consumer_root_idx =
459 int(consumer_tv->getMaybeRFactorDomain().size()) - 1;
460 for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) {
461 auto producer_root_id = producer_root_domain.at(i);
462
463 TORCH_INTERNAL_ASSERT(
464 !gpu_lower->trivialReductionInfo().isDerived(producer_root_id),
465 "No trivial reduction axis should exist: ",
466 producer_root_id);
467
468 // If the producer ID is reduction or broadcast, it should be safe
469 // to ignore.
470 if (producer_root_id->isReduction()) {
471 continue;
472 } else if (producer_root_id->isBroadcast()) {
473 --consumer_root_idx;
474 continue;
475 }
476
477 // There must be a matching consumer root ID as the producer ID is
478 // not reduction and the expression between them is UnaryOpType::Set.
479 auto it = p2c.find(producer_root_id);
480 TORCH_INTERNAL_ASSERT(
481 it != p2c.end(), "No matching consumer root ID found");
482 auto consumer_root_id = it->second;
483
484 // Don't extend the vectorization domain beyond the CA position
485 if (std::find(
486 consumer_root_right_of_ca_domains.begin(),
487 consumer_root_right_of_ca_domains.end(),
488 consumer_root_id) == consumer_root_right_of_ca_domains.end() ||
489 std::find(
490 producer_root_right_of_ca_domains.begin(),
491 producer_root_right_of_ca_domains.end(),
492 producer_root_id) == producer_root_right_of_ca_domains.end()) {
493 break;
494 }
495
496 // We now know it's safe to extend the vectorization domain to these
497 // axes. It shouldn't matter whether producer or consumer is used.
498 if (extent == nullptr) {
499 extent = consumer_root_id->extent();
500 } else {
501 extent = IrBuilder::mulExpr(extent, consumer_root_id->extent());
502 }
503
504 // If it's not contiguous, extending the vectorization domain
505 // further is not possible
506 if (!(producer_contig.at(i) && consumer_contig.at(consumer_root_idx))) {
507 break;
508 }
509
510 --consumer_root_idx;
511 }
512
513 TORCH_INTERNAL_ASSERT(extent != nullptr);
514
515 return extent;
516 }
517
518 Val* createNamedScalarFromValue(
519 kir::Scope& body,
520 Val* val,
521 const std::string& name,
522 bool address = false) {
523 auto namedScalar = (address) ? IrBuilder::addressExprNamedScalar(name, val)
524 : IrBuilder::setExprNamedScalar(name, val);
525 TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr);
526
527 auto alloc = IrBuilder::create<kir::Allocate>(
528 namedScalar,
529 MemoryType::Local,
530 GpuLower::current()->kernel()->oneVal());
531 body.push_back(alloc);
532 body.push_back(namedScalar->definition());
533 return namedScalar;
534 }
535};
536
537} // namespace
538
539std::vector<Expr*> processMisalignedVectorization(
540 const std::vector<Expr*>& exprs) {
541 return MisalignedVectorizationModifier::processMisalignedVectorization(exprs);
542}
543
544bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) {
545 for (auto expr : fl->body().exprs()) {
546 if (expr->isA<kir::ForLoop>()) {
547 auto child_fl = expr->as<kir::ForLoop>();
548 if (child_fl->iter_domain()->getParallelType() ==
549 ParallelType::MisalignedVectorize) {
550 return true;
551 }
552 }
553 }
554 return false;
555}
556
557} // namespace cuda
558} // namespace fuser
559} // namespace jit
560} // namespace torch
561