1 | #include <codegen.h> |
2 | #include <expr_evaluator.h> |
3 | #include <instrumentation.h> |
4 | #include <kernel_expr_evaluator.h> |
5 | #include <kernel_ir.h> |
6 | #include <kernel_ir_dispatch.h> |
7 | #include <scheduler/mma_utils.h> |
8 | #include <type.h> |
9 | #include <utils.h> |
10 | |
11 | #include <array> |
12 | #include <cmath> |
13 | #include <sstream> |
14 | #include <vector> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | namespace fuser { |
19 | namespace cuda { |
20 | namespace codegen { |
21 | |
22 | namespace { |
23 | |
24 | std::string ptrType(DataType dt) { |
25 | std::stringstream ss; |
26 | ss << dt << "*" ; |
27 | return ss.str(); |
28 | } |
29 | |
30 | //! Utility class to build an argument list |
31 | class ArgumentBuilder { |
32 | public: |
33 | //! Build an argument list where each argument is separated with a comma |
34 | ArgumentBuilder() = default; |
35 | |
36 | //! Build an argument list where each argument has its own line |
37 | ArgumentBuilder(int indent_level, const char* tab) { |
38 | std::stringstream ss; |
39 | for (const auto i : c10::irange(indent_level)) { |
40 | (void)i; // Suppress unused variable warning |
41 | ss << tab; |
42 | } |
43 | sep_ = ",\n" + ss.str(); |
44 | } |
45 | |
46 | //! Add a new argument |
47 | template <typename T> |
48 | ArgumentBuilder& arg(const T& x) { |
49 | addSeparator(); |
50 | return append(x); |
51 | } |
52 | |
53 | //! Append to the last argument |
54 | template <typename T> |
55 | ArgumentBuilder& append(const T& arg) { |
56 | ss_ << arg; |
57 | return *this; |
58 | } |
59 | |
60 | //! Get a string of the argument list |
61 | std::string str() const { |
62 | return ss_.str(); |
63 | } |
64 | |
65 | friend std::ostream& operator<<(std::ostream& os, const ArgumentBuilder& ab) { |
66 | return os << ab.str(); |
67 | } |
68 | |
69 | private: |
70 | void addSeparator() { |
71 | if (ss_.tellp() != 0) { |
72 | ss_ << sep_; |
73 | } |
74 | } |
75 | |
76 | private: |
77 | std::string sep_ = ", " ; |
78 | std::stringstream ss_; |
79 | }; |
80 | |
81 | //! Append to the last argument |
82 | template <> |
83 | ArgumentBuilder& ArgumentBuilder::append<bool>(const bool& arg) { |
84 | ss_ << (arg ? "true" : "false" ); |
85 | return *this; |
86 | } |
87 | |
88 | //! Returns "template_name<template_arg>" |
89 | template <typename TemplateNameT, typename TemplateArgT> |
90 | std::string genTemplate( |
91 | const TemplateNameT& template_name, |
92 | const TemplateArgT& template_arg) { |
93 | std::stringstream ss; |
94 | ss << template_name << "<" << template_arg << ">" ; |
95 | return ss.str(); |
96 | } |
97 | |
98 | //! Returns "func_name(func_arg)" |
99 | template <typename FuncNameT, typename FuncArgT> |
100 | std::string genCall(const FuncNameT& func_name, const FuncArgT& func_arg) { |
101 | std::stringstream ss; |
102 | ss << func_name << "(" << func_arg << ")" ; |
103 | return ss.str(); |
104 | } |
105 | |
106 | //! Returns "func_name<template_arg>(func_arg)" |
107 | template <typename FuncNameT, typename TemplateArgT, typename FuncArgT> |
108 | std::string genCall( |
109 | const FuncNameT& func_name, |
110 | const TemplateArgT& template_arg, |
111 | const FuncArgT& func_arg) { |
112 | std::stringstream ss; |
113 | ss << func_name << "<" << template_arg << ">(" << func_arg << ")" ; |
114 | return ss.str(); |
115 | } |
116 | |
117 | //! A utility class to check if an expression of a particular type exists |
118 | class ExprFinder : kir::ConstIrVisitor { |
119 | public: |
120 | //! True if expr or any of its nested expressions is included in |
121 | //! expr_types |
122 | static bool exists( |
123 | const Expr* expr, |
124 | const std::unordered_set<ExprType>& expr_types) { |
125 | ExprFinder finder(expr_types); |
126 | finder.handle(std::vector<const Expr*>{expr}); |
127 | return finder.is_found_; |
128 | } |
129 | |
130 | private: |
131 | ExprFinder(const std::unordered_set<ExprType>& expr_types) |
132 | : expr_types_(expr_types) {} |
133 | |
134 | using kir::ConstIrVisitor::handle; |
135 | |
136 | void handle(const Expr* expr) final { |
137 | if (expr_types_.find(expr->etype()) != expr_types_.end()) { |
138 | is_found_ = true; |
139 | return; |
140 | } |
141 | kir::ConstIrVisitor::handle(expr); |
142 | } |
143 | |
144 | private: |
145 | const std::unordered_set<ExprType>& expr_types_; |
146 | bool is_found_ = false; |
147 | }; |
148 | |
149 | class CudaKernelGenerator : private OptOutConstDispatch { |
150 | static constexpr const char* kTab = " " ; |
151 | |
152 | public: |
153 | static std::string generateKernelDefinition( |
154 | const kir::Kernel* kernel, |
155 | const std::string& kernel_name) { |
156 | CudaKernelGenerator codegen(kernel); |
157 | codegen.genDeclaration(kernel_name); |
158 | codegen.startBlock(); |
159 | codegen.genPrologue(); |
160 | codegen.genBody(); |
161 | codegen.endBlock(); |
162 | TORCH_CHECK(codegen.block_nest_level_ == 0); |
163 | return codegen.code_.str(); |
164 | } |
165 | |
166 | private: |
167 | explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) { |
168 | initStringStreamFormat(code_); |
169 | } |
170 | |
171 | void initStringStreamFormat(std::stringstream& ss) { |
172 | const int digits = std::numeric_limits<Double::ScalarType>::max_digits10; |
173 | ss.imbue(std::locale("C" )); |
174 | ss << std::scientific << std::setprecision(digits); |
175 | } |
176 | |
177 | // Generates the kernel function declaration |
178 | void genDeclaration(const std::string& kernel_name) { |
179 | const auto& kernel_summary = kernel_->summary(); |
180 | |
181 | code_ << "__global__ void " << kernel_name << "(" ; |
182 | |
183 | std::unordered_set<Val*> unique_args; |
184 | |
185 | std::vector<Val*> params; |
186 | |
187 | // Inputs & Outputs |
188 | for (auto val : kernel_->inputs()) { |
189 | params.push_back(val); |
190 | } |
191 | for (auto val : kernel_->outputs()) { |
192 | TORCH_INTERNAL_ASSERT( |
193 | !val->isScalar(), "No scalar output is allowed: " , val->toString()); |
194 | params.push_back(val); |
195 | } |
196 | |
197 | // Generate parameter declarations |
198 | unsigned int duplicate_counter = 0; |
199 | for (auto i : c10::irange(params.size())) { |
200 | std::stringstream var_name_ss; |
201 | if (params[i]->isA<TensorView>()) { |
202 | var_name_ss << varName(params[i]->as<TensorView>()); |
203 | } else { |
204 | var_name_ss << gen(params[i]); |
205 | } |
206 | |
207 | // If value is duplicate in arguments change the name to avoid name |
208 | // conflicts in args. |
209 | if (!unique_args.emplace(params[i]).second) { |
210 | var_name_ss << "_duplicate_" << duplicate_counter++; |
211 | } |
212 | |
213 | if (const auto tv = dynamic_cast<TensorView*>(params[i])) { |
214 | if (tv->isCpuScalar()) { |
215 | code_ << " CpuScalarTensor<" << params[i]->dtype() << "> " |
216 | << var_name_ss.str(); |
217 | } else { |
218 | code_ |
219 | << "Tensor<" << params[i]->dtype() << ", " |
220 | << TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size() |
221 | << "> " << var_name_ss.str(); |
222 | } |
223 | } else { |
224 | TORCH_INTERNAL_ASSERT(params[i]->isScalar()); // NOLINT (LLVM bug 48525) |
225 | TORCH_INTERNAL_ASSERT(params[i]->definition() == nullptr); |
226 | code_ << params[i]->dtype() << " " << var_name_ss.str(); |
227 | } |
228 | |
229 | if (i + 1 != params.size()) { |
230 | code_ << ", " ; |
231 | } |
232 | } |
233 | |
234 | // Global buffers |
235 | for (auto allocate : kernel_summary.global_allocations) { |
236 | TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<TensorView>()); |
237 | const auto tv = allocate->buffer()->as<TensorView>(); |
238 | const auto& maybe_rfactor_domain = tv->domain()->hasRFactor() |
239 | ? tv->domain()->getRFactorDomain() |
240 | : tv->domain()->getRootDomain(); |
241 | const auto nDims = std::count_if( |
242 | maybe_rfactor_domain.begin(), |
243 | maybe_rfactor_domain.end(), |
244 | [](const IterDomain* id) { return !id->isReduction(); }); |
245 | code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> " |
246 | << varName(tv); |
247 | } |
248 | |
249 | // Kernels generating random numbers take extra (seed, offset) arguments |
250 | if (kernel_summary.max_rng_offsets >= 0) { |
251 | code_ << ", at::PhiloxCudaState philox_args" ; |
252 | } |
253 | |
254 | code_ << ") " ; |
255 | } |
256 | |
257 | // Generates setup code which is executed before the kernel body |
258 | void genPrologue() { |
259 | const auto& kernel_summary = kernel_->summary(); |
260 | |
261 | // Random number generator (optional) |
262 | if (kernel_summary.max_rng_offsets >= 0) { |
263 | indent() << "auto philox_offset = philox_args.captured_ ?\n" ; |
264 | indent() |
265 | << " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n" ; |
266 | indent() << " philox_args.offset_.val;\n" ; |
267 | indent() << "uint4 rng_result;\n" ; |
268 | indent() << "nvfuser_index_t rng_subseq = -1;\n" ; |
269 | indent() << "nvfuser_index_t rng_offset = -1;\n" ; |
270 | } |
271 | |
272 | // Do we have any dynamic shared memory buffers? |
273 | const bool has_dynamic_smem = |
274 | !kernel_summary.dynamic_smem_allocations.empty(); |
275 | |
276 | // Do we have any reductions? |
277 | const bool has_reductions = kernel_summary.has_block_reductions || |
278 | kernel_summary.has_grid_reductions; |
279 | const bool has_parallel_welford = |
280 | kernel_summary.has_block_welford || kernel_summary.has_grid_welford; |
281 | |
282 | // Shared memory |
283 | if (has_dynamic_smem || has_reductions || has_parallel_welford) { |
284 | indent() << "alignas(" |
285 | #ifndef USE_ROCM |
286 | << 16 // always align to 16B for any shared mem allocation |
287 | #else |
288 | << 8 // for HIP, we want 8-aligned even for smaller datatypes |
289 | #endif |
290 | << ") extern __shared__ char array[];\n" ; |
291 | |
292 | if (has_dynamic_smem) { |
293 | indent() << "unsigned smem_offset = 0;\n" ; |
294 | } |
295 | |
296 | if (has_reductions || has_parallel_welford) { |
297 | indent() << "void* shared_mem = array;\n" ; |
298 | if (has_dynamic_smem) { |
299 | if (has_parallel_welford) { |
300 | indent() << "smem_offset += " |
301 | << "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof(" |
302 | << kernel_summary.largest_smem_data_type << "));\n" ; |
303 | } else { |
304 | indent() << "smem_offset += " |
305 | << "((blockDim.x * blockDim.y * blockDim.z) * sizeof(" |
306 | << kernel_summary.largest_smem_data_type << "));\n" ; |
307 | } |
308 | } |
309 | |
310 | if (has_parallel_welford) { |
311 | // Unpack shared mem pointer |
312 | auto space_type = kernel_summary.largest_smem_data_type; |
313 | indent() |
314 | << "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n" ; |
315 | indent() << space_type << " *shared_mem_var = " |
316 | << "static_cast<" << space_type << "*>(" |
317 | << "shared_mem);\n" ; |
318 | indent() << space_type |
319 | << " *shared_mem_avg = shared_mem_var + block_size;\n" ; |
320 | indent() << space_type |
321 | << " *shared_mem_n = shared_mem_avg + block_size;\n" ; |
322 | } |
323 | } |
324 | } |
325 | |
326 | // Call the initialization function if using a custom block sync |
327 | if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC" )) { |
328 | indent() << "block_sync::init();\n" ; |
329 | } |
330 | } |
331 | |
332 | void genBody() { |
333 | for (auto expr : kernel_->topLevelExprs()) { |
334 | OptOutConstDispatch::handle(expr); |
335 | } |
336 | } |
337 | |
338 | void startBlock(bool continuation = false) { |
339 | if (continuation) { |
340 | code_ << "{\n" ; |
341 | } else { |
342 | indent() << "{\n" ; |
343 | } |
344 | ++block_nest_level_; |
345 | } |
346 | |
347 | void endBlock(const char* sep = "\n" ) { |
348 | --block_nest_level_; |
349 | TORCH_CHECK(block_nest_level_ >= 0); |
350 | indent() << "}" << sep; |
351 | } |
352 | |
353 | std::ostream& indent() { |
354 | for (const auto i : c10::irange(block_nest_level_)) { |
355 | (void)i; // Suppress unused variable warning |
356 | code_ << kTab; |
357 | } |
358 | return code_; |
359 | } |
360 | |
361 | std::string gen(const Statement* stmt) { |
362 | std::stringstream tmp_code; |
363 | initStringStreamFormat(tmp_code); |
364 | std::swap(tmp_code, code_); |
365 | OptOutConstDispatch::handle(stmt); |
366 | std::swap(tmp_code, code_); |
367 | return tmp_code.str(); |
368 | } |
369 | |
370 | std::string varName(const Val* val) { |
371 | std::stringstream name; |
372 | if (val->isA<TensorView>()) { |
373 | name << "T" ; |
374 | } else if (val->isA<kir::IntPair>()) { |
375 | name << "ip" ; |
376 | } else { |
377 | name << typePrefix(val->dtype()); |
378 | } |
379 | name << val->name(); |
380 | return name.str(); |
381 | } |
382 | |
383 | std::string genInline(const Statement* stmt) { |
384 | const bool saved_inline = print_inline_; |
385 | print_inline_ = true; |
386 | auto result = gen(stmt); |
387 | print_inline_ = saved_inline; |
388 | // NOLINTNEXTLINE(performance-no-automatic-move) |
389 | return result; |
390 | } |
391 | |
392 | void handle(const kir::Predicate* pred) final { |
393 | TORCH_INTERNAL_ASSERT(pred->hasValue()); |
394 | code_ << gen(pred->value()); |
395 | } |
396 | |
397 | void handle(const Bool* pred) final { |
398 | const auto def = pred->definition(); |
399 | const bool has_alloc = alloc_map_.find(pred) != alloc_map_.end(); |
400 | if (def != nullptr && !has_alloc) { |
401 | code_ << "(" << gen(def) << ")" ; |
402 | } else if (pred->isConst()) { |
403 | code_ << (*pred->value() ? "true" : "false" ); |
404 | } else { |
405 | code_ << varName(pred); |
406 | } |
407 | } |
408 | |
409 | void handle(const Double* d) final { |
410 | const auto def = d->definition(); |
411 | const bool has_alloc = alloc_map_.find(d) != alloc_map_.end(); |
412 | if (def != nullptr && !has_alloc) { |
413 | code_ << "(" << gen(def) << ")" ; |
414 | } else if (d->isConst()) { |
415 | auto val = *d->value(); |
416 | // note: default inf/nan doesn't work and should be replaced with macros |
417 | // `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead. |
418 | if (std::isinf(val)) { |
419 | if (val > 0) { |
420 | code_ << "POS_INFINITY" ; |
421 | } else { |
422 | code_ << "NEG_INFINITY" ; |
423 | } |
424 | } else if (std::isnan(val)) { |
425 | code_ << "NAN" ; |
426 | } else { |
427 | code_ << val; |
428 | } |
429 | } else { |
430 | code_ << varName(d); |
431 | } |
432 | } |
433 | |
434 | void handle(const Int* i) final { |
435 | // Check the replacement map first. If there's an entry for i, use |
436 | // the corresponding replacement. |
437 | auto replace_it = index_replacement_map_.find(i); |
438 | if (replace_it != index_replacement_map_.end()) { |
439 | code_ << replace_it->second; |
440 | return; |
441 | } |
442 | |
443 | const auto def = i->definition(); |
444 | const bool has_alloc = alloc_map_.find(i) != alloc_map_.end(); |
445 | if (def != nullptr && !has_alloc) { |
446 | code_ << "(" << genInline(def) << ")" ; |
447 | } else if (i->isConst()) { |
448 | code_ << *i->value(); |
449 | } else { |
450 | code_ << varName(i); |
451 | } |
452 | } |
453 | |
454 | void handle(const ComplexDouble* c) final { |
455 | const auto def = c->definition(); |
456 | const bool has_alloc = alloc_map_.find(c) != alloc_map_.end(); |
457 | if (def != nullptr && !has_alloc) { |
458 | code_ << "(" << gen(def) << ")" ; |
459 | } else if (c->isConst()) { |
460 | code_ << "std::complex<double>" << *c->value(); |
461 | } else { |
462 | code_ << varName(c); |
463 | } |
464 | } |
465 | |
466 | void handle(const NamedScalar* ns) final { |
467 | // dim3 components are unsigned int. Cast to signed integer to |
468 | // support negative indexing |
469 | if (ns->getParallelIndex().has_value() || |
470 | ns->getParallelDim().has_value()) { |
471 | code_ << "((nvfuser_index_t)" << ns->name() << ")" ; |
472 | } else { |
473 | code_ << ns->name(); |
474 | } |
475 | } |
476 | |
477 | //! Returns the sum of all indices in a TensorIndex, |
478 | //! or 0 if the indices vector is empty. |
479 | //! Used lowering generic tensor index and lowering |
480 | //! mma fragment indices. |
481 | std::string genTensorIndex(const kir::TensorIndex* ti) { |
482 | bool first = true; |
483 | std::stringstream index; |
484 | for (auto* ind : ti->indices()) { |
485 | if (!ind->isZeroInt()) { |
486 | if (!first) { |
487 | index << " + " ; |
488 | } |
489 | index << genInline(ind); |
490 | first = false; |
491 | } |
492 | } |
493 | |
494 | if (first) { |
495 | index << "0" ; |
496 | } |
497 | |
498 | return index.str(); |
499 | } |
500 | |
501 | void handle(const kir::TensorIndex* ti) final { |
502 | bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global && |
503 | kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID(); |
504 | if (is_volatile) { |
505 | code_ << "*(volatile " << ti->getDataType().value() << "*)&" ; |
506 | } |
507 | code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]" ; |
508 | } |
509 | |
510 | void handle(const ViewAsScalar* sv) final { |
511 | indent() << gen(sv->output(0)) << " = " << gen(sv->input(0)) << "[" |
512 | << gen(sv->index()) << "];\n" ; |
513 | } |
514 | |
515 | void handle(const IterDomain*) final { |
516 | TORCH_INTERNAL_ASSERT(false, "Unreachable" ); |
517 | } |
518 | |
519 | void handle(const TensorDomain*) final { |
520 | TORCH_INTERNAL_ASSERT(false, "Unreachable" ); |
521 | } |
522 | |
523 | void handle(const TensorView*) final { |
524 | TORCH_INTERNAL_ASSERT(false, "Unreachable" ); |
525 | } |
526 | |
527 | //! Utility for generating vectorized pointer access in ldsm and |
528 | //! cpasync. |
529 | //! TODO: this access pattern as is could be merged with exisiting |
530 | //! vectorization handling logic but this path will be updated in |
531 | //! follow ups to optimize the generated assembly so keeping them |
532 | //! separate path for now. |
533 | std::string genVectorPointer(Val* val, DataType dtype, int vec_size) { |
534 | std::stringstream ss; |
535 | |
536 | ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << "," |
537 | << vec_size << ">*>(&" << gen(val) << ")" ; |
538 | |
539 | return ss.str(); |
540 | } |
541 | |
542 | // Utility function to emit a cp.async intrinsic |
543 | void genCpAsync(const LoadStoreOp* ldst, int vec_size) { |
544 | auto dtype = ldst->in()->getDataType().value(); |
545 | |
546 | if (ldst->predicate() == nullptr) { |
547 | // Out of line predicate variant |
548 | indent() << "Ampere::cpAsync(" |
549 | << genVectorPointer(ldst->out(), dtype, vec_size) << "," |
550 | << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n" ; |
551 | } else { |
552 | // Inline predicate variant |
553 | indent() << "Ampere::cpAsync(" |
554 | << genVectorPointer(ldst->out(), dtype, vec_size) << "," |
555 | << genVectorPointer(ldst->in(), dtype, vec_size) << "," |
556 | << genInline(ldst->predicate()) << ");\n" ; |
557 | } |
558 | } |
559 | |
560 | void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) { |
561 | auto dtype = ldst->in()->getDataType().value(); |
562 | indent() << "Turing::ldMatrix" ; |
563 | if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) { |
564 | code_ << "T" ; |
565 | } |
566 | code_ << " (" ; |
567 | code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size) |
568 | << "," |
569 | << "&" << gen(ldst->in()) << ");\n" ; |
570 | } |
571 | |
572 | void handle(const FullOp* fop) final { |
573 | indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")" |
574 | << gen(fop->getFillValue()) << ";\n" ; |
575 | } |
576 | |
577 | void handle(const ARangeOp* aop) final { |
578 | auto index = |
579 | genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>()); |
580 | indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">" ; |
581 | code_ << "(" << index << ", " << gen(aop->start()) << ", " |
582 | << gen(aop->step()) << ");\n" ; |
583 | } |
584 | |
585 | void handle(const EyeOp* aop) final { |
586 | auto index1 = gen(aop->getIndex1()); |
587 | auto index2 = gen(aop->getIndex2()); |
588 | indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")" ; |
589 | code_ << "(" << index1 << " == " << index2 << ");\n" ; |
590 | } |
591 | |
592 | void handle(const UnaryOp* uop) final { |
593 | bool is_vector_op = false; |
594 | size_t vector_word_size = 1; |
595 | |
596 | if (uop->out()->isA<kir::TensorIndex>()) { |
597 | auto out_tv = uop->out()->as<kir::TensorIndex>()->view(); |
598 | if (std::any_of( |
599 | out_tv->domain()->domain().begin(), |
600 | out_tv->domain()->domain().end(), |
601 | [&](IterDomain* id) { return id->isMma(); })) { |
602 | auto mma = dynamic_cast<MmaOp*>( |
603 | uop->out()->as<kir::TensorIndex>()->view()->definition()); |
604 | TORCH_INTERNAL_ASSERT( |
605 | mma != nullptr, "CodeGen: mma op not in mma loop" ); |
606 | genMmaInitialization(mma, uop); |
607 | return; |
608 | } |
609 | } |
610 | |
611 | if (vectorize_scope_ && uop->out()->isA<kir::TensorIndex>()) { |
612 | auto ti = uop->out()->as<kir::TensorIndex>(); |
613 | |
614 | bool vectorize_op = false; |
615 | bool misaligned_op = false; |
616 | |
617 | for (auto id : ti->view()->domain()->domain()) { |
618 | if (!isParallelTypeVectorize(id->getParallelType())) { |
619 | continue; |
620 | } |
621 | |
622 | ExpressionEvaluator expr_eval(id->fusion()); |
623 | auto vector_size_optional = expr_eval.evaluate(id->extent()); |
624 | |
625 | TORCH_INTERNAL_ASSERT( |
626 | vector_size_optional.has_value(), |
627 | "Could not evaluate constant value bound to vectorized dim." ); |
628 | |
629 | vector_word_size = vector_size_optional->as<int64_t>(); |
630 | |
631 | vectorize_op = id->getParallelType() == ParallelType::Vectorize; |
632 | misaligned_op = |
633 | id->getParallelType() == ParallelType::MisalignedVectorize; |
634 | break; |
635 | } |
636 | |
637 | if (vectorize_op) { |
638 | TORCH_INTERNAL_ASSERT( |
639 | uop->getUnaryOpType() == UnaryOpType::Set, |
640 | "Cannot vectorize operations that are not sets. " , |
641 | "Use cacheBefore and cacheAfter to store/load with vectorized reads into buffers." ); |
642 | is_vector_op = true; |
643 | } |
644 | |
645 | if (misaligned_op) { |
646 | is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set); |
647 | } |
648 | |
649 | if (is_vector_op && !uop->in()->isScalar()) { |
650 | TORCH_INTERNAL_ASSERT( |
651 | uop->out()->dtype() == uop->in()->dtype(), |
652 | "Vectorized store/load requires input and output datatypes match." ); |
653 | } |
654 | |
655 | if (is_vector_op) { |
656 | auto out_tv = uop->out()->as<kir::TensorIndex>()->view(); |
657 | if (uop->in()->isScalar()) { |
658 | // Note: |
659 | // Double buffered local tensors need indexed initialization, |
660 | // so will need to use `arraySet` option. |
661 | if (out_tv->getMemoryType() == MemoryType::Local && |
662 | !(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) { |
663 | // Vectorized initialization |
664 | indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n" ; |
665 | } else { |
666 | // Note: currently arraySet option is not vectorized, so it will |
667 | // rely on auto vectorization pass of cuda compiler. |
668 | indent() << "arraySet<" << out_tv->getDataType().value() << ", " |
669 | << vector_word_size << ">(&" << gen(uop->out()) << ", " |
670 | << "(" << out_tv->getDataType().value() << ")" |
671 | << gen(uop->in()) << ");\n" ; |
672 | } |
673 | } else { |
674 | // Vectorized load |
675 | TORCH_INTERNAL_ASSERT( |
676 | uop->in()->isA<kir::TensorIndex>(), |
677 | "Invalid input to unary op with tensor output, found: " , |
678 | uop->in()->toString()); |
679 | |
680 | auto in_tv = uop->in()->as<kir::TensorIndex>()->view(); |
681 | bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global && |
682 | in_tv->getMemoryType() == MemoryType::Local; |
683 | |
684 | bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local && |
685 | in_tv->getMemoryType() == MemoryType::Global; |
686 | |
687 | bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global && |
688 | in_tv->getMemoryType() == MemoryType::Global; |
689 | |
690 | bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global && |
691 | kernel_->summary().sync_map.needsRawSync(out_tv).hasBID(); |
692 | |
693 | bool is_volatile_from = |
694 | in_tv->getMemoryType() == MemoryType::Global && |
695 | kernel_->summary().sync_map.needsRawSync(in_tv).hasBID(); |
696 | |
697 | if (localToGlobal) { |
698 | indent() << "loadLocalToGlobal<" << uop->out()->dtype() << ", " |
699 | << vector_word_size << ", " |
700 | << (is_volatile_to ? "true" : "false" ) << ">(" ; |
701 | code_ << " &" << gen(uop->out()) << ", &" << gen(uop->in()) |
702 | << ");\n" ; |
703 | } else if (globalToLocal) { |
704 | indent() << "loadGlobalToLocal<" << uop->out()->dtype() << ", " |
705 | << vector_word_size << ", " |
706 | << (is_volatile_from ? "true" : "false" ) << ">(&" |
707 | << gen(uop->out()) << ", " ; |
708 | code_ << " &" << gen(uop->in()) << ");\n" ; |
709 | } else if (globalToGlobal) { |
710 | indent() << "loadGlobalToGlobal<" << uop->out()->dtype() << ", " |
711 | << vector_word_size << ", " |
712 | << (is_volatile_to ? "true" : "false" ) << ", " |
713 | << (is_volatile_from ? "true" : "false" ) << ">(" ; |
714 | code_ << " &" << gen(uop->out()) << ", " ; |
715 | code_ << " &" << gen(uop->in()) << ");\n" ; |
716 | } else { |
717 | indent() << "loadGeneric<" << uop->out()->dtype() << ", " |
718 | << vector_word_size << ">(" ; |
719 | code_ << " &" << gen(uop->out()) << ", " ; |
720 | code_ << " &" << gen(uop->in()) << ");\n" ; |
721 | } |
722 | } |
723 | return; |
724 | } |
725 | } |
726 | |
727 | const auto op_type = uop->getUnaryOpType(); |
728 | |
729 | if (uop->out()->isA<NamedScalar>()) { |
730 | if (auto op = inline_op_str(op_type)) { |
731 | indent() << gen(uop->out()) << " = " << *op << genInline(uop->in()) |
732 | << ";\n" ; |
733 | } |
734 | return; |
735 | } |
736 | |
737 | if (!print_inline_) { |
738 | indent() << gen(uop->out()); |
739 | if (!uop->out()->isScalar() && !uop->in()->isScalar()) { |
740 | code_ << "\n" ; |
741 | indent() << kTab; |
742 | } |
743 | code_ << " = " ; |
744 | } |
745 | |
746 | if (auto op = inline_op_str(op_type)) { |
747 | if (alsoBooleanOperator(op_type) && |
748 | uop->out()->dtype() == DataType::Bool) { |
749 | code_ << stringifyBooleanOp(op_type) << gen(uop->in()); |
750 | } else { |
751 | code_ << *op << gen(uop->in()); |
752 | } |
753 | } else { |
754 | if (op_type == UnaryOpType::Cast) { |
755 | const auto cast_str = |
756 | cast_func_str({uop->in()->dtype(), uop->out()->dtype()}); |
757 | TORCH_INTERNAL_ASSERT( |
758 | cast_str.has_value(), |
759 | "Invalid cast. Input type: " , |
760 | uop->in()->dtype(), |
761 | ", output type: " , |
762 | uop->out()->dtype()); |
763 | code_ << cast_str.value(); |
764 | } else { |
765 | code_ << op_type; |
766 | if (needFloatSuffix(op_type) && |
767 | uop->out()->dtype() == DataType::Float) { |
768 | code_ << "f" ; |
769 | } |
770 | } |
771 | |
772 | code_ << "(" << gen(uop->in()) << ")" ; |
773 | } |
774 | |
775 | if (!print_inline_) { |
776 | code_ << ";\n" ; |
777 | } |
778 | } |
779 | |
780 | void handle(const RNGOp* rop) final { |
781 | // TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an |
782 | // innermost ID of size 4 (float) or size 2 (double)? |
783 | auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>()); |
784 | int multiple = rop->dtype() == DataType::Double ? 2 : 4; |
785 | indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index |
786 | << ";\n" ; |
787 | indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index" |
788 | << rop->name() << " / " << multiple << ";\n" ; |
789 | indent() << "nvfuser_index_t rng_component" << rop->name() |
790 | << " = linear_index" << rop->name() << " % " << multiple << ";\n" ; |
791 | indent() << "nvfuser_index_t rng_offset" << rop->name() << " = " |
792 | << rop->getRNGOffset() << ";\n" ; |
793 | indent() << "if (rng_subseq != rng_subseq" << rop->name() |
794 | << " || rng_offset != rng_offset" << rop->name() << ") {\n" ; |
795 | indent() << " auto seed = philox_args.captured_ ?\n" |
796 | << " static_cast<uint64_t>(*(philox_args.seed_.ptr)) : \n" |
797 | << " philox_args.seed_.val;\n" ; |
798 | indent() << " rng_result = philox(seed, rng_subseq" << rop->name() |
799 | << ", philox_offset / 4 + rng_offset" << rop->name() << ");\n" ; |
800 | indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n" ; |
801 | indent() << " rng_offset = rng_offset" << rop->name() << ";\n" ; |
802 | indent() << "}\n" ; |
803 | auto op_type = rop->getRNGOpType(); |
804 | indent() << gen(rop->output(0)) << " = " << op_type; |
805 | if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) { |
806 | code_ << "f" ; |
807 | } |
808 | code_ << "(rng_result, rng_component" << rop->name(); |
809 | switch (op_type) { |
810 | case RNGOpType::UniformRange: { |
811 | auto parameters = rop->getParameters(); |
812 | TORCH_INTERNAL_ASSERT(parameters.size() == 2); |
813 | code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]); |
814 | break; |
815 | } |
816 | default:; |
817 | } |
818 | code_ << ");\n" ; |
819 | } |
820 | |
821 | std::string genBinaryOp( |
822 | BinaryOpType op_type, |
823 | DataType data_type, |
824 | const std::string& lhs, |
825 | const std::string& rhs) { |
826 | std::stringstream expr; |
827 | if (auto op = inline_op_str(op_type)) { |
828 | expr << lhs << " " ; |
829 | if (alsoBooleanOperator(op_type) && data_type == DataType::Bool) { |
830 | expr << stringifyBooleanOp(op_type); |
831 | } else { |
832 | expr << *op; |
833 | } |
834 | expr << " " << rhs; |
835 | } else { |
836 | if (integer_op_str(op_type) && isIntegralType(data_type)) { |
837 | auto int_op = integer_op_str(op_type); |
838 | expr << *int_op; |
839 | } else if (bool_op_str(op_type) && isBooleanType(data_type)) { |
840 | auto bool_op = bool_op_str(op_type); |
841 | expr << *bool_op; |
842 | } else { |
843 | expr << op_type; |
844 | if (needFloatSuffix(op_type) && data_type == DataType::Float) { |
845 | expr << "f" ; |
846 | } |
847 | } |
848 | expr << "(" << lhs << ", " << rhs << ")" ; |
849 | } |
850 | return expr.str(); |
851 | } |
852 | |
853 | // If one argument is a tensorview and the other is a scalar, make sure we |
854 | // cast the scalar to the tensorview type |
855 | std::string scalarCast(Val* lhs, Val* rhs) { |
856 | // If neither are scalars return |
857 | if (!((lhs->isScalar() || rhs->isScalar()) && |
858 | (lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) { |
859 | return "" ; |
860 | } |
861 | |
862 | // Looking for mixed tensorview scalar options where types don't match |
863 | // but are either both floating or both int types. We should cast |
864 | // scalar to tensorview type in these instances. |
865 | auto lhs_t = lhs->dtype(); |
866 | auto rhs_t = rhs->dtype(); |
867 | |
868 | // If same type, don't cast anything |
869 | if (lhs_t == rhs_t) { |
870 | return "" ; |
871 | } |
872 | |
873 | // Don't do anything when dealing with bools |
874 | if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) { |
875 | return "" ; |
876 | } |
877 | |
878 | // Mixing floating and int combination |
879 | if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) || |
880 | (isIntegralType(lhs_t) != isIntegralType(rhs_t))) { |
881 | return "" ; |
882 | } |
883 | |
884 | std::stringstream cast; |
885 | cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") " ; |
886 | return cast.str(); |
887 | } |
888 | |
889 | // If possible, replace pow with mul. Return true when successful. |
890 | bool genPowerWithMul(const BinaryOp* bop) { |
891 | if (bop->getBinaryOpType() != BinaryOpType::Pow) { |
892 | return false; |
893 | } |
894 | |
895 | auto rhs = bop->rhs(); |
896 | c10::optional<double> exponent; |
897 | if (auto val_int = dynamic_cast<Int*>(rhs)) { |
898 | if (val_int->isConst()) { |
899 | exponent = val_int->value().value(); |
900 | } |
901 | } else if (auto val_float = dynamic_cast<Double*>(rhs)) { |
902 | if (val_float->isConst()) { |
903 | auto fp_exp = val_float->value().value(); |
904 | double int_exp = 0; |
905 | if (std::modf(fp_exp, &int_exp) == 0) { |
906 | exponent = int_exp; |
907 | } |
908 | } |
909 | } |
910 | |
911 | if (!exponent.has_value()) { |
912 | return false; |
913 | } |
914 | |
915 | // Only **2 and **3 are considered |
916 | if (!(exponent.value() == 2 || exponent.value() == 3)) { |
917 | return false; |
918 | } |
919 | |
920 | auto lhs = gen(bop->lhs()); |
921 | |
922 | if (print_inline_) { |
923 | code_ << lhs << " * " << lhs; |
924 | if (exponent.value() == 3) { |
925 | code_ << " * " << lhs; |
926 | } |
927 | } else { |
928 | indent() << gen(bop->out()); |
929 | if (bop->out()->isScalar()) { |
930 | code_ << " = " << lhs << " * " << lhs; |
931 | if (exponent.value() == 3) { |
932 | code_ << " * " << lhs; |
933 | } |
934 | } else { |
935 | code_ << "\n" ; |
936 | indent() << kTab << "= " << lhs << "\n" ; |
937 | indent() << kTab << "* " << lhs; |
938 | if (exponent.value() == 3) { |
939 | code_ << "\n" ; |
940 | indent() << kTab << "* " << lhs; |
941 | } |
942 | } |
943 | } |
944 | |
945 | code_ << ";\n" ; |
946 | return true; |
947 | } |
948 | |
949 | void handle(const BinaryOp* bop) final { |
950 | // Try replacing pow with mul |
951 | if (genPowerWithMul(bop)) { |
952 | return; |
953 | } |
954 | |
955 | const auto op_type = bop->getBinaryOpType(); |
956 | if (print_inline_) { |
957 | // Inline expression: `lhs op rhs` |
958 | code_ << genBinaryOp( |
959 | op_type, bop->out()->dtype(), gen(bop->lhs()), gen(bop->rhs())); |
960 | } else { |
961 | indent() << gen(bop->out()); |
962 | if (bop->out()->isScalar()) { |
963 | // Single line: `out = lhs op rhs;` |
964 | code_ << " = " |
965 | << genBinaryOp( |
966 | op_type, |
967 | bop->out()->dtype(), |
968 | gen(bop->lhs()), |
969 | gen(bop->rhs())); |
970 | } else { |
971 | // Split TensorView expressions across multiple lines: |
972 | // |
973 | // out |
974 | // = lhs |
975 | // op rhs; |
976 | // |
977 | |
978 | auto cast = scalarCast(bop->lhs(), bop->rhs()); |
979 | if (auto op = inline_op_str(op_type)) { |
980 | code_ << "\n" ; |
981 | indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "" ) |
982 | << gen(bop->lhs()) << "\n" ; |
983 | indent() << kTab; |
984 | if (alsoBooleanOperator(op_type) && |
985 | bop->out()->dtype() == DataType::Bool) { |
986 | code_ << stringifyBooleanOp(op_type); |
987 | } else { |
988 | code_ << *op; |
989 | } |
990 | code_ << " " << (bop->rhs()->isScalar() ? cast : "" ) |
991 | << gen(bop->rhs()); |
992 | } else { |
993 | if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) { |
994 | auto int_op = integer_op_str(op_type); |
995 | code_ << " = " << *int_op << "(\n" ; |
996 | } else if ( |
997 | bool_op_str(op_type) && isBooleanType(bop->out()->dtype())) { |
998 | auto bool_op = bool_op_str(op_type); |
999 | code_ << " = " << *bool_op << "(\n" ; |
1000 | } else { |
1001 | std::stringstream op_str; |
1002 | op_str << op_type; |
1003 | if (needFloatSuffix(op_type) && |
1004 | bop->out()->dtype() == DataType::Float) { |
1005 | op_str << "f" ; |
1006 | } |
1007 | code_ << " = " << op_str.str() << "(\n" ; |
1008 | } |
1009 | indent() << kTab << (bop->lhs()->isScalar() ? cast : "" ) |
1010 | << gen(bop->lhs()) << ",\n" ; |
1011 | indent() << kTab << (bop->rhs()->isScalar() ? cast : "" ) |
1012 | << gen(bop->rhs()) << ")" ; |
1013 | } |
1014 | } |
1015 | code_ << ";\n" ; |
1016 | } |
1017 | } |
1018 | |
1019 | void handle(const TernaryOp* top) final { |
1020 | if (!print_inline_) { |
1021 | indent() << gen(top->out()); |
1022 | if (!top->out()->isScalar()) { |
1023 | code_ << "\n" ; |
1024 | indent() << kTab; |
1025 | } |
1026 | code_ << " = " ; |
1027 | } |
1028 | |
1029 | code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", " ; |
1030 | |
1031 | // Make sure the two operands of where has the same |
1032 | // type. Note that compiling "where(0.0f, 0.0)" fails because of |
1033 | // the overloading ambiguity. |
1034 | if (top->getTernaryOpType() == TernaryOpType::Where) { |
1035 | auto cast = scalarCast(top->in2(), top->in3()); |
1036 | code_ << (top->in2()->isScalar() ? cast : "" ) << gen(top->in2()) << ", " |
1037 | << (top->in3()->isScalar() ? cast : "" ) << gen(top->in3()) << ")" ; |
1038 | } else { |
1039 | code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")" ; |
1040 | } |
1041 | |
1042 | if (!print_inline_) { |
1043 | code_ << ";\n" ; |
1044 | } |
1045 | } |
1046 | |
1047 | std::string genArchString(MmaOptions::MacroType macro) { |
1048 | std::stringstream ss; |
1049 | if (isVolta(macro)) { |
1050 | ss << "Volta" ; |
1051 | } else if (isTuring(macro)) { |
1052 | ss << "Turing" ; |
1053 | } else if (isAmpere(macro)) { |
1054 | ss << "Ampere" ; |
1055 | } else { |
1056 | TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch" ); |
1057 | } |
1058 | return ss.str(); |
1059 | } |
1060 | |
1061 | std::string genMmaOp(const MmaOp* mma, bool init = false) { |
1062 | std::stringstream ss; |
1063 | auto options = mma->options(); |
1064 | ss << genArchString(options.macro) << "::" ; |
1065 | if (init) { |
1066 | ss << "init" ; |
1067 | } |
1068 | ss << toString(options.macro); |
1069 | |
1070 | if (isVolta(options.macro)) { |
1071 | ss << toString(options.operand_layout); |
1072 | } else if (isTuring(options.macro) || isAmpere(options.macro)) { |
1073 | // mma's in turing and ampere TN only, transpose is handled either |
1074 | // via ldmatrix for fp16 or explicitly for other types. |
1075 | ss << "TN" ; |
1076 | } |
1077 | // TODO: additional parameter could be removed by swizzling iterdomain |
1078 | auto acc_stride = mma->accStride(); |
1079 | TORCH_INTERNAL_ASSERT(acc_stride > 0); |
1080 | ss << "<" << acc_stride << ">" ; |
1081 | return ss.str(); |
1082 | } |
1083 | |
1084 | void genMmaOperands(const MmaOp* mma) { |
1085 | std::stringstream ss; |
1086 | auto options = mma->options(); |
1087 | auto in_a = mma->inA()->as<kir::TensorIndex>()->view(); |
1088 | auto dtype = in_a->getDataType().value(); |
1089 | indent() << kTab << "&(reinterpret_cast<Array<" << dtype << "," |
1090 | << getInputARegisterSize(options.macro) << "," |
1091 | << getInputARegisterSize(options.macro) << ">*>(&" |
1092 | << varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")[" |
1093 | << genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])" |
1094 | << ",\n" ; |
1095 | indent() << kTab << "&(reinterpret_cast<Array<" << dtype << "," |
1096 | << getInputBRegisterSize(options.macro) << "," |
1097 | << getInputBRegisterSize(options.macro) << ">*>(&" |
1098 | << varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")[" |
1099 | << genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])" ; |
1100 | } |
1101 | |
1102 | void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) { |
1103 | auto options = mma->options(); |
1104 | |
1105 | indent() << genMmaOp(mma, true) << "(reinterpret_cast<Array<" |
1106 | << mma->out()->getDataType().value() << "," |
1107 | << getOutputRegisterSize(options.macro) << "," |
1108 | << getOutputRegisterSize(options.macro) << ">*>" |
1109 | << "(&" << gen(uop->out()) << "));\n" ; |
1110 | } |
1111 | |
1112 | void handle(const MmaOp* mma) final { |
1113 | auto options = mma->options(); |
1114 | auto out = mma->out()->as<kir::TensorIndex>(); |
1115 | indent() << genMmaOp(mma) << "(\n" ; |
1116 | indent() << kTab << "reinterpret_cast<Array<" |
1117 | << out->view()->getDataType().value() << "," |
1118 | << getOutputRegisterSize(options.macro) << "," |
1119 | << getOutputRegisterSize(options.macro) << ">*>(&" |
1120 | << gen(mma->out()) << "),\n" ; |
1121 | genMmaOperands(mma); |
1122 | code_ << ");\n" ; |
1123 | } |
1124 | |
1125 | std::string genReductionOp(BinaryOpType op_type, DataType data_type) { |
1126 | std::stringstream lambda; |
1127 | lambda << "[](" << data_type << " &a, " << data_type << " b) " |
1128 | << "{ a = " << genBinaryOp(op_type, data_type, "a" , "b" ) << "; }" ; |
1129 | return lambda.str(); |
1130 | } |
1131 | |
1132 | void handle(const BroadcastOp* stmt) final { |
1133 | TORCH_INTERNAL_ASSERT(stmt->out()->isA<kir::TensorIndex>()); |
1134 | |
1135 | const ParallelTypeBitmap parallel_types = |
1136 | kernel_->summary().broadcast_parallel_types.at(stmt); |
1137 | |
1138 | if (parallel_types.none()) { |
1139 | // Not parallelized |
1140 | indent() << gen(stmt->out()) << "\n" ; |
1141 | indent() << kTab << " = " << gen(stmt->in()) << ";\n" ; |
1142 | return; |
1143 | } |
1144 | |
1145 | TORCH_INTERNAL_ASSERT( |
1146 | !parallel_types.hasBID(), |
1147 | "Parallel broadcast across blocks should have been translated to a GridBroadcast IR node" ); |
1148 | |
1149 | std::stringstream flags_str; |
1150 | for (const ParallelType pt : kParallelTypeTIDs) { |
1151 | const bool parallel_bcast = parallel_types.get(pt); |
1152 | if (pt != kParallelTypeTIDs[0]) { |
1153 | flags_str << ", " ; |
1154 | } |
1155 | flags_str << (parallel_bcast ? "true" : "false" ); |
1156 | } |
1157 | |
1158 | const auto data_type = stmt->out()->dtype(); |
1159 | indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n" ; |
1160 | indent() << kTab << gen(stmt->out()) << ",\n" ; |
1161 | indent() << kTab << gen(stmt->in()) << ",\n" ; |
1162 | indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n" ; |
1163 | TORCH_INTERNAL_ASSERT( |
1164 | stmt->predicate() != nullptr && stmt->predicate()->hasValue()); |
1165 | indent() << kTab << genInline(stmt->predicate()) << ");\n" ; |
1166 | } |
1167 | |
1168 | void genSerialReduction( |
1169 | const kir::TensorIndex* output, |
1170 | const Val* input, |
1171 | BinaryOpType reduction_op_type) { |
1172 | const auto gen_out = gen(output); |
1173 | indent() << gen_out << " = " |
1174 | << genBinaryOp( |
1175 | reduction_op_type, output->dtype(), gen_out, gen(input)) |
1176 | << ";\n" ; |
1177 | return; |
1178 | } |
1179 | |
1180 | void genWarpReduction( |
1181 | const kir::TensorIndex* output, |
1182 | const kir::TensorIndex* input, |
1183 | const Val* init, |
1184 | BinaryOpType reduction_op_type, |
1185 | kir::Predicate* read_pred) { |
1186 | bool is_single_warp = |
1187 | kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp; |
1188 | |
1189 | indent() << "warp::warpReduceTIDX" ; |
1190 | if (is_single_warp) { |
1191 | code_ << "<true>(\n" ; |
1192 | } else { |
1193 | code_ << "<false>(\n" ; |
1194 | } |
1195 | indent() << kTab << gen(output) << ",\n" ; |
1196 | indent() << kTab << gen(input) << ",\n" ; |
1197 | indent() << kTab << genReductionOp(reduction_op_type, output->dtype()) |
1198 | << ",\n" ; |
1199 | indent() << kTab << "threadIdx,\n" ; |
1200 | indent() << kTab << "blockDim,\n" ; |
1201 | indent() << kTab << "static_cast<" << output->dtype() |
1202 | << "*>(shared_mem),\n" ; |
1203 | TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue()); |
1204 | indent() << kTab << genInline(read_pred) << ",\n" ; |
1205 | indent() << kTab << output->dtype() << "(" << genInline(init) << "));\n" ; |
1206 | } |
1207 | |
1208 | void genBlockReduction( |
1209 | const kir::TensorIndex* output, |
1210 | const kir::TensorIndex* input, |
1211 | const Val* init, |
1212 | BinaryOpType reduction_op_type, |
1213 | kir::Predicate* read_pred, |
1214 | kir::Predicate* write_pred) { |
1215 | const auto par_domains = ir_utils::getParallelDomains(output); |
1216 | // Get parallel reduction domains |
1217 | const bool tidx = |
1218 | par_domains.find(ParallelType::TIDx) != par_domains.end() && |
1219 | par_domains.at(ParallelType::TIDx)->isReduction(); |
1220 | const bool tidy = |
1221 | par_domains.find(ParallelType::TIDy) != par_domains.end() && |
1222 | par_domains.at(ParallelType::TIDy)->isReduction(); |
1223 | const bool tidz = |
1224 | par_domains.find(ParallelType::TIDz) != par_domains.end() && |
1225 | par_domains.at(ParallelType::TIDz)->isReduction(); |
1226 | |
1227 | const auto data_type = output->dtype(); |
1228 | |
1229 | indent() << "blockReduce<" << (tidx ? "true" : "false" ) << ", " |
1230 | << (tidy ? "true" : "false" ) << ", " << (tidz ? "true" : "false" ) |
1231 | << ">(\n" ; |
1232 | indent() << kTab << gen(output) << ",\n" ; |
1233 | indent() << kTab << gen(input) << ",\n" ; |
1234 | indent() << kTab << genReductionOp(reduction_op_type, output->dtype()) |
1235 | << ",\n" ; |
1236 | indent() << kTab << "threadIdx,\n" ; |
1237 | indent() << kTab << "blockDim,\n" ; |
1238 | indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n" ; |
1239 | TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue()); |
1240 | indent() << kTab << genInline(read_pred) << ",\n" ; |
1241 | // Pass the write predicate if available and different from the |
1242 | // default predicate. The blockReduce runtime function uses the |
1243 | // default predicate for both read and write when only the |
1244 | // default one is given. |
1245 | if (write_pred != nullptr) { |
1246 | TORCH_INTERNAL_ASSERT(write_pred->hasValue()); |
1247 | indent() << kTab << genInline(write_pred) << ",\n" ; |
1248 | } |
1249 | indent() << kTab << data_type << "(" << genInline(init) << "));\n" ; |
1250 | } |
1251 | |
1252 | void handle(const ReductionOp* rop) final { |
1253 | TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>()); |
1254 | |
1255 | const auto output = rop->out()->as<kir::TensorIndex>(); |
1256 | const auto input = rop->in()->as<kir::TensorIndex>(); |
1257 | const auto domain = output->view()->domain(); |
1258 | const auto op_type = rop->getReductionOpType(); |
1259 | |
1260 | const bool has_block_reduce = domain->hasBlockReduction(); |
1261 | const bool has_grid_reduce = domain->hasGridReduction(); |
1262 | |
1263 | TORCH_INTERNAL_ASSERT( |
1264 | !has_grid_reduce, |
1265 | "ReductionOp does not support block parallelization. GridReductionOp must be used. " , |
1266 | rop->toString()); |
1267 | |
1268 | if (!has_block_reduce) { |
1269 | genSerialReduction(output, input, op_type); |
1270 | } else if ( |
1271 | auto reduction_id = ir_utils::getMaybeWarpReductionDim(output, input)) { |
1272 | genWarpReduction(output, input, rop->init(), op_type, rop->predicate()); |
1273 | } else { |
1274 | genBlockReduction( |
1275 | output, |
1276 | input, |
1277 | rop->init(), |
1278 | op_type, |
1279 | rop->predicate(), |
1280 | rop->writePredicate()); |
1281 | } |
1282 | } |
1283 | |
1284 | void handle(const LoadStoreOp* ldst) final { |
1285 | // TODO: |
1286 | // Need to gradually merge the code path of this |
1287 | // with UnaryOp::Set for vectorization. |
1288 | // There is quite a bit of possible clean up. |
1289 | bool vectorize_op = false; |
1290 | size_t vector_word_size = 1; |
1291 | auto ti = ldst->out()->as<kir::TensorIndex>(); |
1292 | |
1293 | // Check vectorization and set vector word size |
1294 | for (auto id : ti->view()->domain()->domain()) { |
1295 | if (!isParallelTypeVectorize(id->getParallelType())) { |
1296 | continue; |
1297 | } |
1298 | |
1299 | ExpressionEvaluator expr_eval(id->fusion()); |
1300 | auto vector_size_optional = expr_eval.evaluate(id->extent()); |
1301 | |
1302 | TORCH_INTERNAL_ASSERT( |
1303 | vector_size_optional.has_value(), |
1304 | "Could not evaluate constant value bound to vectorized dim." ); |
1305 | |
1306 | TORCH_INTERNAL_ASSERT( |
1307 | id->getParallelType() != ParallelType::MisalignedVectorize, |
1308 | "LoadStoreOp: no support yet for mis-aligned vectorization" ); |
1309 | vector_word_size = vector_size_optional->as<int64_t>(); |
1310 | vectorize_op = true; |
1311 | break; |
1312 | } |
1313 | |
1314 | // Dispatch instruction generation: |
1315 | switch (ldst->opType()) { |
1316 | case LoadStoreOpType::LdMatrix: |
1317 | case LoadStoreOpType::LdMatrixTranspose: |
1318 | TORCH_INTERNAL_ASSERT( |
1319 | vectorize_op, "LdMatrix: Vectorization required: " , ldst); |
1320 | genLdMatrix(ldst, vector_word_size); |
1321 | break; |
1322 | case LoadStoreOpType::CpAsync: |
1323 | genCpAsync(ldst, vector_word_size); |
1324 | break; |
1325 | default: |
1326 | TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type" ); |
1327 | } |
1328 | } |
1329 | |
1330 | void handle(const WelfordOp* wop) final { |
1331 | TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>()); |
1332 | |
1333 | const auto out = wop->out()->as<kir::TensorIndex>(); |
1334 | const auto domain = out->view()->domain(); |
1335 | |
1336 | const auto out_var = wop->outVar(); |
1337 | const auto out_avg = wop->outAvg(); |
1338 | const auto out_N = wop->outN(); |
1339 | |
1340 | const auto in_var = wop->inVar(); |
1341 | const auto in_avg = wop->inAvg(); |
1342 | const auto in_N = wop->inN(); |
1343 | |
1344 | // inVar was allowed to be nullptr. Make sure it isn't. |
1345 | TORCH_INTERNAL_ASSERT( |
1346 | in_var != nullptr, "Welford var input nullptr not allowed" ); |
1347 | |
1348 | const bool has_block_reduce = domain->hasBlockReduction(); |
1349 | const bool has_grid_reduce = domain->hasGridReduction(); |
1350 | |
1351 | // Serial WelfordOp generation |
1352 | if (!has_block_reduce && !has_grid_reduce) { |
1353 | indent() << "welfordCombine (" |
1354 | << "\n" ; |
1355 | indent() << kTab << gen(out_avg) << ",\n" ; |
1356 | indent() << kTab << gen(out_var) << ",\n" ; |
1357 | indent() << kTab << gen(out_N) << ",\n" ; |
1358 | indent() << kTab << gen(in_avg) << ",\n" ; |
1359 | indent() << kTab << "(" << out_avg->dtype() << ")" << gen(in_var) |
1360 | << ",\n" ; |
1361 | indent() << kTab << "(" << out_N->dtype() << ")" << gen(in_N) << ");\n" ; |
1362 | return; |
1363 | } |
1364 | |
1365 | const auto par_domains = ir_utils::getParallelDomains(wop->out()); |
1366 | // Get parallel reduction domains |
1367 | const bool tidx = |
1368 | par_domains.find(ParallelType::TIDx) != par_domains.end() && |
1369 | par_domains.at(ParallelType::TIDx)->isReduction(); |
1370 | const bool tidy = |
1371 | par_domains.find(ParallelType::TIDy) != par_domains.end() && |
1372 | par_domains.at(ParallelType::TIDy)->isReduction(); |
1373 | const bool tidz = |
1374 | par_domains.find(ParallelType::TIDz) != par_domains.end() && |
1375 | par_domains.at(ParallelType::TIDz)->isReduction(); |
1376 | |
1377 | const auto data_type = wop->out()->dtype(); |
1378 | |
1379 | if (has_block_reduce) { |
1380 | if (has_grid_reduce) { |
1381 | // allocate block result |
1382 | indent() << data_type << " " |
1383 | << "block_result_avg_" << block_reduce_name_ << " = " |
1384 | << gen(wop->initAvg()) << ";\n" ; |
1385 | indent() << data_type << " " |
1386 | << "block_result_var_" << block_reduce_name_ << " = " |
1387 | << gen(wop->initVar()) << ";\n" ; |
1388 | indent() << out_N->dtype() << " " |
1389 | << "block_result_n_" << block_reduce_name_ << " = " |
1390 | << gen(wop->initN()) << ";\n" ; |
1391 | } |
1392 | indent() << "blockWelford<" << (tidx ? "true" : "false" ) << ", " |
1393 | << (tidy ? "true" : "false" ) << ", " << (tidz ? "true" : "false" ) |
1394 | << ">(\n" ; |
1395 | if (has_grid_reduce) { |
1396 | indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n" ; |
1397 | indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n" ; |
1398 | indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n" ; |
1399 | } else { |
1400 | indent() << kTab << gen(wop->outAvg()) << ",\n" ; |
1401 | indent() << kTab << gen(wop->outVar()) << ",\n" ; |
1402 | indent() << kTab << gen(wop->outN()) << ",\n" ; |
1403 | } |
1404 | indent() << kTab << gen(in_avg) << ",\n" ; |
1405 | indent() << kTab << out_avg->dtype() << "(" << gen(in_var) << "),\n" ; |
1406 | indent() << kTab << out_N->dtype() << "(" << gen(in_N) << "),\n" ; |
1407 | indent() << kTab << "threadIdx,\n" ; |
1408 | indent() << kTab << "blockDim,\n" ; |
1409 | indent() << kTab << "reinterpret_cast<" << data_type |
1410 | << "*>(shared_mem_avg),\n" ; |
1411 | indent() << kTab << "reinterpret_cast<" << data_type |
1412 | << "*>(shared_mem_var),\n" ; |
1413 | indent() << kTab << "reinterpret_cast<" << out_N->dtype() |
1414 | << "*>(shared_mem_n),\n" ; |
1415 | TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr); |
1416 | TORCH_INTERNAL_ASSERT( |
1417 | wop->predicate() != nullptr && wop->predicate()->hasValue()); |
1418 | auto read_pred = genInline(wop->predicate()); |
1419 | indent() << kTab << read_pred << ",\n" ; |
1420 | if (wop->writePredicate() != nullptr) { |
1421 | TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue()); |
1422 | auto write_pred = genInline(wop->writePredicate()); |
1423 | indent() << kTab << write_pred << ",\n" ; |
1424 | } |
1425 | indent() << kTab << data_type << "(0));\n" ; |
1426 | } |
1427 | } |
1428 | |
1429 | // Support ReductionOp and WelfordOp |
1430 | template <typename REDUCTION_OP> |
1431 | std::string generateGridReduceTemplateFlags( |
1432 | const REDUCTION_OP* rop, |
1433 | const ParallelTypeBitmap& thread_pred) { |
1434 | TORCH_INTERNAL_ASSERT( |
1435 | !rop->isAllreduce(), |
1436 | "This is not for the allreduce reduction kernel\n" ); |
1437 | |
1438 | const auto par_domains = ir_utils::getParallelDomains(rop->outputs()[0]); |
1439 | ArgumentBuilder flags; |
1440 | for (const ParallelType pt : kParallelTypeThreads) { |
1441 | const bool parallel_reduction = |
1442 | par_domains.find(pt) != par_domains.end() && |
1443 | par_domains.at(pt)->isReduction(); |
1444 | const bool pred = thread_pred.get(pt); |
1445 | TORCH_INTERNAL_ASSERT( |
1446 | !(parallel_reduction && pred), "Cannot reduce predicated axis: " , pt); |
1447 | bool flag = false; |
1448 | // Currently assumed that no dimensions parallelized with blocks |
1449 | // are predicated. This assumption may be lifted, but |
1450 | // gridReduction would need some changes. |
1451 | if (isParallelTypeBlockDim(pt)) { |
1452 | TORCH_INTERNAL_ASSERT( |
1453 | !pred, "Predication on block dimensions not allowed: " , pt); |
1454 | flag = parallel_reduction; |
1455 | } else { |
1456 | flag = !pred && !parallel_reduction; |
1457 | } |
1458 | flags.arg(flag); |
1459 | } |
1460 | return flags.str(); |
1461 | } |
1462 | |
1463 | // TODO: This should replace generateGridReduceTemplateFlags once |
1464 | // GridWelford is refactored as GridReduction. |
1465 | template <typename REDUCTION_OP> |
1466 | std::string generateGridReduceTemplateFlags2( |
1467 | const REDUCTION_OP* rop, |
1468 | const ParallelTypeBitmap& thread_pred) { |
1469 | TORCH_INTERNAL_ASSERT( |
1470 | !rop->isAllreduce(), |
1471 | "This is not for the allreduce reduction kernel\n" ); |
1472 | |
1473 | const auto par_domains = |
1474 | ir_utils::getParallelDomains(ir_utils::getTvOutput(rop)); |
1475 | ArgumentBuilder flags; |
1476 | for (const ParallelType pt : kParallelTypeThreads) { |
1477 | const bool parallel_reduction = |
1478 | par_domains.find(pt) != par_domains.end() && |
1479 | par_domains.at(pt)->isReduction(); |
1480 | const bool pred = thread_pred.get(pt); |
1481 | TORCH_INTERNAL_ASSERT( |
1482 | !(parallel_reduction && pred), "Cannot reduce predicated axis: " , pt); |
1483 | // Currently assumed that no dimensions parallelized with blocks |
1484 | // are predicated. This assumption may be lifted, but |
1485 | // gridReduction would need some changes. |
1486 | if (isParallelTypeBlockDim(pt)) { |
1487 | TORCH_INTERNAL_ASSERT( |
1488 | !pred, "Predication on block dimensions not allowed: " , pt); |
1489 | } |
1490 | flags.arg(parallel_reduction); |
1491 | } |
1492 | return flags.str(); |
1493 | } |
1494 | |
1495 | void addProfileArguments(ArgumentBuilder& func_args, const Expr* expr) { |
1496 | if (isOptionEnabled(EnableOption::KernelProfile) && |
1497 | kernel_->profile().isProfiled(expr)) { |
1498 | const auto& buffer_indices = |
1499 | kernel_->profile().getIndicesInProfileBuffer(expr); |
1500 | auto buffer = kernel_->profile().getBuffer(); |
1501 | TORCH_INTERNAL_ASSERT(buffer != nullptr); |
1502 | for (const auto& index : buffer_indices) { |
1503 | func_args.arg(varName(buffer)).append("[" ).append(index).append("]" ); |
1504 | } |
1505 | } |
1506 | } |
1507 | |
1508 | void handle(const kir::GridReduction* grop) final { |
1509 | TORCH_INTERNAL_ASSERT(grop->out()->isA<kir::TensorIndex>()); |
1510 | |
1511 | const auto out = grop->out()->as<kir::TensorIndex>(); |
1512 | const auto domain = out->view()->domain(); |
1513 | TORCH_INTERNAL_ASSERT(domain->hasGridReduction()); |
1514 | |
1515 | const auto data_type = grop->out()->dtype(); |
1516 | const auto op_type = grop->getReductionOpType(); |
1517 | |
1518 | TORCH_INTERNAL_ASSERT( |
1519 | grop->reduction_buffer()->buffer()->isA<TensorView>()); |
1520 | TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>()); |
1521 | const auto work_buffer = |
1522 | grop->reduction_buffer()->buffer()->as<TensorView>(); |
1523 | const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>(); |
1524 | |
1525 | if (grop->isAllreduce()) { |
1526 | generateGridAllreduce(grop); |
1527 | return; |
1528 | } |
1529 | |
1530 | const std::string flags_str = |
1531 | generateGridReduceTemplateFlags2(grop, grop->threadPredicate()); |
1532 | |
1533 | const bool persistent_sync = |
1534 | kernel_->summary().has_cooperative_grid_reduction; |
1535 | |
1536 | // Since block-level reduction is already done, those dimensions |
1537 | // with tidx/y/z being true do not participate in the grid |
1538 | // reduction. |
1539 | ArgumentBuilder template_args; |
1540 | template_args.arg(flags_str).arg(persistent_sync); |
1541 | |
1542 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
1543 | func_args.arg(gen(grop->out())); |
1544 | func_args.arg(gen(grop->in())); |
1545 | func_args.arg(genReductionOp(op_type, out->dtype())); |
1546 | func_args.arg("&" ).append(varName(work_buffer)).append("[0]" ); |
1547 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
1548 | func_args.arg(genCall("static_cast" , ptrType(data_type), "shared_mem" )); |
1549 | // read and write predicates |
1550 | TORCH_INTERNAL_ASSERT( |
1551 | grop->predicate() != nullptr && grop->predicate()->hasValue()); |
1552 | const auto read_pred = genInline(grop->predicate()); |
1553 | func_args.arg(read_pred); |
1554 | if (grop->writePredicate() != nullptr) { |
1555 | TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue()); |
1556 | func_args.arg(genInline(grop->writePredicate())); |
1557 | } else { |
1558 | func_args.arg(read_pred); |
1559 | } |
1560 | // Init val |
1561 | func_args.arg(genCall(data_type, genInline(grop->init()))); |
1562 | func_args.arg(genInline(grop->entrance_index())); |
1563 | func_args.arg(genInline(grop->entrances())); |
1564 | |
1565 | addProfileArguments(func_args, grop); |
1566 | |
1567 | indent() << "reduction::gridReduce<" << template_args << ">(\n" ; |
1568 | indent() << kTab << func_args << ");\n" ; |
1569 | } |
1570 | |
1571 | std::string genFusedReductionName(const TensorView* reduction_out) { |
1572 | return varName(reduction_out) + "_reduction" ; |
1573 | } |
1574 | |
1575 | void generateGridAllreduce(const kir::GridReduction* grop) { |
1576 | TORCH_INTERNAL_ASSERT(grop->isAllreduce()); |
1577 | |
1578 | const auto out = grop->out()->as<kir::TensorIndex>(); |
1579 | |
1580 | const auto data_type = grop->out()->dtype(); |
1581 | const auto op_type = grop->getReductionOpType(); |
1582 | |
1583 | const auto work_buffer = |
1584 | grop->reduction_buffer()->buffer()->as<TensorView>(); |
1585 | const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>(); |
1586 | |
1587 | const auto reduction_name = genFusedReductionName(out->view()); |
1588 | |
1589 | // template <typename Func, typename... Types> |
1590 | // __device__ __inline__ void reduce( |
1591 | // RefTuple<Types...> out, |
1592 | // const LocalTuple<Types...>& inp, |
1593 | // VolatilePtrTuple<Types...> global_work_buffer, |
1594 | // int64_t* global_sync_buffer, // Allocated as product of all |
1595 | // // non-participating Grid dimension |
1596 | // PtrTuple<Types...> shared_buf, |
1597 | // bool read_pred, // Prevent reading from out of bounds memory |
1598 | // bool write_pred, // Prevent from writing out of bounds |
1599 | // const LocalTuple<Types...>& init_val, |
1600 | // Func reduction_op); |
1601 | |
1602 | indent() << reduction_name << ".reduce(\n" ; |
1603 | |
1604 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
1605 | // out |
1606 | func_args.arg(genCall("RefTuple" , data_type, gen(grop->out()))); |
1607 | // inp |
1608 | func_args.arg(genCall("ConstRefTuple" , data_type, gen(grop->in()))); |
1609 | // global_work_buffer |
1610 | func_args.arg(genCall( |
1611 | "VolatilePtrTuple" , data_type, "&" + varName(work_buffer) + "[0]" )); |
1612 | // global_sync_buffer |
1613 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
1614 | // shared_buf |
1615 | func_args.arg(genCall( |
1616 | "PtrTuple" , |
1617 | data_type, |
1618 | genCall("static_cast" , ptrType(data_type), "shared_mem" ))); |
1619 | // read and write predicates |
1620 | TORCH_INTERNAL_ASSERT( |
1621 | grop->predicate() != nullptr && grop->predicate()->hasValue()); |
1622 | const auto read_pred = genInline(grop->predicate()); |
1623 | auto write_pred = read_pred; |
1624 | if (grop->writePredicate() != nullptr) { |
1625 | TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue()); |
1626 | write_pred = genInline(grop->writePredicate()); |
1627 | } |
1628 | func_args.arg(read_pred).arg(write_pred); |
1629 | // init_val |
1630 | func_args.arg(genCall("LocalTuple" , data_type, genInline(grop->init()))); |
1631 | // reduction_op |
1632 | func_args.arg(genReductionOp(op_type, out->dtype())); |
1633 | |
1634 | addProfileArguments(func_args, grop); |
1635 | |
1636 | indent() << kTab << func_args << ");\n" ; |
1637 | } |
1638 | |
1639 | void handle(const kir::GroupedGridReduction* grouped_grop) final { |
1640 | const auto out = ir_utils::getTvOutput(grouped_grop); |
1641 | const auto domain = out->domain(); |
1642 | TORCH_INTERNAL_ASSERT(domain->hasGridReduction()); |
1643 | |
1644 | TORCH_INTERNAL_ASSERT( |
1645 | grouped_grop->sync_buffer()->buffer()->isA<TensorView>()); |
1646 | const auto sync_buffer = |
1647 | grouped_grop->sync_buffer()->buffer()->as<TensorView>(); |
1648 | |
1649 | if (grouped_grop->isAllreduce()) { |
1650 | generateGroupedGridAllreduce(grouped_grop); |
1651 | return; |
1652 | } |
1653 | |
1654 | TORCH_INTERNAL_ASSERT( |
1655 | grouped_grop->numExprs() == 2, |
1656 | "Only grouping of 2 reductions is supported. " , |
1657 | grouped_grop->toString()); |
1658 | |
1659 | const std::string flags_str = generateGridReduceTemplateFlags2( |
1660 | grouped_grop, grouped_grop->threadPredicate()); |
1661 | |
1662 | const bool persistent_sync = |
1663 | kernel_->summary().has_cooperative_grid_reduction; |
1664 | |
1665 | // Since block-level reduction is already done, those dimensions |
1666 | // with tidx/y/z being true do not participate in the grid |
1667 | // reduction. |
1668 | ArgumentBuilder template_args; |
1669 | template_args.arg(flags_str).arg(persistent_sync); |
1670 | |
1671 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
1672 | |
1673 | // Append arguments for each reduction |
1674 | for (const auto i : c10::irange(grouped_grop->numExprs())) { |
1675 | TORCH_INTERNAL_ASSERT( |
1676 | grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>()); |
1677 | const auto work_buffer = |
1678 | grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>(); |
1679 | |
1680 | func_args.arg(gen(grouped_grop->output(i))); |
1681 | func_args.arg(gen(grouped_grop->input(i))); |
1682 | func_args.arg(genCall( |
1683 | grouped_grop->output(i)->dtype(), |
1684 | genInline(grouped_grop->initVal(i)))); |
1685 | func_args.arg(genReductionOp( |
1686 | grouped_grop->getReductionOpType(i), |
1687 | grouped_grop->output(i)->dtype())); |
1688 | func_args.arg("&" ).append(varName(work_buffer)).append("[0]" ); |
1689 | } |
1690 | |
1691 | // The rest of the arguments are common between the reductions |
1692 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
1693 | func_args.arg("shared_mem" ); |
1694 | // read and write predicates |
1695 | TORCH_INTERNAL_ASSERT( |
1696 | grouped_grop->predicate() != nullptr && |
1697 | grouped_grop->predicate()->hasValue()); |
1698 | const auto read_pred = genInline(grouped_grop->predicate()); |
1699 | func_args.arg(read_pred); |
1700 | if (grouped_grop->writePredicate() != nullptr) { |
1701 | TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue()); |
1702 | func_args.arg(genInline(grouped_grop->writePredicate())); |
1703 | } else { |
1704 | func_args.arg(read_pred); |
1705 | } |
1706 | |
1707 | func_args.arg(genInline(grouped_grop->entrance_index())); |
1708 | func_args.arg(genInline(grouped_grop->entrances())); |
1709 | |
1710 | addProfileArguments(func_args, grouped_grop); |
1711 | |
1712 | indent() << "reduction::gridReduceGroup<" << template_args << ">(\n" ; |
1713 | indent() << kTab << func_args << ");\n" ; |
1714 | } |
1715 | |
1716 | void handle(const kir::GroupedGridWelford* grouped_gwop) final { |
1717 | if (grouped_gwop->isAllreduce()) { |
1718 | generateGroupedGridAllreduceWelford(grouped_gwop); |
1719 | return; |
1720 | } else { |
1721 | TORCH_INTERNAL_ASSERT( |
1722 | false, "Non-allreduce grouped grid welford is not yet supported" ); |
1723 | } |
1724 | } |
1725 | |
1726 | // Enumerates all combinations of index values of grouped |
1727 | // loops. Each combination is a vector of loop index values. The |
1728 | // length of the vector is the number of grouped loops. |
1729 | // |
1730 | // Example 1: only one domain of extent 2 is grouped: {{0}, {1}}. |
1731 | // Example 2: two domains of extents 2 and 3 are grouped: {{0, 0}, |
1732 | // {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}} |
1733 | std::vector<std::vector<int64_t>> getGroupedLoopIndexConcreteIntSets() { |
1734 | std::vector<std::vector<int64_t>> index_combinationsatoins; |
1735 | |
1736 | // Initialize with an empty vector |
1737 | index_combinationsatoins.push_back(std::vector<int64_t>()); |
1738 | |
1739 | // Incrementally build a combinatorial set |
1740 | for (const auto loop : grouped_loops_) { |
1741 | const auto iter_count = loop->stop()->evaluateInt(); |
1742 | std::vector<std::vector<int64_t>> new_combinations; |
1743 | // Append integers from 0 to iter_count to all the vectors built |
1744 | // so far |
1745 | for (const auto& index_vec : index_combinationsatoins) { |
1746 | for (int64_t i = 0; i < iter_count; ++i) { |
1747 | auto index_vec_appended = index_vec; |
1748 | index_vec_appended.push_back(i); |
1749 | new_combinations.push_back(index_vec_appended); |
1750 | } |
1751 | } |
1752 | index_combinationsatoins = std::move(new_combinations); |
1753 | } |
1754 | |
1755 | return index_combinationsatoins; |
1756 | } |
1757 | |
1758 | //! Returns all combinations of maps from index Vals of grouped loops to their |
1759 | //! conrete integers. |
1760 | std::vector<std::unordered_map<const Int*, int64_t>> |
1761 | getLoopIndexReplacementMaps() { |
1762 | std::vector<std::unordered_map<const Int*, int64_t>> maps; |
1763 | |
1764 | if (grouped_loops_.empty()) { |
1765 | std::unordered_map<const Int*, int64_t> empty_map; |
1766 | return {empty_map}; |
1767 | } |
1768 | |
1769 | // Vector of indices of grouped loops |
1770 | std::vector<Int*> loop_indices; |
1771 | std::transform( |
1772 | grouped_loops_.begin(), |
1773 | grouped_loops_.end(), |
1774 | std::back_inserter(loop_indices), |
1775 | [](const kir::ForLoop* loop) { return loop->index()->as<Int>(); }); |
1776 | |
1777 | // All combinations of loop index integer values |
1778 | const auto index_val_sets = getGroupedLoopIndexConcreteIntSets(); |
1779 | |
1780 | // Create maps from loop index Vals to integers |
1781 | for (const auto& index_values : index_val_sets) { |
1782 | TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size()); |
1783 | std::unordered_map<const Int*, int64_t> index_val_map; |
1784 | for (const auto i : c10::irange(loop_indices.size())) { |
1785 | auto loop_index = loop_indices.at(i); |
1786 | auto index_val = index_values.at(i); |
1787 | index_val_map.emplace(loop_index, index_val); |
1788 | } |
1789 | maps.emplace_back(std::move(index_val_map)); |
1790 | } |
1791 | |
1792 | return maps; |
1793 | } |
1794 | |
1795 | void generateGroupedGridAllreduce( |
1796 | const kir::GroupedGridReduction* grouped_grop) { |
1797 | TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce()); |
1798 | |
1799 | // There are two dimensions of grouping: horizontal grouping and |
1800 | // iteration grouping. The total number of individual reductions |
1801 | // is the number of horizontal reductions * the extent of grouped |
1802 | // iterations. All of them are packed into a single grid reduction |
1803 | // call. The number of reductions is limited, and currently it is |
1804 | // simply an error if exceeded. This could be avoided by |
1805 | // decomposing grouped_grop into smaller groups within the |
1806 | // limit. TODO: Support a larger number of reductions. |
1807 | |
1808 | // First, enumerate all combinations of loop index values of |
1809 | // grouped IterDomains. If only a single domain is grouped, this |
1810 | // is simply just a 1D vector of integer from 0 to extent-1. If |
1811 | // two domains are grouped, combinations of two integer vectors |
1812 | // are returned. These loop index value vectors are returned as a |
1813 | // map from loop index Vals to concrete int values. |
1814 | const auto index_replacement_maps = getLoopIndexReplacementMaps(); |
1815 | const auto num_grouped_iterations = index_replacement_maps.size(); |
1816 | |
1817 | // This is also checked at the lowering validaiton time, so it |
1818 | // isn't strictly necessary. |
1819 | TORCH_INTERNAL_ASSERT( |
1820 | num_grouped_iterations * grouped_grop->numExprs() <= |
1821 | kMaxNumGroupedReductions, |
1822 | "Too many grouped reductions: " , |
1823 | grouped_grop->toString(), |
1824 | ". Up to " , |
1825 | kMaxNumGroupedReductions, |
1826 | " reductions are allowed." ); |
1827 | |
1828 | ArgumentBuilder types; |
1829 | ArgumentBuilder outputs; |
1830 | ArgumentBuilder inputs; |
1831 | ArgumentBuilder work_bufs; |
1832 | ArgumentBuilder init_vals; |
1833 | ArgumentBuilder reduction_ops; |
1834 | |
1835 | ArgumentBuilder bool_types; |
1836 | ArgumentBuilder read_preds; |
1837 | ArgumentBuilder write_preds; |
1838 | |
1839 | for (const auto expr_index : c10::irange(grouped_grop->numExprs())) { |
1840 | const auto data_type = grouped_grop->outputs().at(expr_index)->dtype(); |
1841 | TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers() |
1842 | .at(expr_index) |
1843 | ->buffer() |
1844 | ->isA<TensorView>()); |
1845 | |
1846 | for (const auto& group_index : |
1847 | c10::irange(index_replacement_maps.size())) { |
1848 | // Set the index replacement map with the concrete values of |
1849 | // indices of grouped loops. |
1850 | index_replacement_map_ = index_replacement_maps.at(group_index); |
1851 | |
1852 | types.arg(data_type); |
1853 | |
1854 | // out |
1855 | outputs.arg(gen(grouped_grop->outputs().at(expr_index))); |
1856 | |
1857 | // inp |
1858 | inputs.arg(gen(grouped_grop->inputs().at(expr_index))); |
1859 | |
1860 | // global_work_buffer |
1861 | const auto work_buffer = grouped_grop->reduction_buffers() |
1862 | .at(expr_index) |
1863 | ->buffer() |
1864 | ->as<TensorView>(); |
1865 | // Separate Work buffer is used for each reduction. |
1866 | auto work_buffer_offset = group_index == 0 |
1867 | ? "0" |
1868 | : (genInline(grouped_grop->buffer_stride()) + " * " + |
1869 | std::to_string(group_index)); |
1870 | work_bufs.arg("&" ) |
1871 | .append(varName(work_buffer)) |
1872 | .append("[" ) |
1873 | .append(work_buffer_offset) |
1874 | .append("]" ); |
1875 | init_vals.arg(genInline(grouped_grop->initVal(expr_index))); |
1876 | |
1877 | reduction_ops.arg(genReductionOp( |
1878 | grouped_grop->getReductionOpType(expr_index), |
1879 | grouped_grop->output(expr_index)->dtype())); |
1880 | |
1881 | // read and write predicates |
1882 | bool_types.arg("bool" ); |
1883 | // Same argument for all inputs. Different predicates would be |
1884 | // used when grouping is done across iterations |
1885 | TORCH_INTERNAL_ASSERT( |
1886 | grouped_grop->predicate() != nullptr && |
1887 | grouped_grop->predicate()->hasValue()); |
1888 | const auto read_pred = genInline(grouped_grop->predicate()); |
1889 | read_preds.arg(read_pred); |
1890 | if (grouped_grop->writePredicate() != nullptr) { |
1891 | TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue()); |
1892 | write_preds.arg(genInline(grouped_grop->writePredicate())); |
1893 | } else { |
1894 | write_preds.arg(read_pred); |
1895 | } |
1896 | |
1897 | index_replacement_map_.clear(); |
1898 | } |
1899 | } |
1900 | |
1901 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
1902 | func_args.arg(genCall("RefTuple" , types, outputs)); |
1903 | func_args.arg(genCall("ConstRefTuple" , types, inputs)); |
1904 | func_args.arg(genCall("VolatilePtrTuple" , types, work_bufs)); |
1905 | func_args.arg(genCall("LocalTuple" , types, init_vals)); |
1906 | |
1907 | // global_sync_buffer |
1908 | const auto sync_buffer = |
1909 | grouped_grop->sync_buffer()->buffer()->as<TensorView>(); |
1910 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
1911 | |
1912 | // shared_buf |
1913 | func_args.arg("shared_mem" ); |
1914 | |
1915 | func_args.arg(genCall("LocalTuple" , bool_types, read_preds)); |
1916 | func_args.arg(genCall("LocalTuple" , bool_types, write_preds)); |
1917 | |
1918 | addProfileArguments(func_args, grouped_grop); |
1919 | |
1920 | func_args.arg(reduction_ops); |
1921 | |
1922 | indent() << genFusedReductionName(ir_utils::getTvOutput(grouped_grop)) |
1923 | << ".reduceGroup(\n" ; |
1924 | indent() << kTab << func_args << ");\n" ; |
1925 | } |
1926 | |
1927 | // Mostly the same as the grouped grid redution version |
1928 | void generateGroupedGridAllreduceWelford( |
1929 | const kir::GroupedGridWelford* grouped_gwop) { |
1930 | TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce()); |
1931 | |
1932 | const auto index_replacement_maps = getLoopIndexReplacementMaps(); |
1933 | const auto num_grouped_iterations = index_replacement_maps.size(); |
1934 | |
1935 | // This is also checked at the lowering validaiton time, so it |
1936 | // isn't strictly necessary. |
1937 | TORCH_INTERNAL_ASSERT( |
1938 | num_grouped_iterations * grouped_gwop->numExprs() <= |
1939 | kMaxNumGroupedReductions, |
1940 | "Too many grouped reductions: " , |
1941 | grouped_gwop->toString(), |
1942 | ". Up to " , |
1943 | kMaxNumGroupedReductions, |
1944 | " reductions are allowed." ); |
1945 | |
1946 | ArgumentBuilder data_types; |
1947 | ArgumentBuilder index_types; |
1948 | |
1949 | // Note that the data type of var and avg and that of N are the |
1950 | // same with all the welford ops since we only support |
1951 | // grouping of iterations. |
1952 | const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype(); |
1953 | const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype(); |
1954 | |
1955 | std::array<ArgumentBuilder, 3> out_args; |
1956 | std::array<ArgumentBuilder, 3> in_args; |
1957 | std::array<ArgumentBuilder, 3> init_args; |
1958 | std::array<ArgumentBuilder, 3> work_bufs; |
1959 | |
1960 | ArgumentBuilder bool_types; |
1961 | ArgumentBuilder read_preds; |
1962 | ArgumentBuilder write_preds; |
1963 | |
1964 | for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) { |
1965 | const auto& output = grouped_gwop->outputVals().at(expr_index); |
1966 | const auto& input = grouped_gwop->inputVals().at(expr_index); |
1967 | const auto& init = grouped_gwop->initVals().at(expr_index); |
1968 | |
1969 | for (const auto& group_index : |
1970 | c10::irange(index_replacement_maps.size())) { |
1971 | // Set the index replacement map with the concrete values of |
1972 | // indices of grouped loops. |
1973 | index_replacement_map_ = index_replacement_maps.at(group_index); |
1974 | |
1975 | data_types.arg(data_type); |
1976 | index_types.arg(index_type); |
1977 | |
1978 | auto work_buffer_offset = group_index == 0 |
1979 | ? "0" |
1980 | : (genInline(grouped_gwop->buffer_stride()) + " * " + |
1981 | std::to_string(group_index)); |
1982 | |
1983 | // Setup arguments for avg, var, and N |
1984 | for (const auto i : c10::irange(3)) { |
1985 | out_args[i].arg(gen(output.get(i))); |
1986 | in_args[i].arg(gen(input.get(i))); |
1987 | init_args[i].arg(gen(init.get(i))); |
1988 | const auto work_buffer = grouped_gwop->reduction_buffers()[i] |
1989 | .at(expr_index) |
1990 | ->buffer() |
1991 | ->as<TensorView>(); |
1992 | work_bufs[i] |
1993 | .arg("&" ) |
1994 | .append(varName(work_buffer)) |
1995 | .append("[" ) |
1996 | .append(work_buffer_offset) |
1997 | .append("]" ); |
1998 | } |
1999 | |
2000 | // read and write predicates |
2001 | bool_types.arg("bool" ); |
2002 | // Same argument for all inputs. Different predicates would be |
2003 | // used when grouping is done across iterations |
2004 | TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr); |
2005 | TORCH_INTERNAL_ASSERT( |
2006 | grouped_gwop->predicate() != nullptr && |
2007 | grouped_gwop->predicate()->hasValue()); |
2008 | const auto read_pred = genInline(grouped_gwop->predicate()); |
2009 | read_preds.arg(read_pred); |
2010 | if (grouped_gwop->writePredicate() != nullptr) { |
2011 | TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue()); |
2012 | write_preds.arg(genInline(grouped_gwop->writePredicate())); |
2013 | } else { |
2014 | write_preds.arg(read_pred); |
2015 | } |
2016 | |
2017 | index_replacement_map_.clear(); |
2018 | } |
2019 | } |
2020 | |
2021 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
2022 | // output |
2023 | func_args.arg(genCall("RefTuple" , data_types, out_args[0])); |
2024 | func_args.arg(genCall("RefTuple" , data_types, out_args[1])); |
2025 | func_args.arg(genCall("RefTuple" , index_types, out_args[2])); |
2026 | // input |
2027 | func_args.arg(genCall("ConstRefTuple" , data_types, in_args[0])); |
2028 | func_args.arg(genCall("ConstRefTuple" , data_types, in_args[1])); |
2029 | func_args.arg(genCall("ConstRefTuple" , index_types, in_args[2])); |
2030 | // init |
2031 | func_args.arg(genCall("LocalTuple" , data_types, init_args[0])); |
2032 | func_args.arg(genCall("LocalTuple" , data_types, init_args[1])); |
2033 | func_args.arg(genCall("LocalTuple" , index_types, init_args[2])); |
2034 | // work buffer |
2035 | func_args.arg(genCall("VolatilePtrTuple" , data_types, work_bufs[0])); |
2036 | func_args.arg(genCall("VolatilePtrTuple" , data_types, work_bufs[1])); |
2037 | func_args.arg(genCall("VolatilePtrTuple" , index_types, work_bufs[2])); |
2038 | // global_sync_buffer |
2039 | const auto sync_buffer = |
2040 | grouped_gwop->sync_buffer()->buffer()->as<TensorView>(); |
2041 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
2042 | |
2043 | // shared_buf |
2044 | ArgumentBuilder smem_buffer_args; |
2045 | smem_buffer_args.arg( |
2046 | genCall("reinterpret_cast" , ptrType(data_type), "shared_mem_avg" )); |
2047 | smem_buffer_args.arg( |
2048 | genCall("reinterpret_cast" , ptrType(data_type), "shared_mem_var" )); |
2049 | smem_buffer_args.arg( |
2050 | genCall("reinterpret_cast" , ptrType(index_type), "shared_mem_n" )); |
2051 | func_args.arg(genCall( |
2052 | "PtrTuple" , |
2053 | ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type), |
2054 | smem_buffer_args)); |
2055 | |
2056 | func_args.arg(genCall("LocalTuple" , bool_types, read_preds)); |
2057 | func_args.arg(genCall("LocalTuple" , bool_types, write_preds)); |
2058 | |
2059 | addProfileArguments(func_args, grouped_gwop); |
2060 | |
2061 | ArgumentBuilder func_template_args; |
2062 | func_template_args.arg( |
2063 | grouped_gwop->numExprs() * index_replacement_maps.size()); |
2064 | func_template_args.arg(data_type); |
2065 | func_template_args.arg(index_type); |
2066 | |
2067 | indent() << genCall( |
2068 | genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) + |
2069 | ".welfordGroup" , |
2070 | func_template_args, |
2071 | func_args) |
2072 | << ";\n" ; |
2073 | } |
2074 | |
2075 | void handle(const kir::GridBroadcast* grop) final { |
2076 | const auto bop = grop->broadcast_op(); |
2077 | TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>()); |
2078 | |
2079 | const ParallelTypeBitmap parallel_types = |
2080 | kernel_->summary().broadcast_parallel_types.at(bop); |
2081 | |
2082 | TORCH_INTERNAL_ASSERT( |
2083 | parallel_types.hasBID(), |
2084 | "GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types" ); |
2085 | |
2086 | TORCH_INTERNAL_ASSERT( |
2087 | grop->broadcast_buffer()->buffer()->isA<TensorView>()); |
2088 | TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>()); |
2089 | const auto work_buffer = |
2090 | grop->broadcast_buffer()->buffer()->as<TensorView>(); |
2091 | const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>(); |
2092 | |
2093 | std::stringstream flags_str; |
2094 | for (const ParallelType pt : kParallelTypeThreads) { |
2095 | const bool parallel_bcast = parallel_types.get(pt); |
2096 | if (pt != kParallelTypeThreads[0]) { |
2097 | flags_str << ", " ; |
2098 | } |
2099 | flags_str << (parallel_bcast ? "true" : "false" ); |
2100 | } |
2101 | |
2102 | // Since block-level broadcast has not necessarily been performed before |
2103 | // this function call, so grid broadcast may be broadcasting across both |
2104 | // the grid and the block level. |
2105 | indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n" ; |
2106 | indent() << kTab << gen(bop->out()) << ",\n" ; |
2107 | indent() << kTab << gen(bop->in()) << ",\n" ; |
2108 | indent() << kTab << "&" << varName(work_buffer) << "[0],\n" ; |
2109 | indent() << kTab << varName(sync_buffer) << ",\n" ; |
2110 | TORCH_INTERNAL_ASSERT( |
2111 | grop->predicate() != nullptr && grop->predicate()->hasValue()); |
2112 | indent() << kTab << genInline(grop->predicate()) << ");\n" ; |
2113 | } |
2114 | |
2115 | void handle(const kir::GridWelford* gwop) final { |
2116 | const auto wop = gwop->welford_op(); |
2117 | TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>()); |
2118 | |
2119 | const auto out = wop->out()->as<kir::TensorIndex>(); |
2120 | const auto domain = out->view()->domain(); |
2121 | TORCH_INTERNAL_ASSERT(domain->hasGridReduction()); |
2122 | |
2123 | const auto data_type = out->dtype(); |
2124 | |
2125 | TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA<TensorView>()); |
2126 | TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA<TensorView>()); |
2127 | |
2128 | const auto avg_buffer = gwop->avg_buffer()->buffer()->as<TensorView>(); |
2129 | const auto var_buffer = gwop->var_buffer()->buffer()->as<TensorView>(); |
2130 | const auto n_buffer = gwop->N_buffer()->buffer()->as<TensorView>(); |
2131 | const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>(); |
2132 | |
2133 | if (wop->isAllreduce()) { |
2134 | generateGridAllreduce(gwop); |
2135 | return; |
2136 | } |
2137 | |
2138 | const bool persistent_sync = |
2139 | kernel_->summary().has_cooperative_grid_reduction; |
2140 | |
2141 | const std::string flags_str = |
2142 | generateGridReduceTemplateFlags(wop, gwop->threadPredicate()); |
2143 | |
2144 | // Since block-level reduction is already done, those dimensions |
2145 | // with tidx/y/z being true do not participate in the grid reduction. |
2146 | indent() << "welford::gridWelford<" << flags_str << ", " |
2147 | << (persistent_sync ? "true" : "false" ) << ">(\n" ; |
2148 | indent() << kTab << gen(wop->outAvg()) << ",\n" ; |
2149 | indent() << kTab << gen(wop->outVar()) << ",\n" ; |
2150 | indent() << kTab << gen(wop->outN()) << ",\n" ; |
2151 | if (domain->hasBlockReduction()) { |
2152 | indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n" ; |
2153 | indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n" ; |
2154 | indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n" ; |
2155 | block_reduce_name_++; |
2156 | } else { |
2157 | indent() << kTab << gen(wop->inAvg()) << ",\n" ; |
2158 | TORCH_INTERNAL_ASSERT( |
2159 | wop->inVar() != nullptr, "Welford var input nullptr not allowed" ); |
2160 | indent() << kTab << "(" << wop->outVar()->dtype() << ")" |
2161 | << gen(wop->inVar()) << ",\n" ; |
2162 | indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN()) |
2163 | << ",\n" ; |
2164 | } |
2165 | indent() << kTab << "&" << varName(avg_buffer) << "[0],\n" ; |
2166 | indent() << kTab << "&" << varName(var_buffer) << "[0],\n" ; |
2167 | indent() << kTab << "&" << varName(n_buffer) << "[0],\n" ; |
2168 | indent() << kTab << varName(sync_buffer) << ",\n" ; |
2169 | indent() << kTab << "reinterpret_cast<" << data_type |
2170 | << "*>(shared_mem_avg),\n" ; |
2171 | indent() << kTab << "reinterpret_cast<" << data_type |
2172 | << "*>(shared_mem_var),\n" ; |
2173 | indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype() |
2174 | << "*>(shared_mem_n),\n" ; |
2175 | TORCH_INTERNAL_ASSERT( |
2176 | gwop->predicate() != nullptr && gwop->predicate()->hasValue()); |
2177 | auto read_pred = genInline(gwop->predicate()); |
2178 | indent() << kTab << read_pred << ",\n" ; |
2179 | if (gwop->writePredicate() != nullptr) { |
2180 | TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue()); |
2181 | auto write_pred = genInline(gwop->writePredicate()); |
2182 | indent() << kTab << write_pred << ",\n" ; |
2183 | } else { |
2184 | indent() << kTab << read_pred << ",\n" ; |
2185 | } |
2186 | // TODO : init value support or remove. |
2187 | indent() << kTab << data_type << "(0),\n" ; |
2188 | indent() << kTab << genInline(gwop->entrance_index()) << ",\n" ; |
2189 | indent() << kTab << genInline(gwop->entrances()); |
2190 | code_ << ");\n" ; |
2191 | } |
2192 | |
2193 | void generateGridAllreduce(const kir::GridWelford* gwop) { |
2194 | const auto wop = gwop->welford_op(); |
2195 | TORCH_INTERNAL_ASSERT(wop->isAllreduce()); |
2196 | |
2197 | const auto out = wop->out()->as<kir::TensorIndex>(); |
2198 | |
2199 | const auto data_type = wop->outAvg()->dtype(); |
2200 | const auto index_type = wop->outN()->dtype(); |
2201 | TORCH_INTERNAL_ASSERT(wop->outAvg()->dtype() == wop->outVar()->dtype()); |
2202 | |
2203 | ArgumentBuilder data_type_args; |
2204 | data_type_args.arg(data_type).arg(data_type).arg(index_type); |
2205 | |
2206 | const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>(); |
2207 | |
2208 | const auto reduction_name = genFusedReductionName(out->view()); |
2209 | |
2210 | // template <typename Func, typename... Types> |
2211 | // __device__ __inline__ void reduce( |
2212 | // RefTuple<Types...> out, |
2213 | // const LocalTuple<Types...>& inp, |
2214 | // VolatilePtrTuple<Types...> global_work_buffer, |
2215 | // int64_t* global_sync_buffer, // Allocated as product of all |
2216 | // // non-participating Grid dimension |
2217 | // PtrTuple<Types...> shared_buf, |
2218 | // bool read_pred, // Prevent reading from out of bounds memory |
2219 | // bool write_pred, // Prevent from writing out of bounds |
2220 | // const LocalTuple<Types...>& init_val, |
2221 | // Func reduction_op); |
2222 | |
2223 | ArgumentBuilder out_args; |
2224 | out_args.arg(gen(wop->outAvg())); |
2225 | out_args.arg(gen(wop->outVar())); |
2226 | out_args.arg(gen(wop->outN())); |
2227 | |
2228 | ArgumentBuilder in_args; |
2229 | in_args.arg(gen(wop->inAvg())); |
2230 | if (wop->inVar() != nullptr) { |
2231 | in_args.arg(gen(wop->inVar())); |
2232 | } else { |
2233 | in_args.arg("(" ).append(data_type).append(")0" ); |
2234 | } |
2235 | in_args.arg(gen(wop->inN())); |
2236 | |
2237 | ArgumentBuilder init_args; |
2238 | init_args.arg(gen(wop->initAvg())); |
2239 | init_args.arg(gen(wop->initVar())); |
2240 | init_args.arg(gen(wop->initN())); |
2241 | |
2242 | ArgumentBuilder work_buffer_args; |
2243 | work_buffer_args.arg("&" ) |
2244 | .append(varName(gwop->avg_buffer()->buffer()->as<TensorView>())) |
2245 | .append("[0]" ); |
2246 | work_buffer_args.arg("&" ) |
2247 | .append(varName(gwop->var_buffer()->buffer()->as<TensorView>())) |
2248 | .append("[0]" ); |
2249 | work_buffer_args.arg("&" ) |
2250 | .append(varName(gwop->N_buffer()->buffer()->as<TensorView>())) |
2251 | .append("[0]" ); |
2252 | |
2253 | ArgumentBuilder smem_buffer_args; |
2254 | smem_buffer_args.arg( |
2255 | genCall("reinterpret_cast" , ptrType(data_type), "shared_mem_avg" )); |
2256 | smem_buffer_args.arg( |
2257 | genCall("reinterpret_cast" , ptrType(data_type), "shared_mem_var" )); |
2258 | smem_buffer_args.arg( |
2259 | genCall("reinterpret_cast" , ptrType(index_type), "shared_mem_n" )); |
2260 | |
2261 | ArgumentBuilder func_args(block_nest_level_ + 1, kTab); |
2262 | // out |
2263 | func_args.arg(genCall("RefTuple" , data_type_args, out_args)); |
2264 | // inp |
2265 | func_args.arg(genCall("ConstRefTuple" , data_type_args, in_args)); |
2266 | // global_work_buffer |
2267 | func_args.arg( |
2268 | genCall("VolatilePtrTuple" , data_type_args, work_buffer_args)); |
2269 | // global_sync_buffer |
2270 | func_args.arg("&" ).append(varName(sync_buffer)).append("[0]" ); |
2271 | // shared_buf |
2272 | func_args.arg(genCall("PtrTuple" , data_type_args, smem_buffer_args)); |
2273 | // read and write predicates |
2274 | TORCH_INTERNAL_ASSERT( |
2275 | gwop->predicate() != nullptr && gwop->predicate()->hasValue()); |
2276 | const auto read_pred = genInline(gwop->predicate()); |
2277 | auto write_pred = read_pred; |
2278 | if (gwop->writePredicate() != nullptr) { |
2279 | TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue()); |
2280 | write_pred = genInline(gwop->writePredicate()); |
2281 | } |
2282 | func_args.arg(read_pred).arg(write_pred); |
2283 | // init_val |
2284 | func_args.arg(genCall("LocalTuple" , data_type_args, init_args)); |
2285 | // reduction_op |
2286 | func_args.arg(genTemplate( |
2287 | "welfordCombine" , ArgumentBuilder().arg(data_type).arg(index_type))); |
2288 | |
2289 | indent() << reduction_name << ".reduce(\n" ; |
2290 | indent() << kTab << func_args << ");\n" ; |
2291 | } |
2292 | |
2293 | void handle(const kir::AllocateFusedReduction* alloc_fused_reduction) final { |
2294 | // See the runtime file of the fused reduction |
2295 | enum class ReductionParallelTypeState { Reduce, Iter, Pred, Inactive }; |
2296 | |
2297 | using ReductionParallelTypeStateArray = |
2298 | ParallelTypeMap<ReductionParallelTypeState>; |
2299 | |
2300 | ReductionParallelTypeStateArray states( |
2301 | ReductionParallelTypeState::Inactive); |
2302 | |
2303 | for (const ParallelType pt : kParallelTypeThreads) { |
2304 | // It may be better to predicate grid reductions on dimensions they don't |
2305 | // actively use, however since that should generally be discouraged (they |
2306 | // should be part of the iter portion of the operation, or they should be |
2307 | // predciated out) we're just going to assume they're part of the iter |
2308 | // dimension. This would cause more communication than strictly necessary |
2309 | // but should not be a common use case. |
2310 | auto pt_dim = kernel_->summary().parallel_dimension_map_.get(pt); |
2311 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
2312 | continue; |
2313 | } |
2314 | // Initialize pt_dim if used to an iter dimension. It may change to a |
2315 | // reduction or predicated dimension later. |
2316 | states[pt] = ReductionParallelTypeState::Iter; |
2317 | } |
2318 | |
2319 | for (auto id : alloc_fused_reduction->out()->view()->domain()->domain()) { |
2320 | auto pt = id->getParallelType(); |
2321 | if (isParallelTypeThread(pt)) { |
2322 | auto state = id->isReduction() ? ReductionParallelTypeState::Reduce |
2323 | : ReductionParallelTypeState::Iter; |
2324 | states[pt] = state; |
2325 | } |
2326 | } |
2327 | |
2328 | for (const auto predicated_pt : alloc_fused_reduction->threadPredicate()) { |
2329 | auto& state = states[predicated_pt]; |
2330 | TORCH_INTERNAL_ASSERT( |
2331 | state != ReductionParallelTypeState::Reduce, |
2332 | "Invalid thread predication: " , |
2333 | predicated_pt); |
2334 | state = ReductionParallelTypeState::Pred; |
2335 | } |
2336 | |
2337 | ArgumentBuilder flags; |
2338 | for (auto pt : kParallelTypeThreads) { |
2339 | flags.arg(static_cast<int>(states[pt])); |
2340 | } |
2341 | |
2342 | // Persistent |
2343 | flags.arg(true); |
2344 | |
2345 | // Broadcast is fused |
2346 | flags.arg(true); |
2347 | |
2348 | const auto reduction_name = |
2349 | genFusedReductionName(alloc_fused_reduction->out()->view()); |
2350 | |
2351 | indent() << genTemplate("fused_reduction::ParallelReduce" , flags) << " " |
2352 | << reduction_name << ";\n" ; |
2353 | } |
2354 | |
2355 | void handleScope(const kir::Scope& scope) { |
2356 | for (auto expr : scope.exprs()) { |
2357 | OptOutConstDispatch::handle(expr); |
2358 | } |
2359 | } |
2360 | |
2361 | void handleTrivialLoop(const kir::ForLoop* loop) { |
2362 | if (loop->vectorize()) { |
2363 | vectorize_scope_ = true; |
2364 | } |
2365 | handleScope(loop->body()); |
2366 | if (loop->vectorize()) { |
2367 | vectorize_scope_ = false; |
2368 | } |
2369 | } |
2370 | |
2371 | void handle(const GroupedReductionOp* grouped_rop) final { |
2372 | for (const auto i : c10::irange(grouped_rop->numExprs())) { |
2373 | TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA<kir::TensorIndex>()); |
2374 | |
2375 | const auto output = grouped_rop->output(i)->as<kir::TensorIndex>(); |
2376 | const auto input = grouped_rop->input(i)->as<kir::TensorIndex>(); |
2377 | const auto domain = output->view()->domain(); |
2378 | const auto op_type = grouped_rop->getReductionOpType(i); |
2379 | |
2380 | const bool has_block_reduce = domain->hasBlockReduction(); |
2381 | const bool has_grid_reduce = domain->hasGridReduction(); |
2382 | |
2383 | TORCH_INTERNAL_ASSERT( |
2384 | !has_grid_reduce, |
2385 | "GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. " , |
2386 | grouped_rop->toString()); |
2387 | |
2388 | if (!has_block_reduce) { |
2389 | genSerialReduction(output, input, op_type); |
2390 | } else if ( |
2391 | auto reduction_id = |
2392 | ir_utils::getMaybeWarpReductionDim(output, input)) { |
2393 | genWarpReduction( |
2394 | output, |
2395 | input, |
2396 | grouped_rop->initVal(i), |
2397 | op_type, |
2398 | grouped_rop->predicate()); |
2399 | } else { |
2400 | genBlockReduction( |
2401 | output, |
2402 | input, |
2403 | grouped_rop->initVal(i), |
2404 | op_type, |
2405 | grouped_rop->predicate(), |
2406 | grouped_rop->writePredicate()); |
2407 | } |
2408 | } |
2409 | } |
2410 | |
2411 | void handle(const GroupedWelfordOp* grouped_wop) final { |
2412 | TORCH_INTERNAL_ASSERT( |
2413 | false, |
2414 | "Should not reach here as grouped welford is only enabled for grid welford," , |
2415 | " which is handled by its own handler" ); |
2416 | } |
2417 | |
2418 | //! True if loop is grouped. The IterDomain of the loop must have |
2419 | //! ParallelType::Group, but it isn't sufficient as the loop may be |
2420 | //! for an initialization expression, for which the loop shold not |
2421 | //! be grouped. Make sure a GroupedGridReduction is found. |
2422 | bool isGroupedLoop(const kir::ForLoop* loop) { |
2423 | if (loop->iter_domain()->getParallelType() != ParallelType::Group) { |
2424 | return false; |
2425 | } |
2426 | return ExprFinder::exists( |
2427 | loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford}); |
2428 | } |
2429 | |
2430 | void handle(const kir::ForLoop* loop) final { |
2431 | if (loop->isTrivial()) { |
2432 | handleTrivialLoop(loop); |
2433 | return; |
2434 | } |
2435 | |
2436 | // If a loop is grouped, no loop is created, but it isn't |
2437 | // considered trivial as the loop trip count is not one. |
2438 | if (isGroupedLoop(loop)) { |
2439 | grouped_loops_.push_back(loop); |
2440 | handleScope(loop->body()); |
2441 | grouped_loops_.pop_back(); |
2442 | return; |
2443 | } |
2444 | |
2445 | const auto gen_index = gen(loop->index()); |
2446 | const auto gen_start = genInline(loop->start()); |
2447 | const auto gen_stop = genInline(loop->stop()); |
2448 | const auto gen_step = genInline(loop->step()); |
2449 | |
2450 | std::stringstream step_code; |
2451 | if (loop->step()->isOneInt()) { |
2452 | step_code << "++" << gen_index; |
2453 | } else { |
2454 | step_code << gen_index << " += " << gen_step; |
2455 | } |
2456 | if (loop->isUnrolled()) { |
2457 | indent() << "#pragma unroll\n" ; |
2458 | } else { |
2459 | indent() << "#pragma unroll 1\n" ; |
2460 | } |
2461 | |
2462 | indent() << "for(nvfuser_index_t " << gen_index; |
2463 | if (loop->iter_domain()->isParallelized()) { |
2464 | code_ << " = " << gen_start << "; " ; |
2465 | } else { |
2466 | // Do not start at the start of the ID when not parallelized. Instead, |
2467 | // start at 0. Predicates will protect buffers between 0 and ID->start(), |
2468 | // however if we started at ID->start and extent == ID->start, we could |
2469 | // have a "degenerate" loop (loop with no iterations). It may not be an |
2470 | // issue to have a 0-sized loop, but all potential consequences haven't |
2471 | // been covered. One example is WAR analysis which could incorrectly think |
2472 | // a barrier inside a 0-sized loop actually provides protection. |
2473 | code_ << " = 0; " ; |
2474 | } |
2475 | code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") " ; |
2476 | startBlock(true); |
2477 | handleScope(loop->body()); |
2478 | endBlock(); |
2479 | } |
2480 | |
2481 | void handle(const kir::IfThenElse* ite) final { |
2482 | auto conditional = ite->predicate()->value(); |
2483 | if (conditional->isConst()) { |
2484 | // If the conditional is a constant, then the IfThenElse is not required |
2485 | if (conditional->value().value()) { |
2486 | handleScope(ite->thenBody()); |
2487 | } else { |
2488 | handleScope(ite->elseBody()); |
2489 | } |
2490 | return; |
2491 | } |
2492 | |
2493 | indent() << "if (" << genInline(conditional) << ") " ; |
2494 | |
2495 | // "then" block |
2496 | startBlock(true); |
2497 | handleScope(ite->thenBody()); |
2498 | |
2499 | // "else" block (optional) |
2500 | if (ite->hasElse()) { |
2501 | endBlock(" else " ); |
2502 | startBlock(true); |
2503 | handleScope(ite->elseBody()); |
2504 | } |
2505 | |
2506 | endBlock(); |
2507 | } |
2508 | |
2509 | void handle(const kir::Allocate* alloc) final { |
2510 | const auto buffer_dtype = alloc->buffer()->dtype(); |
2511 | |
2512 | TORCH_INTERNAL_ASSERT(alloc->buffer() != nullptr); |
2513 | alloc_map_.emplace(alloc->buffer(), alloc); |
2514 | |
2515 | if (!alloc->buffer()->isA<TensorView>()) { |
2516 | indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n" ; |
2517 | return; |
2518 | } |
2519 | |
2520 | const auto tv = alloc->buffer()->as<TensorView>(); |
2521 | |
2522 | const auto size = alloc->size(); |
2523 | TORCH_INTERNAL_ASSERT(size != nullptr); |
2524 | |
2525 | if (alloc->alias() != nullptr) { |
2526 | // Allocate alias another Allocate stmt |
2527 | const auto alias_tv = alloc->alias()->buffer()->as<TensorView>(); |
2528 | indent() << "// Alias Allocation - " << alloc->memoryType() << "\n" ; |
2529 | indent() << "auto& " << varName(tv) << " = " << varName(alias_tv) |
2530 | << ";\n" ; |
2531 | |
2532 | } else { |
2533 | // Standard Memory Allocation |
2534 | switch (tv->getMemoryType()) { |
2535 | case MemoryType::Global: |
2536 | indent() << "// Allocate global tensor " << varName(tv) << "\n" ; |
2537 | break; |
2538 | case MemoryType::Shared: |
2539 | // Align Offset Position |
2540 | indent() << "smem_offset = alignBufferSize(smem_offset, " |
2541 | // Always align to 128b / 16B |
2542 | << 16 << ");\n" ; |
2543 | // Shared Memory Pointer |
2544 | indent() << buffer_dtype << "* " << varName(tv) |
2545 | << " = reinterpret_cast<" << buffer_dtype << "*>" |
2546 | << "(array + smem_offset);\n" ; |
2547 | // Increment Offset Position |
2548 | indent() << "smem_offset += (" << genInline(size) << " * sizeof(" |
2549 | << buffer_dtype << "));\n" ; |
2550 | break; |
2551 | case MemoryType::Local: { |
2552 | auto va = kernel_->summary().vectorized_accesses; |
2553 | if (va.find(tv) != va.end()) { |
2554 | indent() << "Array<" << buffer_dtype << ", " << genInline(size) |
2555 | << ", " << va.at(tv) << "> " << varName(tv) << ";\n" ; |
2556 | } else { |
2557 | indent() << buffer_dtype << " " << varName(tv) << "[" |
2558 | << genInline(size) << "];\n" ; |
2559 | } |
2560 | } break; |
2561 | default: |
2562 | TORCH_INTERNAL_ASSERT(false, "Unexpected memory type" ); |
2563 | } |
2564 | } |
2565 | } |
2566 | |
2567 | void handle(const kir::BlockSync* sync) final { |
2568 | // Use a custom synchronization method if enabled |
2569 | if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC" )) { |
2570 | indent() << "block_sync::sync();\n" ; |
2571 | } else { |
2572 | indent() << "__barrier_sync(0);\n" ; |
2573 | } |
2574 | } |
2575 | |
2576 | void handle(const kir::CpAsyncWait* cpasync_wait) final { |
2577 | if (cpasync_wait->keepStages() > 0) { |
2578 | // Perform partial sync, see comment on kir::CpAsyncWait. |
2579 | indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages() |
2580 | << ">();\n" ; |
2581 | } else { |
2582 | // Perform sync all, see comment on kir::CpAsyncWait. |
2583 | indent() << "Ampere::cpAsyncBarrier();\n" ; |
2584 | } |
2585 | } |
2586 | |
2587 | void handle(const kir::CpAsyncCommit* cpasync_wait) final { |
2588 | // Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit. |
2589 | indent() << "Ampere::cpAsyncCommit();\n" ; |
2590 | } |
2591 | |
2592 | void handle(const kir::GridSync* sync) final { |
2593 | // Use a custom synchronization method if enabled |
2594 | bool bidx = sync->syncDims().get(ParallelType::BIDx); |
2595 | bool bidy = sync->syncDims().get(ParallelType::BIDy); |
2596 | bool bidz = sync->syncDims().get(ParallelType::BIDz); |
2597 | |
2598 | ArgumentBuilder sync_call_template_parms; |
2599 | sync_call_template_parms.arg(bidx).arg(bidy).arg(bidz).arg(true); |
2600 | |
2601 | auto sync_idx = genCall( |
2602 | "index_utils::maskedOffset" , |
2603 | ArgumentBuilder().arg(!bidx).arg(!bidy).arg(!bidz), |
2604 | ArgumentBuilder().arg("blockIdx" ).arg("gridDim" )); |
2605 | |
2606 | auto sync_segment_size = genCall( |
2607 | "index_utils::maskedSize" , |
2608 | ArgumentBuilder().arg(bidx).arg(bidy).arg(bidz), |
2609 | ArgumentBuilder().arg("gridDim" )); |
2610 | |
2611 | ArgumentBuilder sync_call_args; |
2612 | sync_call_args.arg(varName(sync->syncBuffer())) |
2613 | .append("[" ) |
2614 | .append(sync_idx) |
2615 | .append("]" ); |
2616 | sync_call_args.arg(sync_segment_size); |
2617 | |
2618 | auto sync_call = |
2619 | genCall("grid_sync::sync" , sync_call_template_parms, sync_call_args); |
2620 | |
2621 | indent() << sync_call << ";\n" ; |
2622 | } |
2623 | |
2624 | void handle(const kir::InitMagicZero*) final { |
2625 | indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n" ; |
2626 | } |
2627 | |
2628 | void handle(const kir::UpdateMagicZero*) final { |
2629 | indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n" ; |
2630 | } |
2631 | |
2632 | void handle(const kir::Swizzle2DInt* swizzle_2d) final { |
2633 | TORCH_INTERNAL_ASSERT(print_inline_); |
2634 | TORCH_INTERNAL_ASSERT( |
2635 | swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle, |
2636 | "Swizzle type undefined." ); |
2637 | if (print_inline_) { |
2638 | code_ << swizzle_2d->swizzleType() << "({" << gen(swizzle_2d->inX()) |
2639 | << "," << gen(swizzle_2d->inY()) << "} , " |
2640 | << "{" << gen(swizzle_2d->extentX()) << "," |
2641 | << gen(swizzle_2d->extentY()) << "})" ; |
2642 | } |
2643 | } |
2644 | |
2645 | void handle(const kir::IntPair* int_pair) final { |
2646 | const auto def = int_pair->definition(); |
2647 | TORCH_INTERNAL_ASSERT( |
2648 | def != nullptr, "no support for un-inlined int pair yet." ); |
2649 | code_ << gen(def); |
2650 | } |
2651 | |
2652 | void handle(const kir::PairSelect* pair_select) final { |
2653 | if (print_inline_) { |
2654 | code_ << gen(pair_select->in()); |
2655 | } else { |
2656 | indent() << gen(pair_select->out()) << " = " << gen(pair_select->in()); |
2657 | } |
2658 | |
2659 | switch (pair_select->selection()) { |
2660 | case kir::PairSelect::Selection::X: |
2661 | code_ << ".x" ; |
2662 | break; |
2663 | case kir::PairSelect::Selection::Y: |
2664 | code_ << ".y" ; |
2665 | break; |
2666 | default: |
2667 | TORCH_INTERNAL_ASSERT(false, "unknown select" ) |
2668 | break; |
2669 | } |
2670 | |
2671 | if (!print_inline_) { |
2672 | code_ << ";\n" ; |
2673 | } |
2674 | } |
2675 | |
2676 | private: |
2677 | std::stringstream code_; |
2678 | const kir::Kernel* kernel_; |
2679 | int block_nest_level_ = 0; |
2680 | int block_reduce_name_ = 0; |
2681 | bool print_inline_ = false; |
2682 | |
2683 | // Mark when we are inside of a vectorized for-loop |
2684 | bool vectorize_scope_ = false; |
2685 | //! Keep track of Allocate node for Val. Used to determine if Val |
2686 | //! should be inlined. |
2687 | std::unordered_map<const Val*, const kir::Allocate*> alloc_map_; |
2688 | //! Keep track of grouped loops |
2689 | std::deque<const kir::ForLoop*> grouped_loops_; |
2690 | //! Used to replace symbolic indices with concrete values |
2691 | std::unordered_map<const Int*, int64_t> index_replacement_map_; |
2692 | }; |
2693 | |
2694 | } // namespace |
2695 | |
2696 | std::string generateCudaKernel( |
2697 | const kir::Kernel* kernel, |
2698 | const std::string& kernel_name) { |
2699 | FUSER_PERF_SCOPE("generateCudaKernel" ); |
2700 | return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name); |
2701 | } |
2702 | |
2703 | } // namespace codegen |
2704 | } // namespace cuda |
2705 | } // namespace fuser |
2706 | } // namespace jit |
2707 | } // namespace torch |
2708 | |