1#include <codegen.h>
2#include <expr_evaluator.h>
3#include <instrumentation.h>
4#include <kernel_expr_evaluator.h>
5#include <kernel_ir.h>
6#include <kernel_ir_dispatch.h>
7#include <scheduler/mma_utils.h>
8#include <type.h>
9#include <utils.h>
10
11#include <array>
12#include <cmath>
13#include <sstream>
14#include <vector>
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20namespace codegen {
21
22namespace {
23
24std::string ptrType(DataType dt) {
25 std::stringstream ss;
26 ss << dt << "*";
27 return ss.str();
28}
29
30//! Utility class to build an argument list
31class ArgumentBuilder {
32 public:
33 //! Build an argument list where each argument is separated with a comma
34 ArgumentBuilder() = default;
35
36 //! Build an argument list where each argument has its own line
37 ArgumentBuilder(int indent_level, const char* tab) {
38 std::stringstream ss;
39 for (const auto i : c10::irange(indent_level)) {
40 (void)i; // Suppress unused variable warning
41 ss << tab;
42 }
43 sep_ = ",\n" + ss.str();
44 }
45
46 //! Add a new argument
47 template <typename T>
48 ArgumentBuilder& arg(const T& x) {
49 addSeparator();
50 return append(x);
51 }
52
53 //! Append to the last argument
54 template <typename T>
55 ArgumentBuilder& append(const T& arg) {
56 ss_ << arg;
57 return *this;
58 }
59
60 //! Get a string of the argument list
61 std::string str() const {
62 return ss_.str();
63 }
64
65 friend std::ostream& operator<<(std::ostream& os, const ArgumentBuilder& ab) {
66 return os << ab.str();
67 }
68
69 private:
70 void addSeparator() {
71 if (ss_.tellp() != 0) {
72 ss_ << sep_;
73 }
74 }
75
76 private:
77 std::string sep_ = ", ";
78 std::stringstream ss_;
79};
80
81//! Append to the last argument
82template <>
83ArgumentBuilder& ArgumentBuilder::append<bool>(const bool& arg) {
84 ss_ << (arg ? "true" : "false");
85 return *this;
86}
87
88//! Returns "template_name<template_arg>"
89template <typename TemplateNameT, typename TemplateArgT>
90std::string genTemplate(
91 const TemplateNameT& template_name,
92 const TemplateArgT& template_arg) {
93 std::stringstream ss;
94 ss << template_name << "<" << template_arg << ">";
95 return ss.str();
96}
97
98//! Returns "func_name(func_arg)"
99template <typename FuncNameT, typename FuncArgT>
100std::string genCall(const FuncNameT& func_name, const FuncArgT& func_arg) {
101 std::stringstream ss;
102 ss << func_name << "(" << func_arg << ")";
103 return ss.str();
104}
105
106//! Returns "func_name<template_arg>(func_arg)"
107template <typename FuncNameT, typename TemplateArgT, typename FuncArgT>
108std::string genCall(
109 const FuncNameT& func_name,
110 const TemplateArgT& template_arg,
111 const FuncArgT& func_arg) {
112 std::stringstream ss;
113 ss << func_name << "<" << template_arg << ">(" << func_arg << ")";
114 return ss.str();
115}
116
117//! A utility class to check if an expression of a particular type exists
118class ExprFinder : kir::ConstIrVisitor {
119 public:
120 //! True if expr or any of its nested expressions is included in
121 //! expr_types
122 static bool exists(
123 const Expr* expr,
124 const std::unordered_set<ExprType>& expr_types) {
125 ExprFinder finder(expr_types);
126 finder.handle(std::vector<const Expr*>{expr});
127 return finder.is_found_;
128 }
129
130 private:
131 ExprFinder(const std::unordered_set<ExprType>& expr_types)
132 : expr_types_(expr_types) {}
133
134 using kir::ConstIrVisitor::handle;
135
136 void handle(const Expr* expr) final {
137 if (expr_types_.find(expr->etype()) != expr_types_.end()) {
138 is_found_ = true;
139 return;
140 }
141 kir::ConstIrVisitor::handle(expr);
142 }
143
144 private:
145 const std::unordered_set<ExprType>& expr_types_;
146 bool is_found_ = false;
147};
148
149class CudaKernelGenerator : private OptOutConstDispatch {
150 static constexpr const char* kTab = " ";
151
152 public:
153 static std::string generateKernelDefinition(
154 const kir::Kernel* kernel,
155 const std::string& kernel_name) {
156 CudaKernelGenerator codegen(kernel);
157 codegen.genDeclaration(kernel_name);
158 codegen.startBlock();
159 codegen.genPrologue();
160 codegen.genBody();
161 codegen.endBlock();
162 TORCH_CHECK(codegen.block_nest_level_ == 0);
163 return codegen.code_.str();
164 }
165
166 private:
167 explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {
168 initStringStreamFormat(code_);
169 }
170
171 void initStringStreamFormat(std::stringstream& ss) {
172 const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
173 ss.imbue(std::locale("C"));
174 ss << std::scientific << std::setprecision(digits);
175 }
176
177 // Generates the kernel function declaration
178 void genDeclaration(const std::string& kernel_name) {
179 const auto& kernel_summary = kernel_->summary();
180
181 code_ << "__global__ void " << kernel_name << "(";
182
183 std::unordered_set<Val*> unique_args;
184
185 std::vector<Val*> params;
186
187 // Inputs & Outputs
188 for (auto val : kernel_->inputs()) {
189 params.push_back(val);
190 }
191 for (auto val : kernel_->outputs()) {
192 TORCH_INTERNAL_ASSERT(
193 !val->isScalar(), "No scalar output is allowed: ", val->toString());
194 params.push_back(val);
195 }
196
197 // Generate parameter declarations
198 unsigned int duplicate_counter = 0;
199 for (auto i : c10::irange(params.size())) {
200 std::stringstream var_name_ss;
201 if (params[i]->isA<TensorView>()) {
202 var_name_ss << varName(params[i]->as<TensorView>());
203 } else {
204 var_name_ss << gen(params[i]);
205 }
206
207 // If value is duplicate in arguments change the name to avoid name
208 // conflicts in args.
209 if (!unique_args.emplace(params[i]).second) {
210 var_name_ss << "_duplicate_" << duplicate_counter++;
211 }
212
213 if (const auto tv = dynamic_cast<TensorView*>(params[i])) {
214 if (tv->isCpuScalar()) {
215 code_ << " CpuScalarTensor<" << params[i]->dtype() << "> "
216 << var_name_ss.str();
217 } else {
218 code_
219 << "Tensor<" << params[i]->dtype() << ", "
220 << TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size()
221 << "> " << var_name_ss.str();
222 }
223 } else {
224 TORCH_INTERNAL_ASSERT(params[i]->isScalar()); // NOLINT (LLVM bug 48525)
225 TORCH_INTERNAL_ASSERT(params[i]->definition() == nullptr);
226 code_ << params[i]->dtype() << " " << var_name_ss.str();
227 }
228
229 if (i + 1 != params.size()) {
230 code_ << ", ";
231 }
232 }
233
234 // Global buffers
235 for (auto allocate : kernel_summary.global_allocations) {
236 TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<TensorView>());
237 const auto tv = allocate->buffer()->as<TensorView>();
238 const auto& maybe_rfactor_domain = tv->domain()->hasRFactor()
239 ? tv->domain()->getRFactorDomain()
240 : tv->domain()->getRootDomain();
241 const auto nDims = std::count_if(
242 maybe_rfactor_domain.begin(),
243 maybe_rfactor_domain.end(),
244 [](const IterDomain* id) { return !id->isReduction(); });
245 code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> "
246 << varName(tv);
247 }
248
249 // Kernels generating random numbers take extra (seed, offset) arguments
250 if (kernel_summary.max_rng_offsets >= 0) {
251 code_ << ", at::PhiloxCudaState philox_args";
252 }
253
254 code_ << ") ";
255 }
256
257 // Generates setup code which is executed before the kernel body
258 void genPrologue() {
259 const auto& kernel_summary = kernel_->summary();
260
261 // Random number generator (optional)
262 if (kernel_summary.max_rng_offsets >= 0) {
263 indent() << "auto philox_offset = philox_args.captured_ ?\n";
264 indent()
265 << " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
266 indent() << " philox_args.offset_.val;\n";
267 indent() << "uint4 rng_result;\n";
268 indent() << "nvfuser_index_t rng_subseq = -1;\n";
269 indent() << "nvfuser_index_t rng_offset = -1;\n";
270 }
271
272 // Do we have any dynamic shared memory buffers?
273 const bool has_dynamic_smem =
274 !kernel_summary.dynamic_smem_allocations.empty();
275
276 // Do we have any reductions?
277 const bool has_reductions = kernel_summary.has_block_reductions ||
278 kernel_summary.has_grid_reductions;
279 const bool has_parallel_welford =
280 kernel_summary.has_block_welford || kernel_summary.has_grid_welford;
281
282 // Shared memory
283 if (has_dynamic_smem || has_reductions || has_parallel_welford) {
284 indent() << "alignas("
285#ifndef USE_ROCM
286 << 16 // always align to 16B for any shared mem allocation
287#else
288 << 8 // for HIP, we want 8-aligned even for smaller datatypes
289#endif
290 << ") extern __shared__ char array[];\n";
291
292 if (has_dynamic_smem) {
293 indent() << "unsigned smem_offset = 0;\n";
294 }
295
296 if (has_reductions || has_parallel_welford) {
297 indent() << "void* shared_mem = array;\n";
298 if (has_dynamic_smem) {
299 if (has_parallel_welford) {
300 indent() << "smem_offset += "
301 << "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
302 << kernel_summary.largest_smem_data_type << "));\n";
303 } else {
304 indent() << "smem_offset += "
305 << "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
306 << kernel_summary.largest_smem_data_type << "));\n";
307 }
308 }
309
310 if (has_parallel_welford) {
311 // Unpack shared mem pointer
312 auto space_type = kernel_summary.largest_smem_data_type;
313 indent()
314 << "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n";
315 indent() << space_type << " *shared_mem_var = "
316 << "static_cast<" << space_type << "*>("
317 << "shared_mem);\n";
318 indent() << space_type
319 << " *shared_mem_avg = shared_mem_var + block_size;\n";
320 indent() << space_type
321 << " *shared_mem_n = shared_mem_avg + block_size;\n";
322 }
323 }
324 }
325
326 // Call the initialization function if using a custom block sync
327 if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
328 indent() << "block_sync::init();\n";
329 }
330 }
331
332 void genBody() {
333 for (auto expr : kernel_->topLevelExprs()) {
334 OptOutConstDispatch::handle(expr);
335 }
336 }
337
338 void startBlock(bool continuation = false) {
339 if (continuation) {
340 code_ << "{\n";
341 } else {
342 indent() << "{\n";
343 }
344 ++block_nest_level_;
345 }
346
347 void endBlock(const char* sep = "\n") {
348 --block_nest_level_;
349 TORCH_CHECK(block_nest_level_ >= 0);
350 indent() << "}" << sep;
351 }
352
353 std::ostream& indent() {
354 for (const auto i : c10::irange(block_nest_level_)) {
355 (void)i; // Suppress unused variable warning
356 code_ << kTab;
357 }
358 return code_;
359 }
360
361 std::string gen(const Statement* stmt) {
362 std::stringstream tmp_code;
363 initStringStreamFormat(tmp_code);
364 std::swap(tmp_code, code_);
365 OptOutConstDispatch::handle(stmt);
366 std::swap(tmp_code, code_);
367 return tmp_code.str();
368 }
369
370 std::string varName(const Val* val) {
371 std::stringstream name;
372 if (val->isA<TensorView>()) {
373 name << "T";
374 } else if (val->isA<kir::IntPair>()) {
375 name << "ip";
376 } else {
377 name << typePrefix(val->dtype());
378 }
379 name << val->name();
380 return name.str();
381 }
382
383 std::string genInline(const Statement* stmt) {
384 const bool saved_inline = print_inline_;
385 print_inline_ = true;
386 auto result = gen(stmt);
387 print_inline_ = saved_inline;
388 // NOLINTNEXTLINE(performance-no-automatic-move)
389 return result;
390 }
391
392 void handle(const kir::Predicate* pred) final {
393 TORCH_INTERNAL_ASSERT(pred->hasValue());
394 code_ << gen(pred->value());
395 }
396
397 void handle(const Bool* pred) final {
398 const auto def = pred->definition();
399 const bool has_alloc = alloc_map_.find(pred) != alloc_map_.end();
400 if (def != nullptr && !has_alloc) {
401 code_ << "(" << gen(def) << ")";
402 } else if (pred->isConst()) {
403 code_ << (*pred->value() ? "true" : "false");
404 } else {
405 code_ << varName(pred);
406 }
407 }
408
409 void handle(const Double* d) final {
410 const auto def = d->definition();
411 const bool has_alloc = alloc_map_.find(d) != alloc_map_.end();
412 if (def != nullptr && !has_alloc) {
413 code_ << "(" << gen(def) << ")";
414 } else if (d->isConst()) {
415 auto val = *d->value();
416 // note: default inf/nan doesn't work and should be replaced with macros
417 // `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead.
418 if (std::isinf(val)) {
419 if (val > 0) {
420 code_ << "POS_INFINITY";
421 } else {
422 code_ << "NEG_INFINITY";
423 }
424 } else if (std::isnan(val)) {
425 code_ << "NAN";
426 } else {
427 code_ << val;
428 }
429 } else {
430 code_ << varName(d);
431 }
432 }
433
434 void handle(const Int* i) final {
435 // Check the replacement map first. If there's an entry for i, use
436 // the corresponding replacement.
437 auto replace_it = index_replacement_map_.find(i);
438 if (replace_it != index_replacement_map_.end()) {
439 code_ << replace_it->second;
440 return;
441 }
442
443 const auto def = i->definition();
444 const bool has_alloc = alloc_map_.find(i) != alloc_map_.end();
445 if (def != nullptr && !has_alloc) {
446 code_ << "(" << genInline(def) << ")";
447 } else if (i->isConst()) {
448 code_ << *i->value();
449 } else {
450 code_ << varName(i);
451 }
452 }
453
454 void handle(const ComplexDouble* c) final {
455 const auto def = c->definition();
456 const bool has_alloc = alloc_map_.find(c) != alloc_map_.end();
457 if (def != nullptr && !has_alloc) {
458 code_ << "(" << gen(def) << ")";
459 } else if (c->isConst()) {
460 code_ << "std::complex<double>" << *c->value();
461 } else {
462 code_ << varName(c);
463 }
464 }
465
466 void handle(const NamedScalar* ns) final {
467 // dim3 components are unsigned int. Cast to signed integer to
468 // support negative indexing
469 if (ns->getParallelIndex().has_value() ||
470 ns->getParallelDim().has_value()) {
471 code_ << "((nvfuser_index_t)" << ns->name() << ")";
472 } else {
473 code_ << ns->name();
474 }
475 }
476
477 //! Returns the sum of all indices in a TensorIndex,
478 //! or 0 if the indices vector is empty.
479 //! Used lowering generic tensor index and lowering
480 //! mma fragment indices.
481 std::string genTensorIndex(const kir::TensorIndex* ti) {
482 bool first = true;
483 std::stringstream index;
484 for (auto* ind : ti->indices()) {
485 if (!ind->isZeroInt()) {
486 if (!first) {
487 index << " + ";
488 }
489 index << genInline(ind);
490 first = false;
491 }
492 }
493
494 if (first) {
495 index << "0";
496 }
497
498 return index.str();
499 }
500
501 void handle(const kir::TensorIndex* ti) final {
502 bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
503 kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
504 if (is_volatile) {
505 code_ << "*(volatile " << ti->getDataType().value() << "*)&";
506 }
507 code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
508 }
509
510 void handle(const ViewAsScalar* sv) final {
511 indent() << gen(sv->output(0)) << " = " << gen(sv->input(0)) << "["
512 << gen(sv->index()) << "];\n";
513 }
514
515 void handle(const IterDomain*) final {
516 TORCH_INTERNAL_ASSERT(false, "Unreachable");
517 }
518
519 void handle(const TensorDomain*) final {
520 TORCH_INTERNAL_ASSERT(false, "Unreachable");
521 }
522
523 void handle(const TensorView*) final {
524 TORCH_INTERNAL_ASSERT(false, "Unreachable");
525 }
526
527 //! Utility for generating vectorized pointer access in ldsm and
528 //! cpasync.
529 //! TODO: this access pattern as is could be merged with exisiting
530 //! vectorization handling logic but this path will be updated in
531 //! follow ups to optimize the generated assembly so keeping them
532 //! separate path for now.
533 std::string genVectorPointer(Val* val, DataType dtype, int vec_size) {
534 std::stringstream ss;
535
536 ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << ","
537 << vec_size << ">*>(&" << gen(val) << ")";
538
539 return ss.str();
540 }
541
542 // Utility function to emit a cp.async intrinsic
543 void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
544 auto dtype = ldst->in()->getDataType().value();
545
546 if (ldst->predicate() == nullptr) {
547 // Out of line predicate variant
548 indent() << "Ampere::cpAsync("
549 << genVectorPointer(ldst->out(), dtype, vec_size) << ","
550 << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
551 } else {
552 // Inline predicate variant
553 indent() << "Ampere::cpAsync("
554 << genVectorPointer(ldst->out(), dtype, vec_size) << ","
555 << genVectorPointer(ldst->in(), dtype, vec_size) << ","
556 << genInline(ldst->predicate()) << ");\n";
557 }
558 }
559
560 void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
561 auto dtype = ldst->in()->getDataType().value();
562 indent() << "Turing::ldMatrix";
563 if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) {
564 code_ << "T";
565 }
566 code_ << " (";
567 code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
568 << ","
569 << "&" << gen(ldst->in()) << ");\n";
570 }
571
572 void handle(const FullOp* fop) final {
573 indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")"
574 << gen(fop->getFillValue()) << ";\n";
575 }
576
577 void handle(const ARangeOp* aop) final {
578 auto index =
579 genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>());
580 indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
581 code_ << "(" << index << ", " << gen(aop->start()) << ", "
582 << gen(aop->step()) << ");\n";
583 }
584
585 void handle(const EyeOp* aop) final {
586 auto index1 = gen(aop->getIndex1());
587 auto index2 = gen(aop->getIndex2());
588 indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")";
589 code_ << "(" << index1 << " == " << index2 << ");\n";
590 }
591
592 void handle(const UnaryOp* uop) final {
593 bool is_vector_op = false;
594 size_t vector_word_size = 1;
595
596 if (uop->out()->isA<kir::TensorIndex>()) {
597 auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
598 if (std::any_of(
599 out_tv->domain()->domain().begin(),
600 out_tv->domain()->domain().end(),
601 [&](IterDomain* id) { return id->isMma(); })) {
602 auto mma = dynamic_cast<MmaOp*>(
603 uop->out()->as<kir::TensorIndex>()->view()->definition());
604 TORCH_INTERNAL_ASSERT(
605 mma != nullptr, "CodeGen: mma op not in mma loop");
606 genMmaInitialization(mma, uop);
607 return;
608 }
609 }
610
611 if (vectorize_scope_ && uop->out()->isA<kir::TensorIndex>()) {
612 auto ti = uop->out()->as<kir::TensorIndex>();
613
614 bool vectorize_op = false;
615 bool misaligned_op = false;
616
617 for (auto id : ti->view()->domain()->domain()) {
618 if (!isParallelTypeVectorize(id->getParallelType())) {
619 continue;
620 }
621
622 ExpressionEvaluator expr_eval(id->fusion());
623 auto vector_size_optional = expr_eval.evaluate(id->extent());
624
625 TORCH_INTERNAL_ASSERT(
626 vector_size_optional.has_value(),
627 "Could not evaluate constant value bound to vectorized dim.");
628
629 vector_word_size = vector_size_optional->as<int64_t>();
630
631 vectorize_op = id->getParallelType() == ParallelType::Vectorize;
632 misaligned_op =
633 id->getParallelType() == ParallelType::MisalignedVectorize;
634 break;
635 }
636
637 if (vectorize_op) {
638 TORCH_INTERNAL_ASSERT(
639 uop->getUnaryOpType() == UnaryOpType::Set,
640 "Cannot vectorize operations that are not sets. ",
641 "Use cacheBefore and cacheAfter to store/load with vectorized reads into buffers.");
642 is_vector_op = true;
643 }
644
645 if (misaligned_op) {
646 is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set);
647 }
648
649 if (is_vector_op && !uop->in()->isScalar()) {
650 TORCH_INTERNAL_ASSERT(
651 uop->out()->dtype() == uop->in()->dtype(),
652 "Vectorized store/load requires input and output datatypes match.");
653 }
654
655 if (is_vector_op) {
656 auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
657 if (uop->in()->isScalar()) {
658 // Note:
659 // Double buffered local tensors need indexed initialization,
660 // so will need to use `arraySet` option.
661 if (out_tv->getMemoryType() == MemoryType::Local &&
662 !(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
663 // Vectorized initialization
664 indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
665 } else {
666 // Note: currently arraySet option is not vectorized, so it will
667 // rely on auto vectorization pass of cuda compiler.
668 indent() << "arraySet<" << out_tv->getDataType().value() << ", "
669 << vector_word_size << ">(&" << gen(uop->out()) << ", "
670 << "(" << out_tv->getDataType().value() << ")"
671 << gen(uop->in()) << ");\n";
672 }
673 } else {
674 // Vectorized load
675 TORCH_INTERNAL_ASSERT(
676 uop->in()->isA<kir::TensorIndex>(),
677 "Invalid input to unary op with tensor output, found: ",
678 uop->in()->toString());
679
680 auto in_tv = uop->in()->as<kir::TensorIndex>()->view();
681 bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
682 in_tv->getMemoryType() == MemoryType::Local;
683
684 bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
685 in_tv->getMemoryType() == MemoryType::Global;
686
687 bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
688 in_tv->getMemoryType() == MemoryType::Global;
689
690 bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
691 kernel_->summary().sync_map.needsRawSync(out_tv).hasBID();
692
693 bool is_volatile_from =
694 in_tv->getMemoryType() == MemoryType::Global &&
695 kernel_->summary().sync_map.needsRawSync(in_tv).hasBID();
696
697 if (localToGlobal) {
698 indent() << "loadLocalToGlobal<" << uop->out()->dtype() << ", "
699 << vector_word_size << ", "
700 << (is_volatile_to ? "true" : "false") << ">(";
701 code_ << " &" << gen(uop->out()) << ", &" << gen(uop->in())
702 << ");\n";
703 } else if (globalToLocal) {
704 indent() << "loadGlobalToLocal<" << uop->out()->dtype() << ", "
705 << vector_word_size << ", "
706 << (is_volatile_from ? "true" : "false") << ">(&"
707 << gen(uop->out()) << ", ";
708 code_ << " &" << gen(uop->in()) << ");\n";
709 } else if (globalToGlobal) {
710 indent() << "loadGlobalToGlobal<" << uop->out()->dtype() << ", "
711 << vector_word_size << ", "
712 << (is_volatile_to ? "true" : "false") << ", "
713 << (is_volatile_from ? "true" : "false") << ">(";
714 code_ << " &" << gen(uop->out()) << ", ";
715 code_ << " &" << gen(uop->in()) << ");\n";
716 } else {
717 indent() << "loadGeneric<" << uop->out()->dtype() << ", "
718 << vector_word_size << ">(";
719 code_ << " &" << gen(uop->out()) << ", ";
720 code_ << " &" << gen(uop->in()) << ");\n";
721 }
722 }
723 return;
724 }
725 }
726
727 const auto op_type = uop->getUnaryOpType();
728
729 if (uop->out()->isA<NamedScalar>()) {
730 if (auto op = inline_op_str(op_type)) {
731 indent() << gen(uop->out()) << " = " << *op << genInline(uop->in())
732 << ";\n";
733 }
734 return;
735 }
736
737 if (!print_inline_) {
738 indent() << gen(uop->out());
739 if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
740 code_ << "\n";
741 indent() << kTab;
742 }
743 code_ << " = ";
744 }
745
746 if (auto op = inline_op_str(op_type)) {
747 if (alsoBooleanOperator(op_type) &&
748 uop->out()->dtype() == DataType::Bool) {
749 code_ << stringifyBooleanOp(op_type) << gen(uop->in());
750 } else {
751 code_ << *op << gen(uop->in());
752 }
753 } else {
754 if (op_type == UnaryOpType::Cast) {
755 const auto cast_str =
756 cast_func_str({uop->in()->dtype(), uop->out()->dtype()});
757 TORCH_INTERNAL_ASSERT(
758 cast_str.has_value(),
759 "Invalid cast. Input type: ",
760 uop->in()->dtype(),
761 ", output type: ",
762 uop->out()->dtype());
763 code_ << cast_str.value();
764 } else {
765 code_ << op_type;
766 if (needFloatSuffix(op_type) &&
767 uop->out()->dtype() == DataType::Float) {
768 code_ << "f";
769 }
770 }
771
772 code_ << "(" << gen(uop->in()) << ")";
773 }
774
775 if (!print_inline_) {
776 code_ << ";\n";
777 }
778 }
779
780 void handle(const RNGOp* rop) final {
781 // TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
782 // innermost ID of size 4 (float) or size 2 (double)?
783 auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
784 int multiple = rop->dtype() == DataType::Double ? 2 : 4;
785 indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
786 << ";\n";
787 indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
788 << rop->name() << " / " << multiple << ";\n";
789 indent() << "nvfuser_index_t rng_component" << rop->name()
790 << " = linear_index" << rop->name() << " % " << multiple << ";\n";
791 indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
792 << rop->getRNGOffset() << ";\n";
793 indent() << "if (rng_subseq != rng_subseq" << rop->name()
794 << " || rng_offset != rng_offset" << rop->name() << ") {\n";
795 indent() << " auto seed = philox_args.captured_ ?\n"
796 << " static_cast<uint64_t>(*(philox_args.seed_.ptr)) : \n"
797 << " philox_args.seed_.val;\n";
798 indent() << " rng_result = philox(seed, rng_subseq" << rop->name()
799 << ", philox_offset / 4 + rng_offset" << rop->name() << ");\n";
800 indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n";
801 indent() << " rng_offset = rng_offset" << rop->name() << ";\n";
802 indent() << "}\n";
803 auto op_type = rop->getRNGOpType();
804 indent() << gen(rop->output(0)) << " = " << op_type;
805 if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
806 code_ << "f";
807 }
808 code_ << "(rng_result, rng_component" << rop->name();
809 switch (op_type) {
810 case RNGOpType::UniformRange: {
811 auto parameters = rop->getParameters();
812 TORCH_INTERNAL_ASSERT(parameters.size() == 2);
813 code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
814 break;
815 }
816 default:;
817 }
818 code_ << ");\n";
819 }
820
821 std::string genBinaryOp(
822 BinaryOpType op_type,
823 DataType data_type,
824 const std::string& lhs,
825 const std::string& rhs) {
826 std::stringstream expr;
827 if (auto op = inline_op_str(op_type)) {
828 expr << lhs << " ";
829 if (alsoBooleanOperator(op_type) && data_type == DataType::Bool) {
830 expr << stringifyBooleanOp(op_type);
831 } else {
832 expr << *op;
833 }
834 expr << " " << rhs;
835 } else {
836 if (integer_op_str(op_type) && isIntegralType(data_type)) {
837 auto int_op = integer_op_str(op_type);
838 expr << *int_op;
839 } else if (bool_op_str(op_type) && isBooleanType(data_type)) {
840 auto bool_op = bool_op_str(op_type);
841 expr << *bool_op;
842 } else {
843 expr << op_type;
844 if (needFloatSuffix(op_type) && data_type == DataType::Float) {
845 expr << "f";
846 }
847 }
848 expr << "(" << lhs << ", " << rhs << ")";
849 }
850 return expr.str();
851 }
852
853 // If one argument is a tensorview and the other is a scalar, make sure we
854 // cast the scalar to the tensorview type
855 std::string scalarCast(Val* lhs, Val* rhs) {
856 // If neither are scalars return
857 if (!((lhs->isScalar() || rhs->isScalar()) &&
858 (lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) {
859 return "";
860 }
861
862 // Looking for mixed tensorview scalar options where types don't match
863 // but are either both floating or both int types. We should cast
864 // scalar to tensorview type in these instances.
865 auto lhs_t = lhs->dtype();
866 auto rhs_t = rhs->dtype();
867
868 // If same type, don't cast anything
869 if (lhs_t == rhs_t) {
870 return "";
871 }
872
873 // Don't do anything when dealing with bools
874 if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) {
875 return "";
876 }
877
878 // Mixing floating and int combination
879 if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) ||
880 (isIntegralType(lhs_t) != isIntegralType(rhs_t))) {
881 return "";
882 }
883
884 std::stringstream cast;
885 cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") ";
886 return cast.str();
887 }
888
889 // If possible, replace pow with mul. Return true when successful.
890 bool genPowerWithMul(const BinaryOp* bop) {
891 if (bop->getBinaryOpType() != BinaryOpType::Pow) {
892 return false;
893 }
894
895 auto rhs = bop->rhs();
896 c10::optional<double> exponent;
897 if (auto val_int = dynamic_cast<Int*>(rhs)) {
898 if (val_int->isConst()) {
899 exponent = val_int->value().value();
900 }
901 } else if (auto val_float = dynamic_cast<Double*>(rhs)) {
902 if (val_float->isConst()) {
903 auto fp_exp = val_float->value().value();
904 double int_exp = 0;
905 if (std::modf(fp_exp, &int_exp) == 0) {
906 exponent = int_exp;
907 }
908 }
909 }
910
911 if (!exponent.has_value()) {
912 return false;
913 }
914
915 // Only **2 and **3 are considered
916 if (!(exponent.value() == 2 || exponent.value() == 3)) {
917 return false;
918 }
919
920 auto lhs = gen(bop->lhs());
921
922 if (print_inline_) {
923 code_ << lhs << " * " << lhs;
924 if (exponent.value() == 3) {
925 code_ << " * " << lhs;
926 }
927 } else {
928 indent() << gen(bop->out());
929 if (bop->out()->isScalar()) {
930 code_ << " = " << lhs << " * " << lhs;
931 if (exponent.value() == 3) {
932 code_ << " * " << lhs;
933 }
934 } else {
935 code_ << "\n";
936 indent() << kTab << "= " << lhs << "\n";
937 indent() << kTab << "* " << lhs;
938 if (exponent.value() == 3) {
939 code_ << "\n";
940 indent() << kTab << "* " << lhs;
941 }
942 }
943 }
944
945 code_ << ";\n";
946 return true;
947 }
948
949 void handle(const BinaryOp* bop) final {
950 // Try replacing pow with mul
951 if (genPowerWithMul(bop)) {
952 return;
953 }
954
955 const auto op_type = bop->getBinaryOpType();
956 if (print_inline_) {
957 // Inline expression: `lhs op rhs`
958 code_ << genBinaryOp(
959 op_type, bop->out()->dtype(), gen(bop->lhs()), gen(bop->rhs()));
960 } else {
961 indent() << gen(bop->out());
962 if (bop->out()->isScalar()) {
963 // Single line: `out = lhs op rhs;`
964 code_ << " = "
965 << genBinaryOp(
966 op_type,
967 bop->out()->dtype(),
968 gen(bop->lhs()),
969 gen(bop->rhs()));
970 } else {
971 // Split TensorView expressions across multiple lines:
972 //
973 // out
974 // = lhs
975 // op rhs;
976 //
977
978 auto cast = scalarCast(bop->lhs(), bop->rhs());
979 if (auto op = inline_op_str(op_type)) {
980 code_ << "\n";
981 indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "")
982 << gen(bop->lhs()) << "\n";
983 indent() << kTab;
984 if (alsoBooleanOperator(op_type) &&
985 bop->out()->dtype() == DataType::Bool) {
986 code_ << stringifyBooleanOp(op_type);
987 } else {
988 code_ << *op;
989 }
990 code_ << " " << (bop->rhs()->isScalar() ? cast : "")
991 << gen(bop->rhs());
992 } else {
993 if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) {
994 auto int_op = integer_op_str(op_type);
995 code_ << " = " << *int_op << "(\n";
996 } else if (
997 bool_op_str(op_type) && isBooleanType(bop->out()->dtype())) {
998 auto bool_op = bool_op_str(op_type);
999 code_ << " = " << *bool_op << "(\n";
1000 } else {
1001 std::stringstream op_str;
1002 op_str << op_type;
1003 if (needFloatSuffix(op_type) &&
1004 bop->out()->dtype() == DataType::Float) {
1005 op_str << "f";
1006 }
1007 code_ << " = " << op_str.str() << "(\n";
1008 }
1009 indent() << kTab << (bop->lhs()->isScalar() ? cast : "")
1010 << gen(bop->lhs()) << ",\n";
1011 indent() << kTab << (bop->rhs()->isScalar() ? cast : "")
1012 << gen(bop->rhs()) << ")";
1013 }
1014 }
1015 code_ << ";\n";
1016 }
1017 }
1018
1019 void handle(const TernaryOp* top) final {
1020 if (!print_inline_) {
1021 indent() << gen(top->out());
1022 if (!top->out()->isScalar()) {
1023 code_ << "\n";
1024 indent() << kTab;
1025 }
1026 code_ << " = ";
1027 }
1028
1029 code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", ";
1030
1031 // Make sure the two operands of where has the same
1032 // type. Note that compiling "where(0.0f, 0.0)" fails because of
1033 // the overloading ambiguity.
1034 if (top->getTernaryOpType() == TernaryOpType::Where) {
1035 auto cast = scalarCast(top->in2(), top->in3());
1036 code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", "
1037 << (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")";
1038 } else {
1039 code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")";
1040 }
1041
1042 if (!print_inline_) {
1043 code_ << ";\n";
1044 }
1045 }
1046
1047 std::string genArchString(MmaOptions::MacroType macro) {
1048 std::stringstream ss;
1049 if (isVolta(macro)) {
1050 ss << "Volta";
1051 } else if (isTuring(macro)) {
1052 ss << "Turing";
1053 } else if (isAmpere(macro)) {
1054 ss << "Ampere";
1055 } else {
1056 TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
1057 }
1058 return ss.str();
1059 }
1060
1061 std::string genMmaOp(const MmaOp* mma, bool init = false) {
1062 std::stringstream ss;
1063 auto options = mma->options();
1064 ss << genArchString(options.macro) << "::";
1065 if (init) {
1066 ss << "init";
1067 }
1068 ss << toString(options.macro);
1069
1070 if (isVolta(options.macro)) {
1071 ss << toString(options.operand_layout);
1072 } else if (isTuring(options.macro) || isAmpere(options.macro)) {
1073 // mma's in turing and ampere TN only, transpose is handled either
1074 // via ldmatrix for fp16 or explicitly for other types.
1075 ss << "TN";
1076 }
1077 // TODO: additional parameter could be removed by swizzling iterdomain
1078 auto acc_stride = mma->accStride();
1079 TORCH_INTERNAL_ASSERT(acc_stride > 0);
1080 ss << "<" << acc_stride << ">";
1081 return ss.str();
1082 }
1083
1084 void genMmaOperands(const MmaOp* mma) {
1085 std::stringstream ss;
1086 auto options = mma->options();
1087 auto in_a = mma->inA()->as<kir::TensorIndex>()->view();
1088 auto dtype = in_a->getDataType().value();
1089 indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
1090 << getInputARegisterSize(options.macro) << ","
1091 << getInputARegisterSize(options.macro) << ">*>(&"
1092 << varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")["
1093 << genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])"
1094 << ",\n";
1095 indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
1096 << getInputBRegisterSize(options.macro) << ","
1097 << getInputBRegisterSize(options.macro) << ">*>(&"
1098 << varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")["
1099 << genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])";
1100 }
1101
1102 void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) {
1103 auto options = mma->options();
1104
1105 indent() << genMmaOp(mma, true) << "(reinterpret_cast<Array<"
1106 << mma->out()->getDataType().value() << ","
1107 << getOutputRegisterSize(options.macro) << ","
1108 << getOutputRegisterSize(options.macro) << ">*>"
1109 << "(&" << gen(uop->out()) << "));\n";
1110 }
1111
1112 void handle(const MmaOp* mma) final {
1113 auto options = mma->options();
1114 auto out = mma->out()->as<kir::TensorIndex>();
1115 indent() << genMmaOp(mma) << "(\n";
1116 indent() << kTab << "reinterpret_cast<Array<"
1117 << out->view()->getDataType().value() << ","
1118 << getOutputRegisterSize(options.macro) << ","
1119 << getOutputRegisterSize(options.macro) << ">*>(&"
1120 << gen(mma->out()) << "),\n";
1121 genMmaOperands(mma);
1122 code_ << ");\n";
1123 }
1124
1125 std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
1126 std::stringstream lambda;
1127 lambda << "[](" << data_type << " &a, " << data_type << " b) "
1128 << "{ a = " << genBinaryOp(op_type, data_type, "a", "b") << "; }";
1129 return lambda.str();
1130 }
1131
1132 void handle(const BroadcastOp* stmt) final {
1133 TORCH_INTERNAL_ASSERT(stmt->out()->isA<kir::TensorIndex>());
1134
1135 const ParallelTypeBitmap parallel_types =
1136 kernel_->summary().broadcast_parallel_types.at(stmt);
1137
1138 if (parallel_types.none()) {
1139 // Not parallelized
1140 indent() << gen(stmt->out()) << "\n";
1141 indent() << kTab << " = " << gen(stmt->in()) << ";\n";
1142 return;
1143 }
1144
1145 TORCH_INTERNAL_ASSERT(
1146 !parallel_types.hasBID(),
1147 "Parallel broadcast across blocks should have been translated to a GridBroadcast IR node");
1148
1149 std::stringstream flags_str;
1150 for (const ParallelType pt : kParallelTypeTIDs) {
1151 const bool parallel_bcast = parallel_types.get(pt);
1152 if (pt != kParallelTypeTIDs[0]) {
1153 flags_str << ", ";
1154 }
1155 flags_str << (parallel_bcast ? "true" : "false");
1156 }
1157
1158 const auto data_type = stmt->out()->dtype();
1159 indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n";
1160 indent() << kTab << gen(stmt->out()) << ",\n";
1161 indent() << kTab << gen(stmt->in()) << ",\n";
1162 indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
1163 TORCH_INTERNAL_ASSERT(
1164 stmt->predicate() != nullptr && stmt->predicate()->hasValue());
1165 indent() << kTab << genInline(stmt->predicate()) << ");\n";
1166 }
1167
1168 void genSerialReduction(
1169 const kir::TensorIndex* output,
1170 const Val* input,
1171 BinaryOpType reduction_op_type) {
1172 const auto gen_out = gen(output);
1173 indent() << gen_out << " = "
1174 << genBinaryOp(
1175 reduction_op_type, output->dtype(), gen_out, gen(input))
1176 << ";\n";
1177 return;
1178 }
1179
1180 void genWarpReduction(
1181 const kir::TensorIndex* output,
1182 const kir::TensorIndex* input,
1183 const Val* init,
1184 BinaryOpType reduction_op_type,
1185 kir::Predicate* read_pred) {
1186 bool is_single_warp =
1187 kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp;
1188
1189 indent() << "warp::warpReduceTIDX";
1190 if (is_single_warp) {
1191 code_ << "<true>(\n";
1192 } else {
1193 code_ << "<false>(\n";
1194 }
1195 indent() << kTab << gen(output) << ",\n";
1196 indent() << kTab << gen(input) << ",\n";
1197 indent() << kTab << genReductionOp(reduction_op_type, output->dtype())
1198 << ",\n";
1199 indent() << kTab << "threadIdx,\n";
1200 indent() << kTab << "blockDim,\n";
1201 indent() << kTab << "static_cast<" << output->dtype()
1202 << "*>(shared_mem),\n";
1203 TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue());
1204 indent() << kTab << genInline(read_pred) << ",\n";
1205 indent() << kTab << output->dtype() << "(" << genInline(init) << "));\n";
1206 }
1207
1208 void genBlockReduction(
1209 const kir::TensorIndex* output,
1210 const kir::TensorIndex* input,
1211 const Val* init,
1212 BinaryOpType reduction_op_type,
1213 kir::Predicate* read_pred,
1214 kir::Predicate* write_pred) {
1215 const auto par_domains = ir_utils::getParallelDomains(output);
1216 // Get parallel reduction domains
1217 const bool tidx =
1218 par_domains.find(ParallelType::TIDx) != par_domains.end() &&
1219 par_domains.at(ParallelType::TIDx)->isReduction();
1220 const bool tidy =
1221 par_domains.find(ParallelType::TIDy) != par_domains.end() &&
1222 par_domains.at(ParallelType::TIDy)->isReduction();
1223 const bool tidz =
1224 par_domains.find(ParallelType::TIDz) != par_domains.end() &&
1225 par_domains.at(ParallelType::TIDz)->isReduction();
1226
1227 const auto data_type = output->dtype();
1228
1229 indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
1230 << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
1231 << ">(\n";
1232 indent() << kTab << gen(output) << ",\n";
1233 indent() << kTab << gen(input) << ",\n";
1234 indent() << kTab << genReductionOp(reduction_op_type, output->dtype())
1235 << ",\n";
1236 indent() << kTab << "threadIdx,\n";
1237 indent() << kTab << "blockDim,\n";
1238 indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
1239 TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue());
1240 indent() << kTab << genInline(read_pred) << ",\n";
1241 // Pass the write predicate if available and different from the
1242 // default predicate. The blockReduce runtime function uses the
1243 // default predicate for both read and write when only the
1244 // default one is given.
1245 if (write_pred != nullptr) {
1246 TORCH_INTERNAL_ASSERT(write_pred->hasValue());
1247 indent() << kTab << genInline(write_pred) << ",\n";
1248 }
1249 indent() << kTab << data_type << "(" << genInline(init) << "));\n";
1250 }
1251
1252 void handle(const ReductionOp* rop) final {
1253 TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
1254
1255 const auto output = rop->out()->as<kir::TensorIndex>();
1256 const auto input = rop->in()->as<kir::TensorIndex>();
1257 const auto domain = output->view()->domain();
1258 const auto op_type = rop->getReductionOpType();
1259
1260 const bool has_block_reduce = domain->hasBlockReduction();
1261 const bool has_grid_reduce = domain->hasGridReduction();
1262
1263 TORCH_INTERNAL_ASSERT(
1264 !has_grid_reduce,
1265 "ReductionOp does not support block parallelization. GridReductionOp must be used. ",
1266 rop->toString());
1267
1268 if (!has_block_reduce) {
1269 genSerialReduction(output, input, op_type);
1270 } else if (
1271 auto reduction_id = ir_utils::getMaybeWarpReductionDim(output, input)) {
1272 genWarpReduction(output, input, rop->init(), op_type, rop->predicate());
1273 } else {
1274 genBlockReduction(
1275 output,
1276 input,
1277 rop->init(),
1278 op_type,
1279 rop->predicate(),
1280 rop->writePredicate());
1281 }
1282 }
1283
1284 void handle(const LoadStoreOp* ldst) final {
1285 // TODO:
1286 // Need to gradually merge the code path of this
1287 // with UnaryOp::Set for vectorization.
1288 // There is quite a bit of possible clean up.
1289 bool vectorize_op = false;
1290 size_t vector_word_size = 1;
1291 auto ti = ldst->out()->as<kir::TensorIndex>();
1292
1293 // Check vectorization and set vector word size
1294 for (auto id : ti->view()->domain()->domain()) {
1295 if (!isParallelTypeVectorize(id->getParallelType())) {
1296 continue;
1297 }
1298
1299 ExpressionEvaluator expr_eval(id->fusion());
1300 auto vector_size_optional = expr_eval.evaluate(id->extent());
1301
1302 TORCH_INTERNAL_ASSERT(
1303 vector_size_optional.has_value(),
1304 "Could not evaluate constant value bound to vectorized dim.");
1305
1306 TORCH_INTERNAL_ASSERT(
1307 id->getParallelType() != ParallelType::MisalignedVectorize,
1308 "LoadStoreOp: no support yet for mis-aligned vectorization");
1309 vector_word_size = vector_size_optional->as<int64_t>();
1310 vectorize_op = true;
1311 break;
1312 }
1313
1314 // Dispatch instruction generation:
1315 switch (ldst->opType()) {
1316 case LoadStoreOpType::LdMatrix:
1317 case LoadStoreOpType::LdMatrixTranspose:
1318 TORCH_INTERNAL_ASSERT(
1319 vectorize_op, "LdMatrix: Vectorization required: ", ldst);
1320 genLdMatrix(ldst, vector_word_size);
1321 break;
1322 case LoadStoreOpType::CpAsync:
1323 genCpAsync(ldst, vector_word_size);
1324 break;
1325 default:
1326 TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
1327 }
1328 }
1329
1330 void handle(const WelfordOp* wop) final {
1331 TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());
1332
1333 const auto out = wop->out()->as<kir::TensorIndex>();
1334 const auto domain = out->view()->domain();
1335
1336 const auto out_var = wop->outVar();
1337 const auto out_avg = wop->outAvg();
1338 const auto out_N = wop->outN();
1339
1340 const auto in_var = wop->inVar();
1341 const auto in_avg = wop->inAvg();
1342 const auto in_N = wop->inN();
1343
1344 // inVar was allowed to be nullptr. Make sure it isn't.
1345 TORCH_INTERNAL_ASSERT(
1346 in_var != nullptr, "Welford var input nullptr not allowed");
1347
1348 const bool has_block_reduce = domain->hasBlockReduction();
1349 const bool has_grid_reduce = domain->hasGridReduction();
1350
1351 // Serial WelfordOp generation
1352 if (!has_block_reduce && !has_grid_reduce) {
1353 indent() << "welfordCombine ("
1354 << "\n";
1355 indent() << kTab << gen(out_avg) << ",\n";
1356 indent() << kTab << gen(out_var) << ",\n";
1357 indent() << kTab << gen(out_N) << ",\n";
1358 indent() << kTab << gen(in_avg) << ",\n";
1359 indent() << kTab << "(" << out_avg->dtype() << ")" << gen(in_var)
1360 << ",\n";
1361 indent() << kTab << "(" << out_N->dtype() << ")" << gen(in_N) << ");\n";
1362 return;
1363 }
1364
1365 const auto par_domains = ir_utils::getParallelDomains(wop->out());
1366 // Get parallel reduction domains
1367 const bool tidx =
1368 par_domains.find(ParallelType::TIDx) != par_domains.end() &&
1369 par_domains.at(ParallelType::TIDx)->isReduction();
1370 const bool tidy =
1371 par_domains.find(ParallelType::TIDy) != par_domains.end() &&
1372 par_domains.at(ParallelType::TIDy)->isReduction();
1373 const bool tidz =
1374 par_domains.find(ParallelType::TIDz) != par_domains.end() &&
1375 par_domains.at(ParallelType::TIDz)->isReduction();
1376
1377 const auto data_type = wop->out()->dtype();
1378
1379 if (has_block_reduce) {
1380 if (has_grid_reduce) {
1381 // allocate block result
1382 indent() << data_type << " "
1383 << "block_result_avg_" << block_reduce_name_ << " = "
1384 << gen(wop->initAvg()) << ";\n";
1385 indent() << data_type << " "
1386 << "block_result_var_" << block_reduce_name_ << " = "
1387 << gen(wop->initVar()) << ";\n";
1388 indent() << out_N->dtype() << " "
1389 << "block_result_n_" << block_reduce_name_ << " = "
1390 << gen(wop->initN()) << ";\n";
1391 }
1392 indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
1393 << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
1394 << ">(\n";
1395 if (has_grid_reduce) {
1396 indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n";
1397 indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n";
1398 indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
1399 } else {
1400 indent() << kTab << gen(wop->outAvg()) << ",\n";
1401 indent() << kTab << gen(wop->outVar()) << ",\n";
1402 indent() << kTab << gen(wop->outN()) << ",\n";
1403 }
1404 indent() << kTab << gen(in_avg) << ",\n";
1405 indent() << kTab << out_avg->dtype() << "(" << gen(in_var) << "),\n";
1406 indent() << kTab << out_N->dtype() << "(" << gen(in_N) << "),\n";
1407 indent() << kTab << "threadIdx,\n";
1408 indent() << kTab << "blockDim,\n";
1409 indent() << kTab << "reinterpret_cast<" << data_type
1410 << "*>(shared_mem_avg),\n";
1411 indent() << kTab << "reinterpret_cast<" << data_type
1412 << "*>(shared_mem_var),\n";
1413 indent() << kTab << "reinterpret_cast<" << out_N->dtype()
1414 << "*>(shared_mem_n),\n";
1415 TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr);
1416 TORCH_INTERNAL_ASSERT(
1417 wop->predicate() != nullptr && wop->predicate()->hasValue());
1418 auto read_pred = genInline(wop->predicate());
1419 indent() << kTab << read_pred << ",\n";
1420 if (wop->writePredicate() != nullptr) {
1421 TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue());
1422 auto write_pred = genInline(wop->writePredicate());
1423 indent() << kTab << write_pred << ",\n";
1424 }
1425 indent() << kTab << data_type << "(0));\n";
1426 }
1427 }
1428
1429 // Support ReductionOp and WelfordOp
1430 template <typename REDUCTION_OP>
1431 std::string generateGridReduceTemplateFlags(
1432 const REDUCTION_OP* rop,
1433 const ParallelTypeBitmap& thread_pred) {
1434 TORCH_INTERNAL_ASSERT(
1435 !rop->isAllreduce(),
1436 "This is not for the allreduce reduction kernel\n");
1437
1438 const auto par_domains = ir_utils::getParallelDomains(rop->outputs()[0]);
1439 ArgumentBuilder flags;
1440 for (const ParallelType pt : kParallelTypeThreads) {
1441 const bool parallel_reduction =
1442 par_domains.find(pt) != par_domains.end() &&
1443 par_domains.at(pt)->isReduction();
1444 const bool pred = thread_pred.get(pt);
1445 TORCH_INTERNAL_ASSERT(
1446 !(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
1447 bool flag = false;
1448 // Currently assumed that no dimensions parallelized with blocks
1449 // are predicated. This assumption may be lifted, but
1450 // gridReduction would need some changes.
1451 if (isParallelTypeBlockDim(pt)) {
1452 TORCH_INTERNAL_ASSERT(
1453 !pred, "Predication on block dimensions not allowed: ", pt);
1454 flag = parallel_reduction;
1455 } else {
1456 flag = !pred && !parallel_reduction;
1457 }
1458 flags.arg(flag);
1459 }
1460 return flags.str();
1461 }
1462
1463 // TODO: This should replace generateGridReduceTemplateFlags once
1464 // GridWelford is refactored as GridReduction.
1465 template <typename REDUCTION_OP>
1466 std::string generateGridReduceTemplateFlags2(
1467 const REDUCTION_OP* rop,
1468 const ParallelTypeBitmap& thread_pred) {
1469 TORCH_INTERNAL_ASSERT(
1470 !rop->isAllreduce(),
1471 "This is not for the allreduce reduction kernel\n");
1472
1473 const auto par_domains =
1474 ir_utils::getParallelDomains(ir_utils::getTvOutput(rop));
1475 ArgumentBuilder flags;
1476 for (const ParallelType pt : kParallelTypeThreads) {
1477 const bool parallel_reduction =
1478 par_domains.find(pt) != par_domains.end() &&
1479 par_domains.at(pt)->isReduction();
1480 const bool pred = thread_pred.get(pt);
1481 TORCH_INTERNAL_ASSERT(
1482 !(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
1483 // Currently assumed that no dimensions parallelized with blocks
1484 // are predicated. This assumption may be lifted, but
1485 // gridReduction would need some changes.
1486 if (isParallelTypeBlockDim(pt)) {
1487 TORCH_INTERNAL_ASSERT(
1488 !pred, "Predication on block dimensions not allowed: ", pt);
1489 }
1490 flags.arg(parallel_reduction);
1491 }
1492 return flags.str();
1493 }
1494
1495 void addProfileArguments(ArgumentBuilder& func_args, const Expr* expr) {
1496 if (isOptionEnabled(EnableOption::KernelProfile) &&
1497 kernel_->profile().isProfiled(expr)) {
1498 const auto& buffer_indices =
1499 kernel_->profile().getIndicesInProfileBuffer(expr);
1500 auto buffer = kernel_->profile().getBuffer();
1501 TORCH_INTERNAL_ASSERT(buffer != nullptr);
1502 for (const auto& index : buffer_indices) {
1503 func_args.arg(varName(buffer)).append("[").append(index).append("]");
1504 }
1505 }
1506 }
1507
1508 void handle(const kir::GridReduction* grop) final {
1509 TORCH_INTERNAL_ASSERT(grop->out()->isA<kir::TensorIndex>());
1510
1511 const auto out = grop->out()->as<kir::TensorIndex>();
1512 const auto domain = out->view()->domain();
1513 TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
1514
1515 const auto data_type = grop->out()->dtype();
1516 const auto op_type = grop->getReductionOpType();
1517
1518 TORCH_INTERNAL_ASSERT(
1519 grop->reduction_buffer()->buffer()->isA<TensorView>());
1520 TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
1521 const auto work_buffer =
1522 grop->reduction_buffer()->buffer()->as<TensorView>();
1523 const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
1524
1525 if (grop->isAllreduce()) {
1526 generateGridAllreduce(grop);
1527 return;
1528 }
1529
1530 const std::string flags_str =
1531 generateGridReduceTemplateFlags2(grop, grop->threadPredicate());
1532
1533 const bool persistent_sync =
1534 kernel_->summary().has_cooperative_grid_reduction;
1535
1536 // Since block-level reduction is already done, those dimensions
1537 // with tidx/y/z being true do not participate in the grid
1538 // reduction.
1539 ArgumentBuilder template_args;
1540 template_args.arg(flags_str).arg(persistent_sync);
1541
1542 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
1543 func_args.arg(gen(grop->out()));
1544 func_args.arg(gen(grop->in()));
1545 func_args.arg(genReductionOp(op_type, out->dtype()));
1546 func_args.arg("&").append(varName(work_buffer)).append("[0]");
1547 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
1548 func_args.arg(genCall("static_cast", ptrType(data_type), "shared_mem"));
1549 // read and write predicates
1550 TORCH_INTERNAL_ASSERT(
1551 grop->predicate() != nullptr && grop->predicate()->hasValue());
1552 const auto read_pred = genInline(grop->predicate());
1553 func_args.arg(read_pred);
1554 if (grop->writePredicate() != nullptr) {
1555 TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue());
1556 func_args.arg(genInline(grop->writePredicate()));
1557 } else {
1558 func_args.arg(read_pred);
1559 }
1560 // Init val
1561 func_args.arg(genCall(data_type, genInline(grop->init())));
1562 func_args.arg(genInline(grop->entrance_index()));
1563 func_args.arg(genInline(grop->entrances()));
1564
1565 addProfileArguments(func_args, grop);
1566
1567 indent() << "reduction::gridReduce<" << template_args << ">(\n";
1568 indent() << kTab << func_args << ");\n";
1569 }
1570
1571 std::string genFusedReductionName(const TensorView* reduction_out) {
1572 return varName(reduction_out) + "_reduction";
1573 }
1574
1575 void generateGridAllreduce(const kir::GridReduction* grop) {
1576 TORCH_INTERNAL_ASSERT(grop->isAllreduce());
1577
1578 const auto out = grop->out()->as<kir::TensorIndex>();
1579
1580 const auto data_type = grop->out()->dtype();
1581 const auto op_type = grop->getReductionOpType();
1582
1583 const auto work_buffer =
1584 grop->reduction_buffer()->buffer()->as<TensorView>();
1585 const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
1586
1587 const auto reduction_name = genFusedReductionName(out->view());
1588
1589 // template <typename Func, typename... Types>
1590 // __device__ __inline__ void reduce(
1591 // RefTuple<Types...> out,
1592 // const LocalTuple<Types...>& inp,
1593 // VolatilePtrTuple<Types...> global_work_buffer,
1594 // int64_t* global_sync_buffer, // Allocated as product of all
1595 // // non-participating Grid dimension
1596 // PtrTuple<Types...> shared_buf,
1597 // bool read_pred, // Prevent reading from out of bounds memory
1598 // bool write_pred, // Prevent from writing out of bounds
1599 // const LocalTuple<Types...>& init_val,
1600 // Func reduction_op);
1601
1602 indent() << reduction_name << ".reduce(\n";
1603
1604 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
1605 // out
1606 func_args.arg(genCall("RefTuple", data_type, gen(grop->out())));
1607 // inp
1608 func_args.arg(genCall("ConstRefTuple", data_type, gen(grop->in())));
1609 // global_work_buffer
1610 func_args.arg(genCall(
1611 "VolatilePtrTuple", data_type, "&" + varName(work_buffer) + "[0]"));
1612 // global_sync_buffer
1613 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
1614 // shared_buf
1615 func_args.arg(genCall(
1616 "PtrTuple",
1617 data_type,
1618 genCall("static_cast", ptrType(data_type), "shared_mem")));
1619 // read and write predicates
1620 TORCH_INTERNAL_ASSERT(
1621 grop->predicate() != nullptr && grop->predicate()->hasValue());
1622 const auto read_pred = genInline(grop->predicate());
1623 auto write_pred = read_pred;
1624 if (grop->writePredicate() != nullptr) {
1625 TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue());
1626 write_pred = genInline(grop->writePredicate());
1627 }
1628 func_args.arg(read_pred).arg(write_pred);
1629 // init_val
1630 func_args.arg(genCall("LocalTuple", data_type, genInline(grop->init())));
1631 // reduction_op
1632 func_args.arg(genReductionOp(op_type, out->dtype()));
1633
1634 addProfileArguments(func_args, grop);
1635
1636 indent() << kTab << func_args << ");\n";
1637 }
1638
1639 void handle(const kir::GroupedGridReduction* grouped_grop) final {
1640 const auto out = ir_utils::getTvOutput(grouped_grop);
1641 const auto domain = out->domain();
1642 TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
1643
1644 TORCH_INTERNAL_ASSERT(
1645 grouped_grop->sync_buffer()->buffer()->isA<TensorView>());
1646 const auto sync_buffer =
1647 grouped_grop->sync_buffer()->buffer()->as<TensorView>();
1648
1649 if (grouped_grop->isAllreduce()) {
1650 generateGroupedGridAllreduce(grouped_grop);
1651 return;
1652 }
1653
1654 TORCH_INTERNAL_ASSERT(
1655 grouped_grop->numExprs() == 2,
1656 "Only grouping of 2 reductions is supported. ",
1657 grouped_grop->toString());
1658
1659 const std::string flags_str = generateGridReduceTemplateFlags2(
1660 grouped_grop, grouped_grop->threadPredicate());
1661
1662 const bool persistent_sync =
1663 kernel_->summary().has_cooperative_grid_reduction;
1664
1665 // Since block-level reduction is already done, those dimensions
1666 // with tidx/y/z being true do not participate in the grid
1667 // reduction.
1668 ArgumentBuilder template_args;
1669 template_args.arg(flags_str).arg(persistent_sync);
1670
1671 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
1672
1673 // Append arguments for each reduction
1674 for (const auto i : c10::irange(grouped_grop->numExprs())) {
1675 TORCH_INTERNAL_ASSERT(
1676 grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
1677 const auto work_buffer =
1678 grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>();
1679
1680 func_args.arg(gen(grouped_grop->output(i)));
1681 func_args.arg(gen(grouped_grop->input(i)));
1682 func_args.arg(genCall(
1683 grouped_grop->output(i)->dtype(),
1684 genInline(grouped_grop->initVal(i))));
1685 func_args.arg(genReductionOp(
1686 grouped_grop->getReductionOpType(i),
1687 grouped_grop->output(i)->dtype()));
1688 func_args.arg("&").append(varName(work_buffer)).append("[0]");
1689 }
1690
1691 // The rest of the arguments are common between the reductions
1692 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
1693 func_args.arg("shared_mem");
1694 // read and write predicates
1695 TORCH_INTERNAL_ASSERT(
1696 grouped_grop->predicate() != nullptr &&
1697 grouped_grop->predicate()->hasValue());
1698 const auto read_pred = genInline(grouped_grop->predicate());
1699 func_args.arg(read_pred);
1700 if (grouped_grop->writePredicate() != nullptr) {
1701 TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
1702 func_args.arg(genInline(grouped_grop->writePredicate()));
1703 } else {
1704 func_args.arg(read_pred);
1705 }
1706
1707 func_args.arg(genInline(grouped_grop->entrance_index()));
1708 func_args.arg(genInline(grouped_grop->entrances()));
1709
1710 addProfileArguments(func_args, grouped_grop);
1711
1712 indent() << "reduction::gridReduceGroup<" << template_args << ">(\n";
1713 indent() << kTab << func_args << ");\n";
1714 }
1715
1716 void handle(const kir::GroupedGridWelford* grouped_gwop) final {
1717 if (grouped_gwop->isAllreduce()) {
1718 generateGroupedGridAllreduceWelford(grouped_gwop);
1719 return;
1720 } else {
1721 TORCH_INTERNAL_ASSERT(
1722 false, "Non-allreduce grouped grid welford is not yet supported");
1723 }
1724 }
1725
1726 // Enumerates all combinations of index values of grouped
1727 // loops. Each combination is a vector of loop index values. The
1728 // length of the vector is the number of grouped loops.
1729 //
1730 // Example 1: only one domain of extent 2 is grouped: {{0}, {1}}.
1731 // Example 2: two domains of extents 2 and 3 are grouped: {{0, 0},
1732 // {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}}
1733 std::vector<std::vector<int64_t>> getGroupedLoopIndexConcreteIntSets() {
1734 std::vector<std::vector<int64_t>> index_combinationsatoins;
1735
1736 // Initialize with an empty vector
1737 index_combinationsatoins.push_back(std::vector<int64_t>());
1738
1739 // Incrementally build a combinatorial set
1740 for (const auto loop : grouped_loops_) {
1741 const auto iter_count = loop->stop()->evaluateInt();
1742 std::vector<std::vector<int64_t>> new_combinations;
1743 // Append integers from 0 to iter_count to all the vectors built
1744 // so far
1745 for (const auto& index_vec : index_combinationsatoins) {
1746 for (int64_t i = 0; i < iter_count; ++i) {
1747 auto index_vec_appended = index_vec;
1748 index_vec_appended.push_back(i);
1749 new_combinations.push_back(index_vec_appended);
1750 }
1751 }
1752 index_combinationsatoins = std::move(new_combinations);
1753 }
1754
1755 return index_combinationsatoins;
1756 }
1757
1758 //! Returns all combinations of maps from index Vals of grouped loops to their
1759 //! conrete integers.
1760 std::vector<std::unordered_map<const Int*, int64_t>>
1761 getLoopIndexReplacementMaps() {
1762 std::vector<std::unordered_map<const Int*, int64_t>> maps;
1763
1764 if (grouped_loops_.empty()) {
1765 std::unordered_map<const Int*, int64_t> empty_map;
1766 return {empty_map};
1767 }
1768
1769 // Vector of indices of grouped loops
1770 std::vector<Int*> loop_indices;
1771 std::transform(
1772 grouped_loops_.begin(),
1773 grouped_loops_.end(),
1774 std::back_inserter(loop_indices),
1775 [](const kir::ForLoop* loop) { return loop->index()->as<Int>(); });
1776
1777 // All combinations of loop index integer values
1778 const auto index_val_sets = getGroupedLoopIndexConcreteIntSets();
1779
1780 // Create maps from loop index Vals to integers
1781 for (const auto& index_values : index_val_sets) {
1782 TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size());
1783 std::unordered_map<const Int*, int64_t> index_val_map;
1784 for (const auto i : c10::irange(loop_indices.size())) {
1785 auto loop_index = loop_indices.at(i);
1786 auto index_val = index_values.at(i);
1787 index_val_map.emplace(loop_index, index_val);
1788 }
1789 maps.emplace_back(std::move(index_val_map));
1790 }
1791
1792 return maps;
1793 }
1794
1795 void generateGroupedGridAllreduce(
1796 const kir::GroupedGridReduction* grouped_grop) {
1797 TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce());
1798
1799 // There are two dimensions of grouping: horizontal grouping and
1800 // iteration grouping. The total number of individual reductions
1801 // is the number of horizontal reductions * the extent of grouped
1802 // iterations. All of them are packed into a single grid reduction
1803 // call. The number of reductions is limited, and currently it is
1804 // simply an error if exceeded. This could be avoided by
1805 // decomposing grouped_grop into smaller groups within the
1806 // limit. TODO: Support a larger number of reductions.
1807
1808 // First, enumerate all combinations of loop index values of
1809 // grouped IterDomains. If only a single domain is grouped, this
1810 // is simply just a 1D vector of integer from 0 to extent-1. If
1811 // two domains are grouped, combinations of two integer vectors
1812 // are returned. These loop index value vectors are returned as a
1813 // map from loop index Vals to concrete int values.
1814 const auto index_replacement_maps = getLoopIndexReplacementMaps();
1815 const auto num_grouped_iterations = index_replacement_maps.size();
1816
1817 // This is also checked at the lowering validaiton time, so it
1818 // isn't strictly necessary.
1819 TORCH_INTERNAL_ASSERT(
1820 num_grouped_iterations * grouped_grop->numExprs() <=
1821 kMaxNumGroupedReductions,
1822 "Too many grouped reductions: ",
1823 grouped_grop->toString(),
1824 ". Up to ",
1825 kMaxNumGroupedReductions,
1826 " reductions are allowed.");
1827
1828 ArgumentBuilder types;
1829 ArgumentBuilder outputs;
1830 ArgumentBuilder inputs;
1831 ArgumentBuilder work_bufs;
1832 ArgumentBuilder init_vals;
1833 ArgumentBuilder reduction_ops;
1834
1835 ArgumentBuilder bool_types;
1836 ArgumentBuilder read_preds;
1837 ArgumentBuilder write_preds;
1838
1839 for (const auto expr_index : c10::irange(grouped_grop->numExprs())) {
1840 const auto data_type = grouped_grop->outputs().at(expr_index)->dtype();
1841 TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers()
1842 .at(expr_index)
1843 ->buffer()
1844 ->isA<TensorView>());
1845
1846 for (const auto& group_index :
1847 c10::irange(index_replacement_maps.size())) {
1848 // Set the index replacement map with the concrete values of
1849 // indices of grouped loops.
1850 index_replacement_map_ = index_replacement_maps.at(group_index);
1851
1852 types.arg(data_type);
1853
1854 // out
1855 outputs.arg(gen(grouped_grop->outputs().at(expr_index)));
1856
1857 // inp
1858 inputs.arg(gen(grouped_grop->inputs().at(expr_index)));
1859
1860 // global_work_buffer
1861 const auto work_buffer = grouped_grop->reduction_buffers()
1862 .at(expr_index)
1863 ->buffer()
1864 ->as<TensorView>();
1865 // Separate Work buffer is used for each reduction.
1866 auto work_buffer_offset = group_index == 0
1867 ? "0"
1868 : (genInline(grouped_grop->buffer_stride()) + " * " +
1869 std::to_string(group_index));
1870 work_bufs.arg("&")
1871 .append(varName(work_buffer))
1872 .append("[")
1873 .append(work_buffer_offset)
1874 .append("]");
1875 init_vals.arg(genInline(grouped_grop->initVal(expr_index)));
1876
1877 reduction_ops.arg(genReductionOp(
1878 grouped_grop->getReductionOpType(expr_index),
1879 grouped_grop->output(expr_index)->dtype()));
1880
1881 // read and write predicates
1882 bool_types.arg("bool");
1883 // Same argument for all inputs. Different predicates would be
1884 // used when grouping is done across iterations
1885 TORCH_INTERNAL_ASSERT(
1886 grouped_grop->predicate() != nullptr &&
1887 grouped_grop->predicate()->hasValue());
1888 const auto read_pred = genInline(grouped_grop->predicate());
1889 read_preds.arg(read_pred);
1890 if (grouped_grop->writePredicate() != nullptr) {
1891 TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
1892 write_preds.arg(genInline(grouped_grop->writePredicate()));
1893 } else {
1894 write_preds.arg(read_pred);
1895 }
1896
1897 index_replacement_map_.clear();
1898 }
1899 }
1900
1901 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
1902 func_args.arg(genCall("RefTuple", types, outputs));
1903 func_args.arg(genCall("ConstRefTuple", types, inputs));
1904 func_args.arg(genCall("VolatilePtrTuple", types, work_bufs));
1905 func_args.arg(genCall("LocalTuple", types, init_vals));
1906
1907 // global_sync_buffer
1908 const auto sync_buffer =
1909 grouped_grop->sync_buffer()->buffer()->as<TensorView>();
1910 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
1911
1912 // shared_buf
1913 func_args.arg("shared_mem");
1914
1915 func_args.arg(genCall("LocalTuple", bool_types, read_preds));
1916 func_args.arg(genCall("LocalTuple", bool_types, write_preds));
1917
1918 addProfileArguments(func_args, grouped_grop);
1919
1920 func_args.arg(reduction_ops);
1921
1922 indent() << genFusedReductionName(ir_utils::getTvOutput(grouped_grop))
1923 << ".reduceGroup(\n";
1924 indent() << kTab << func_args << ");\n";
1925 }
1926
1927 // Mostly the same as the grouped grid redution version
1928 void generateGroupedGridAllreduceWelford(
1929 const kir::GroupedGridWelford* grouped_gwop) {
1930 TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce());
1931
1932 const auto index_replacement_maps = getLoopIndexReplacementMaps();
1933 const auto num_grouped_iterations = index_replacement_maps.size();
1934
1935 // This is also checked at the lowering validaiton time, so it
1936 // isn't strictly necessary.
1937 TORCH_INTERNAL_ASSERT(
1938 num_grouped_iterations * grouped_gwop->numExprs() <=
1939 kMaxNumGroupedReductions,
1940 "Too many grouped reductions: ",
1941 grouped_gwop->toString(),
1942 ". Up to ",
1943 kMaxNumGroupedReductions,
1944 " reductions are allowed.");
1945
1946 ArgumentBuilder data_types;
1947 ArgumentBuilder index_types;
1948
1949 // Note that the data type of var and avg and that of N are the
1950 // same with all the welford ops since we only support
1951 // grouping of iterations.
1952 const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype();
1953 const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype();
1954
1955 std::array<ArgumentBuilder, 3> out_args;
1956 std::array<ArgumentBuilder, 3> in_args;
1957 std::array<ArgumentBuilder, 3> init_args;
1958 std::array<ArgumentBuilder, 3> work_bufs;
1959
1960 ArgumentBuilder bool_types;
1961 ArgumentBuilder read_preds;
1962 ArgumentBuilder write_preds;
1963
1964 for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
1965 const auto& output = grouped_gwop->outputVals().at(expr_index);
1966 const auto& input = grouped_gwop->inputVals().at(expr_index);
1967 const auto& init = grouped_gwop->initVals().at(expr_index);
1968
1969 for (const auto& group_index :
1970 c10::irange(index_replacement_maps.size())) {
1971 // Set the index replacement map with the concrete values of
1972 // indices of grouped loops.
1973 index_replacement_map_ = index_replacement_maps.at(group_index);
1974
1975 data_types.arg(data_type);
1976 index_types.arg(index_type);
1977
1978 auto work_buffer_offset = group_index == 0
1979 ? "0"
1980 : (genInline(grouped_gwop->buffer_stride()) + " * " +
1981 std::to_string(group_index));
1982
1983 // Setup arguments for avg, var, and N
1984 for (const auto i : c10::irange(3)) {
1985 out_args[i].arg(gen(output.get(i)));
1986 in_args[i].arg(gen(input.get(i)));
1987 init_args[i].arg(gen(init.get(i)));
1988 const auto work_buffer = grouped_gwop->reduction_buffers()[i]
1989 .at(expr_index)
1990 ->buffer()
1991 ->as<TensorView>();
1992 work_bufs[i]
1993 .arg("&")
1994 .append(varName(work_buffer))
1995 .append("[")
1996 .append(work_buffer_offset)
1997 .append("]");
1998 }
1999
2000 // read and write predicates
2001 bool_types.arg("bool");
2002 // Same argument for all inputs. Different predicates would be
2003 // used when grouping is done across iterations
2004 TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr);
2005 TORCH_INTERNAL_ASSERT(
2006 grouped_gwop->predicate() != nullptr &&
2007 grouped_gwop->predicate()->hasValue());
2008 const auto read_pred = genInline(grouped_gwop->predicate());
2009 read_preds.arg(read_pred);
2010 if (grouped_gwop->writePredicate() != nullptr) {
2011 TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue());
2012 write_preds.arg(genInline(grouped_gwop->writePredicate()));
2013 } else {
2014 write_preds.arg(read_pred);
2015 }
2016
2017 index_replacement_map_.clear();
2018 }
2019 }
2020
2021 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
2022 // output
2023 func_args.arg(genCall("RefTuple", data_types, out_args[0]));
2024 func_args.arg(genCall("RefTuple", data_types, out_args[1]));
2025 func_args.arg(genCall("RefTuple", index_types, out_args[2]));
2026 // input
2027 func_args.arg(genCall("ConstRefTuple", data_types, in_args[0]));
2028 func_args.arg(genCall("ConstRefTuple", data_types, in_args[1]));
2029 func_args.arg(genCall("ConstRefTuple", index_types, in_args[2]));
2030 // init
2031 func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
2032 func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
2033 func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
2034 // work buffer
2035 func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
2036 func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
2037 func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2]));
2038 // global_sync_buffer
2039 const auto sync_buffer =
2040 grouped_gwop->sync_buffer()->buffer()->as<TensorView>();
2041 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
2042
2043 // shared_buf
2044 ArgumentBuilder smem_buffer_args;
2045 smem_buffer_args.arg(
2046 genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
2047 smem_buffer_args.arg(
2048 genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
2049 smem_buffer_args.arg(
2050 genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
2051 func_args.arg(genCall(
2052 "PtrTuple",
2053 ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type),
2054 smem_buffer_args));
2055
2056 func_args.arg(genCall("LocalTuple", bool_types, read_preds));
2057 func_args.arg(genCall("LocalTuple", bool_types, write_preds));
2058
2059 addProfileArguments(func_args, grouped_gwop);
2060
2061 ArgumentBuilder func_template_args;
2062 func_template_args.arg(
2063 grouped_gwop->numExprs() * index_replacement_maps.size());
2064 func_template_args.arg(data_type);
2065 func_template_args.arg(index_type);
2066
2067 indent() << genCall(
2068 genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) +
2069 ".welfordGroup",
2070 func_template_args,
2071 func_args)
2072 << ";\n";
2073 }
2074
2075 void handle(const kir::GridBroadcast* grop) final {
2076 const auto bop = grop->broadcast_op();
2077 TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
2078
2079 const ParallelTypeBitmap parallel_types =
2080 kernel_->summary().broadcast_parallel_types.at(bop);
2081
2082 TORCH_INTERNAL_ASSERT(
2083 parallel_types.hasBID(),
2084 "GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types");
2085
2086 TORCH_INTERNAL_ASSERT(
2087 grop->broadcast_buffer()->buffer()->isA<TensorView>());
2088 TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
2089 const auto work_buffer =
2090 grop->broadcast_buffer()->buffer()->as<TensorView>();
2091 const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
2092
2093 std::stringstream flags_str;
2094 for (const ParallelType pt : kParallelTypeThreads) {
2095 const bool parallel_bcast = parallel_types.get(pt);
2096 if (pt != kParallelTypeThreads[0]) {
2097 flags_str << ", ";
2098 }
2099 flags_str << (parallel_bcast ? "true" : "false");
2100 }
2101
2102 // Since block-level broadcast has not necessarily been performed before
2103 // this function call, so grid broadcast may be broadcasting across both
2104 // the grid and the block level.
2105 indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n";
2106 indent() << kTab << gen(bop->out()) << ",\n";
2107 indent() << kTab << gen(bop->in()) << ",\n";
2108 indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
2109 indent() << kTab << varName(sync_buffer) << ",\n";
2110 TORCH_INTERNAL_ASSERT(
2111 grop->predicate() != nullptr && grop->predicate()->hasValue());
2112 indent() << kTab << genInline(grop->predicate()) << ");\n";
2113 }
2114
2115 void handle(const kir::GridWelford* gwop) final {
2116 const auto wop = gwop->welford_op();
2117 TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>());
2118
2119 const auto out = wop->out()->as<kir::TensorIndex>();
2120 const auto domain = out->view()->domain();
2121 TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
2122
2123 const auto data_type = out->dtype();
2124
2125 TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA<TensorView>());
2126 TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA<TensorView>());
2127
2128 const auto avg_buffer = gwop->avg_buffer()->buffer()->as<TensorView>();
2129 const auto var_buffer = gwop->var_buffer()->buffer()->as<TensorView>();
2130 const auto n_buffer = gwop->N_buffer()->buffer()->as<TensorView>();
2131 const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>();
2132
2133 if (wop->isAllreduce()) {
2134 generateGridAllreduce(gwop);
2135 return;
2136 }
2137
2138 const bool persistent_sync =
2139 kernel_->summary().has_cooperative_grid_reduction;
2140
2141 const std::string flags_str =
2142 generateGridReduceTemplateFlags(wop, gwop->threadPredicate());
2143
2144 // Since block-level reduction is already done, those dimensions
2145 // with tidx/y/z being true do not participate in the grid reduction.
2146 indent() << "welford::gridWelford<" << flags_str << ", "
2147 << (persistent_sync ? "true" : "false") << ">(\n";
2148 indent() << kTab << gen(wop->outAvg()) << ",\n";
2149 indent() << kTab << gen(wop->outVar()) << ",\n";
2150 indent() << kTab << gen(wop->outN()) << ",\n";
2151 if (domain->hasBlockReduction()) {
2152 indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n";
2153 indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n";
2154 indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
2155 block_reduce_name_++;
2156 } else {
2157 indent() << kTab << gen(wop->inAvg()) << ",\n";
2158 TORCH_INTERNAL_ASSERT(
2159 wop->inVar() != nullptr, "Welford var input nullptr not allowed");
2160 indent() << kTab << "(" << wop->outVar()->dtype() << ")"
2161 << gen(wop->inVar()) << ",\n";
2162 indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN())
2163 << ",\n";
2164 }
2165 indent() << kTab << "&" << varName(avg_buffer) << "[0],\n";
2166 indent() << kTab << "&" << varName(var_buffer) << "[0],\n";
2167 indent() << kTab << "&" << varName(n_buffer) << "[0],\n";
2168 indent() << kTab << varName(sync_buffer) << ",\n";
2169 indent() << kTab << "reinterpret_cast<" << data_type
2170 << "*>(shared_mem_avg),\n";
2171 indent() << kTab << "reinterpret_cast<" << data_type
2172 << "*>(shared_mem_var),\n";
2173 indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype()
2174 << "*>(shared_mem_n),\n";
2175 TORCH_INTERNAL_ASSERT(
2176 gwop->predicate() != nullptr && gwop->predicate()->hasValue());
2177 auto read_pred = genInline(gwop->predicate());
2178 indent() << kTab << read_pred << ",\n";
2179 if (gwop->writePredicate() != nullptr) {
2180 TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue());
2181 auto write_pred = genInline(gwop->writePredicate());
2182 indent() << kTab << write_pred << ",\n";
2183 } else {
2184 indent() << kTab << read_pred << ",\n";
2185 }
2186 // TODO : init value support or remove.
2187 indent() << kTab << data_type << "(0),\n";
2188 indent() << kTab << genInline(gwop->entrance_index()) << ",\n";
2189 indent() << kTab << genInline(gwop->entrances());
2190 code_ << ");\n";
2191 }
2192
2193 void generateGridAllreduce(const kir::GridWelford* gwop) {
2194 const auto wop = gwop->welford_op();
2195 TORCH_INTERNAL_ASSERT(wop->isAllreduce());
2196
2197 const auto out = wop->out()->as<kir::TensorIndex>();
2198
2199 const auto data_type = wop->outAvg()->dtype();
2200 const auto index_type = wop->outN()->dtype();
2201 TORCH_INTERNAL_ASSERT(wop->outAvg()->dtype() == wop->outVar()->dtype());
2202
2203 ArgumentBuilder data_type_args;
2204 data_type_args.arg(data_type).arg(data_type).arg(index_type);
2205
2206 const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>();
2207
2208 const auto reduction_name = genFusedReductionName(out->view());
2209
2210 // template <typename Func, typename... Types>
2211 // __device__ __inline__ void reduce(
2212 // RefTuple<Types...> out,
2213 // const LocalTuple<Types...>& inp,
2214 // VolatilePtrTuple<Types...> global_work_buffer,
2215 // int64_t* global_sync_buffer, // Allocated as product of all
2216 // // non-participating Grid dimension
2217 // PtrTuple<Types...> shared_buf,
2218 // bool read_pred, // Prevent reading from out of bounds memory
2219 // bool write_pred, // Prevent from writing out of bounds
2220 // const LocalTuple<Types...>& init_val,
2221 // Func reduction_op);
2222
2223 ArgumentBuilder out_args;
2224 out_args.arg(gen(wop->outAvg()));
2225 out_args.arg(gen(wop->outVar()));
2226 out_args.arg(gen(wop->outN()));
2227
2228 ArgumentBuilder in_args;
2229 in_args.arg(gen(wop->inAvg()));
2230 if (wop->inVar() != nullptr) {
2231 in_args.arg(gen(wop->inVar()));
2232 } else {
2233 in_args.arg("(").append(data_type).append(")0");
2234 }
2235 in_args.arg(gen(wop->inN()));
2236
2237 ArgumentBuilder init_args;
2238 init_args.arg(gen(wop->initAvg()));
2239 init_args.arg(gen(wop->initVar()));
2240 init_args.arg(gen(wop->initN()));
2241
2242 ArgumentBuilder work_buffer_args;
2243 work_buffer_args.arg("&")
2244 .append(varName(gwop->avg_buffer()->buffer()->as<TensorView>()))
2245 .append("[0]");
2246 work_buffer_args.arg("&")
2247 .append(varName(gwop->var_buffer()->buffer()->as<TensorView>()))
2248 .append("[0]");
2249 work_buffer_args.arg("&")
2250 .append(varName(gwop->N_buffer()->buffer()->as<TensorView>()))
2251 .append("[0]");
2252
2253 ArgumentBuilder smem_buffer_args;
2254 smem_buffer_args.arg(
2255 genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
2256 smem_buffer_args.arg(
2257 genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
2258 smem_buffer_args.arg(
2259 genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
2260
2261 ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
2262 // out
2263 func_args.arg(genCall("RefTuple", data_type_args, out_args));
2264 // inp
2265 func_args.arg(genCall("ConstRefTuple", data_type_args, in_args));
2266 // global_work_buffer
2267 func_args.arg(
2268 genCall("VolatilePtrTuple", data_type_args, work_buffer_args));
2269 // global_sync_buffer
2270 func_args.arg("&").append(varName(sync_buffer)).append("[0]");
2271 // shared_buf
2272 func_args.arg(genCall("PtrTuple", data_type_args, smem_buffer_args));
2273 // read and write predicates
2274 TORCH_INTERNAL_ASSERT(
2275 gwop->predicate() != nullptr && gwop->predicate()->hasValue());
2276 const auto read_pred = genInline(gwop->predicate());
2277 auto write_pred = read_pred;
2278 if (gwop->writePredicate() != nullptr) {
2279 TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue());
2280 write_pred = genInline(gwop->writePredicate());
2281 }
2282 func_args.arg(read_pred).arg(write_pred);
2283 // init_val
2284 func_args.arg(genCall("LocalTuple", data_type_args, init_args));
2285 // reduction_op
2286 func_args.arg(genTemplate(
2287 "welfordCombine", ArgumentBuilder().arg(data_type).arg(index_type)));
2288
2289 indent() << reduction_name << ".reduce(\n";
2290 indent() << kTab << func_args << ");\n";
2291 }
2292
2293 void handle(const kir::AllocateFusedReduction* alloc_fused_reduction) final {
2294 // See the runtime file of the fused reduction
2295 enum class ReductionParallelTypeState { Reduce, Iter, Pred, Inactive };
2296
2297 using ReductionParallelTypeStateArray =
2298 ParallelTypeMap<ReductionParallelTypeState>;
2299
2300 ReductionParallelTypeStateArray states(
2301 ReductionParallelTypeState::Inactive);
2302
2303 for (const ParallelType pt : kParallelTypeThreads) {
2304 // It may be better to predicate grid reductions on dimensions they don't
2305 // actively use, however since that should generally be discouraged (they
2306 // should be part of the iter portion of the operation, or they should be
2307 // predciated out) we're just going to assume they're part of the iter
2308 // dimension. This would cause more communication than strictly necessary
2309 // but should not be a common use case.
2310 auto pt_dim = kernel_->summary().parallel_dimension_map_.get(pt);
2311 if (pt_dim == nullptr || pt_dim->isOneInt()) {
2312 continue;
2313 }
2314 // Initialize pt_dim if used to an iter dimension. It may change to a
2315 // reduction or predicated dimension later.
2316 states[pt] = ReductionParallelTypeState::Iter;
2317 }
2318
2319 for (auto id : alloc_fused_reduction->out()->view()->domain()->domain()) {
2320 auto pt = id->getParallelType();
2321 if (isParallelTypeThread(pt)) {
2322 auto state = id->isReduction() ? ReductionParallelTypeState::Reduce
2323 : ReductionParallelTypeState::Iter;
2324 states[pt] = state;
2325 }
2326 }
2327
2328 for (const auto predicated_pt : alloc_fused_reduction->threadPredicate()) {
2329 auto& state = states[predicated_pt];
2330 TORCH_INTERNAL_ASSERT(
2331 state != ReductionParallelTypeState::Reduce,
2332 "Invalid thread predication: ",
2333 predicated_pt);
2334 state = ReductionParallelTypeState::Pred;
2335 }
2336
2337 ArgumentBuilder flags;
2338 for (auto pt : kParallelTypeThreads) {
2339 flags.arg(static_cast<int>(states[pt]));
2340 }
2341
2342 // Persistent
2343 flags.arg(true);
2344
2345 // Broadcast is fused
2346 flags.arg(true);
2347
2348 const auto reduction_name =
2349 genFusedReductionName(alloc_fused_reduction->out()->view());
2350
2351 indent() << genTemplate("fused_reduction::ParallelReduce", flags) << " "
2352 << reduction_name << ";\n";
2353 }
2354
2355 void handleScope(const kir::Scope& scope) {
2356 for (auto expr : scope.exprs()) {
2357 OptOutConstDispatch::handle(expr);
2358 }
2359 }
2360
2361 void handleTrivialLoop(const kir::ForLoop* loop) {
2362 if (loop->vectorize()) {
2363 vectorize_scope_ = true;
2364 }
2365 handleScope(loop->body());
2366 if (loop->vectorize()) {
2367 vectorize_scope_ = false;
2368 }
2369 }
2370
2371 void handle(const GroupedReductionOp* grouped_rop) final {
2372 for (const auto i : c10::irange(grouped_rop->numExprs())) {
2373 TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA<kir::TensorIndex>());
2374
2375 const auto output = grouped_rop->output(i)->as<kir::TensorIndex>();
2376 const auto input = grouped_rop->input(i)->as<kir::TensorIndex>();
2377 const auto domain = output->view()->domain();
2378 const auto op_type = grouped_rop->getReductionOpType(i);
2379
2380 const bool has_block_reduce = domain->hasBlockReduction();
2381 const bool has_grid_reduce = domain->hasGridReduction();
2382
2383 TORCH_INTERNAL_ASSERT(
2384 !has_grid_reduce,
2385 "GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. ",
2386 grouped_rop->toString());
2387
2388 if (!has_block_reduce) {
2389 genSerialReduction(output, input, op_type);
2390 } else if (
2391 auto reduction_id =
2392 ir_utils::getMaybeWarpReductionDim(output, input)) {
2393 genWarpReduction(
2394 output,
2395 input,
2396 grouped_rop->initVal(i),
2397 op_type,
2398 grouped_rop->predicate());
2399 } else {
2400 genBlockReduction(
2401 output,
2402 input,
2403 grouped_rop->initVal(i),
2404 op_type,
2405 grouped_rop->predicate(),
2406 grouped_rop->writePredicate());
2407 }
2408 }
2409 }
2410
2411 void handle(const GroupedWelfordOp* grouped_wop) final {
2412 TORCH_INTERNAL_ASSERT(
2413 false,
2414 "Should not reach here as grouped welford is only enabled for grid welford,",
2415 " which is handled by its own handler");
2416 }
2417
2418 //! True if loop is grouped. The IterDomain of the loop must have
2419 //! ParallelType::Group, but it isn't sufficient as the loop may be
2420 //! for an initialization expression, for which the loop shold not
2421 //! be grouped. Make sure a GroupedGridReduction is found.
2422 bool isGroupedLoop(const kir::ForLoop* loop) {
2423 if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
2424 return false;
2425 }
2426 return ExprFinder::exists(
2427 loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
2428 }
2429
2430 void handle(const kir::ForLoop* loop) final {
2431 if (loop->isTrivial()) {
2432 handleTrivialLoop(loop);
2433 return;
2434 }
2435
2436 // If a loop is grouped, no loop is created, but it isn't
2437 // considered trivial as the loop trip count is not one.
2438 if (isGroupedLoop(loop)) {
2439 grouped_loops_.push_back(loop);
2440 handleScope(loop->body());
2441 grouped_loops_.pop_back();
2442 return;
2443 }
2444
2445 const auto gen_index = gen(loop->index());
2446 const auto gen_start = genInline(loop->start());
2447 const auto gen_stop = genInline(loop->stop());
2448 const auto gen_step = genInline(loop->step());
2449
2450 std::stringstream step_code;
2451 if (loop->step()->isOneInt()) {
2452 step_code << "++" << gen_index;
2453 } else {
2454 step_code << gen_index << " += " << gen_step;
2455 }
2456 if (loop->isUnrolled()) {
2457 indent() << "#pragma unroll\n";
2458 } else {
2459 indent() << "#pragma unroll 1\n";
2460 }
2461
2462 indent() << "for(nvfuser_index_t " << gen_index;
2463 if (loop->iter_domain()->isParallelized()) {
2464 code_ << " = " << gen_start << "; ";
2465 } else {
2466 // Do not start at the start of the ID when not parallelized. Instead,
2467 // start at 0. Predicates will protect buffers between 0 and ID->start(),
2468 // however if we started at ID->start and extent == ID->start, we could
2469 // have a "degenerate" loop (loop with no iterations). It may not be an
2470 // issue to have a 0-sized loop, but all potential consequences haven't
2471 // been covered. One example is WAR analysis which could incorrectly think
2472 // a barrier inside a 0-sized loop actually provides protection.
2473 code_ << " = 0; ";
2474 }
2475 code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") ";
2476 startBlock(true);
2477 handleScope(loop->body());
2478 endBlock();
2479 }
2480
2481 void handle(const kir::IfThenElse* ite) final {
2482 auto conditional = ite->predicate()->value();
2483 if (conditional->isConst()) {
2484 // If the conditional is a constant, then the IfThenElse is not required
2485 if (conditional->value().value()) {
2486 handleScope(ite->thenBody());
2487 } else {
2488 handleScope(ite->elseBody());
2489 }
2490 return;
2491 }
2492
2493 indent() << "if (" << genInline(conditional) << ") ";
2494
2495 // "then" block
2496 startBlock(true);
2497 handleScope(ite->thenBody());
2498
2499 // "else" block (optional)
2500 if (ite->hasElse()) {
2501 endBlock(" else ");
2502 startBlock(true);
2503 handleScope(ite->elseBody());
2504 }
2505
2506 endBlock();
2507 }
2508
2509 void handle(const kir::Allocate* alloc) final {
2510 const auto buffer_dtype = alloc->buffer()->dtype();
2511
2512 TORCH_INTERNAL_ASSERT(alloc->buffer() != nullptr);
2513 alloc_map_.emplace(alloc->buffer(), alloc);
2514
2515 if (!alloc->buffer()->isA<TensorView>()) {
2516 indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n";
2517 return;
2518 }
2519
2520 const auto tv = alloc->buffer()->as<TensorView>();
2521
2522 const auto size = alloc->size();
2523 TORCH_INTERNAL_ASSERT(size != nullptr);
2524
2525 if (alloc->alias() != nullptr) {
2526 // Allocate alias another Allocate stmt
2527 const auto alias_tv = alloc->alias()->buffer()->as<TensorView>();
2528 indent() << "// Alias Allocation - " << alloc->memoryType() << "\n";
2529 indent() << "auto& " << varName(tv) << " = " << varName(alias_tv)
2530 << ";\n";
2531
2532 } else {
2533 // Standard Memory Allocation
2534 switch (tv->getMemoryType()) {
2535 case MemoryType::Global:
2536 indent() << "// Allocate global tensor " << varName(tv) << "\n";
2537 break;
2538 case MemoryType::Shared:
2539 // Align Offset Position
2540 indent() << "smem_offset = alignBufferSize(smem_offset, "
2541 // Always align to 128b / 16B
2542 << 16 << ");\n";
2543 // Shared Memory Pointer
2544 indent() << buffer_dtype << "* " << varName(tv)
2545 << " = reinterpret_cast<" << buffer_dtype << "*>"
2546 << "(array + smem_offset);\n";
2547 // Increment Offset Position
2548 indent() << "smem_offset += (" << genInline(size) << " * sizeof("
2549 << buffer_dtype << "));\n";
2550 break;
2551 case MemoryType::Local: {
2552 auto va = kernel_->summary().vectorized_accesses;
2553 if (va.find(tv) != va.end()) {
2554 indent() << "Array<" << buffer_dtype << ", " << genInline(size)
2555 << ", " << va.at(tv) << "> " << varName(tv) << ";\n";
2556 } else {
2557 indent() << buffer_dtype << " " << varName(tv) << "["
2558 << genInline(size) << "];\n";
2559 }
2560 } break;
2561 default:
2562 TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
2563 }
2564 }
2565 }
2566
2567 void handle(const kir::BlockSync* sync) final {
2568 // Use a custom synchronization method if enabled
2569 if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
2570 indent() << "block_sync::sync();\n";
2571 } else {
2572 indent() << "__barrier_sync(0);\n";
2573 }
2574 }
2575
2576 void handle(const kir::CpAsyncWait* cpasync_wait) final {
2577 if (cpasync_wait->keepStages() > 0) {
2578 // Perform partial sync, see comment on kir::CpAsyncWait.
2579 indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages()
2580 << ">();\n";
2581 } else {
2582 // Perform sync all, see comment on kir::CpAsyncWait.
2583 indent() << "Ampere::cpAsyncBarrier();\n";
2584 }
2585 }
2586
2587 void handle(const kir::CpAsyncCommit* cpasync_wait) final {
2588 // Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit.
2589 indent() << "Ampere::cpAsyncCommit();\n";
2590 }
2591
2592 void handle(const kir::GridSync* sync) final {
2593 // Use a custom synchronization method if enabled
2594 bool bidx = sync->syncDims().get(ParallelType::BIDx);
2595 bool bidy = sync->syncDims().get(ParallelType::BIDy);
2596 bool bidz = sync->syncDims().get(ParallelType::BIDz);
2597
2598 ArgumentBuilder sync_call_template_parms;
2599 sync_call_template_parms.arg(bidx).arg(bidy).arg(bidz).arg(true);
2600
2601 auto sync_idx = genCall(
2602 "index_utils::maskedOffset",
2603 ArgumentBuilder().arg(!bidx).arg(!bidy).arg(!bidz),
2604 ArgumentBuilder().arg("blockIdx").arg("gridDim"));
2605
2606 auto sync_segment_size = genCall(
2607 "index_utils::maskedSize",
2608 ArgumentBuilder().arg(bidx).arg(bidy).arg(bidz),
2609 ArgumentBuilder().arg("gridDim"));
2610
2611 ArgumentBuilder sync_call_args;
2612 sync_call_args.arg(varName(sync->syncBuffer()))
2613 .append("[")
2614 .append(sync_idx)
2615 .append("]");
2616 sync_call_args.arg(sync_segment_size);
2617
2618 auto sync_call =
2619 genCall("grid_sync::sync", sync_call_template_parms, sync_call_args);
2620
2621 indent() << sync_call << ";\n";
2622 }
2623
2624 void handle(const kir::InitMagicZero*) final {
2625 indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
2626 }
2627
2628 void handle(const kir::UpdateMagicZero*) final {
2629 indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
2630 }
2631
2632 void handle(const kir::Swizzle2DInt* swizzle_2d) final {
2633 TORCH_INTERNAL_ASSERT(print_inline_);
2634 TORCH_INTERNAL_ASSERT(
2635 swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle,
2636 "Swizzle type undefined.");
2637 if (print_inline_) {
2638 code_ << swizzle_2d->swizzleType() << "({" << gen(swizzle_2d->inX())
2639 << "," << gen(swizzle_2d->inY()) << "} , "
2640 << "{" << gen(swizzle_2d->extentX()) << ","
2641 << gen(swizzle_2d->extentY()) << "})";
2642 }
2643 }
2644
2645 void handle(const kir::IntPair* int_pair) final {
2646 const auto def = int_pair->definition();
2647 TORCH_INTERNAL_ASSERT(
2648 def != nullptr, "no support for un-inlined int pair yet.");
2649 code_ << gen(def);
2650 }
2651
2652 void handle(const kir::PairSelect* pair_select) final {
2653 if (print_inline_) {
2654 code_ << gen(pair_select->in());
2655 } else {
2656 indent() << gen(pair_select->out()) << " = " << gen(pair_select->in());
2657 }
2658
2659 switch (pair_select->selection()) {
2660 case kir::PairSelect::Selection::X:
2661 code_ << ".x";
2662 break;
2663 case kir::PairSelect::Selection::Y:
2664 code_ << ".y";
2665 break;
2666 default:
2667 TORCH_INTERNAL_ASSERT(false, "unknown select")
2668 break;
2669 }
2670
2671 if (!print_inline_) {
2672 code_ << ";\n";
2673 }
2674 }
2675
2676 private:
2677 std::stringstream code_;
2678 const kir::Kernel* kernel_;
2679 int block_nest_level_ = 0;
2680 int block_reduce_name_ = 0;
2681 bool print_inline_ = false;
2682
2683 // Mark when we are inside of a vectorized for-loop
2684 bool vectorize_scope_ = false;
2685 //! Keep track of Allocate node for Val. Used to determine if Val
2686 //! should be inlined.
2687 std::unordered_map<const Val*, const kir::Allocate*> alloc_map_;
2688 //! Keep track of grouped loops
2689 std::deque<const kir::ForLoop*> grouped_loops_;
2690 //! Used to replace symbolic indices with concrete values
2691 std::unordered_map<const Int*, int64_t> index_replacement_map_;
2692};
2693
2694} // namespace
2695
2696std::string generateCudaKernel(
2697 const kir::Kernel* kernel,
2698 const std::string& kernel_name) {
2699 FUSER_PERF_SCOPE("generateCudaKernel");
2700 return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
2701}
2702
2703} // namespace codegen
2704} // namespace cuda
2705} // namespace fuser
2706} // namespace jit
2707} // namespace torch
2708