1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <instrumentation.h>
6#include <kernel_ir.h>
7#include <kernel_ir_dispatch.h>
8#include <root_domain_map.h>
9
10#include <vector>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17// TODO: Replace with mutator as IndexLowering is replacing expr's with
18// versions that are doing indexing
19class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
20 public:
21 static std::vector<Expr*> getIndexedExprs(std::vector<Expr*> incoming_exprs) {
22 FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs");
23 IndexLowering il;
24 il.generate(incoming_exprs);
25 return il.lowered_exprs_;
26 }
27
28 private:
29 IndexLowering() = default;
30
31 void pushBack(Expr*);
32
33 // Return the most recently inserted
34 // expression in the current active
35 // scope or global scope.
36 Expr* back() const;
37
38 // Insert an expression before the current top-level expression.
39 void insertAtTopLevel(Expr* expr);
40
41 void handle(const FullOp*) final;
42 void handle(const ARangeOp*) final;
43 void handle(const EyeOp*) final;
44 void handle(const ViewAsScalar*) final;
45 void handle(const UnaryOp*) final;
46
47 void handle(const BinaryOp*) final;
48 void handle(const TernaryOp*) final;
49 void handle(const RNGOp*) final;
50 void handle(const ReductionOp*) final;
51 void handle(const GroupedReductionOp*) final;
52 void handle(const WelfordOp*) final;
53 void handle(const GroupedWelfordOp*) final;
54 void handle(const LoadStoreOp*) final;
55 void handle(const MmaOp*) final;
56 void handle(const BroadcastOp*) final;
57
58 void handle(const kir::ForLoop*) final;
59 void handle(const kir::IfThenElse*) final;
60 void handle(const kir::Allocate*) final;
61 void handle(const kir::BlockSync*) final;
62 void handle(const kir::GridSync*) final;
63 void handle(const kir::CpAsyncWait*) final;
64 void handle(const kir::CpAsyncCommit*) final;
65
66 void generate(const std::vector<Expr*>& exprs);
67
68 Val* lowerSrcIndex(Val* val, Val* dst) const;
69
70 Val* lowerDstIndex(Val* dst) const;
71
72 void handleBlockReduction(const ReductionOp* rop, Val* out, Val* in);
73 void handleGridReduction(const ReductionOp* rop, Val* out, Val* in);
74
75 void handleBlockReduction(
76 const GroupedReductionOp* rop,
77 const std::vector<Val*>& outputs,
78 const std::vector<Val*>& inputs);
79 void handleGridReduction(
80 const GroupedReductionOp* rop,
81 const std::vector<Val*>& outputs,
82 const std::vector<Val*>& inputs);
83
84 void handleGridWelford(WelfordOp* new_wop);
85
86 void handleGroupedBlockWelford(
87 const GroupedWelfordOp* wop,
88 const std::vector<WelfordTriplet>& output_vals,
89 const std::vector<WelfordTriplet>& input_vals,
90 const std::vector<WelfordTriplet>& init_vals);
91 void handleGroupedGridWelford(
92 const GroupedWelfordOp* wop,
93 const std::vector<WelfordTriplet>& output_vals,
94 const std::vector<WelfordTriplet>& input_vals,
95 const std::vector<WelfordTriplet>& init_vals);
96
97 // Allocate a unique buffer for grid reductions and broadcast. A
98 // buffer is uniquely allocated for each output tensor of an
99 // expression.
100 kir::Allocate* allocateUniqueBuffer(
101 Val* buffer_size,
102 DataType dtype,
103 bool zero_init,
104 TensorView* out_tv,
105 std::unordered_map<TensorView*, kir::Allocate*>& alloc_map);
106
107 std::vector<kir::Allocate*> allocateWelfordWorkBuffer(
108 const std::vector<WelfordTriplet>& triplets,
109 WelfordTriplet::ValName name,
110 Val* buffer_size);
111
112 // Allocate a fused reduction object uniquely for a given
113 // TensorView. Parameter expr is the expression corresponding to the
114 // fused reduction.
115 void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv);
116
117 private:
118 std::vector<Expr*> lowered_exprs_;
119
120 // This is a slight work around as scope has a couple definitions, we have the
121 // Scope that's in ForLoop/IfThenElse which is really just a wrapper around
122 // std::vector<Expr*> and then we have the actual ForLoop/IfThenElse. We want
123 // to be able to carry both around because when we push back to a scope it
124 // could be either the body or else body of the IfThenElse. However, we want
125 // to understand the nesting of IfThenElse/ForLoop nodes.
126 kir::Scope* active_scope_ = nullptr;
127
128 // Track for loops to send to indexing. Similar to what's done in
129 // kir::IrVisitor
130 std::vector<kir::ForLoop*> for_loops_;
131
132 // Maps to keep track of allocated buffers and objects that must be
133 // allocated only once
134 std::unordered_map<TensorView*, kir::Allocate*> sync_buffer_map_;
135 std::unordered_map<TensorView*, kir::Allocate*> work_buffer_map_;
136 std::unordered_map<TensorView*, kir::AllocateFusedReduction*>
137 fused_reduction_map_;
138};
139
140} // namespace cuda
141} // namespace fuser
142} // namespace jit
143} // namespace torch
144