1#include <lower_magic_zero.h>
2
3#include <dispatch.h>
4#include <instrumentation.h>
5#include <ir_utils.h>
6#include <kernel_ir_dispatch.h>
7#include <lower2device.h>
8#include <lower_index_compute.h>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15namespace {
16
17class MagicZeroInserter : public kir::ExprMutator {
18 public:
19 static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) {
20 MagicZeroInserter inserter(exprs);
21 return inserter.exprs_;
22 }
23
24 private:
25 struct InsertionInfo {
26 kir::Scope* scope = nullptr;
27 kir::ForLoop* fl = nullptr;
28 };
29
30 MagicZeroInserter(const std::vector<Expr*>& exprs) {
31 TORCH_INTERNAL_ASSERT(exprs.size());
32 kir::ExprMutator::registerInsertBefore(
33 exprs.front(), IrBuilder::create<kir::InitMagicZero>(), nullptr);
34 kir::ExprMutator::traverseAndInsert(exprs);
35 }
36
37 void handle(kir::ForLoop* fl) final {
38 if (fl->isUnrolled()) {
39 if (scope_.empty()) {
40 kir::ExprMutator::registerInsertAfter(
41 fl, IrBuilder::create<kir::UpdateMagicZero>());
42 } else {
43 TORCH_INTERNAL_ASSERT(
44 scope_.back()->exprs().size(), "Not expecting an empty loop.");
45 kir::ExprMutator::registerInsertAfter(
46 fl, IrBuilder::create<kir::UpdateMagicZero>(), scope_.back());
47 }
48 } else {
49 kir::ExprMutator::handle(fl);
50 }
51 }
52
53 std::vector<InsertionInfo> insertion_list_;
54};
55
56} // namespace
57
58std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs) {
59 FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero");
60 // Check if magic zero was even used, if not we don't have to define it or
61 // update it.
62 const auto gpu_lower = GpuLower::current();
63 auto kernel = gpu_lower->kernel();
64 const bool has_magic_zero =
65 std::any_of(kernel->vals().begin(), kernel->vals().end(), [](Val* val) {
66 return isMagicZero(val);
67 });
68
69 if (!has_magic_zero) {
70 return exprs;
71 }
72
73 return MagicZeroInserter::insert(exprs);
74}
75
76bool isMagicZero(const Val* val) {
77 if (!val->isA<NamedScalar>()) {
78 return false;
79 }
80 auto ns = val->as<NamedScalar>();
81 return ns->dtype() == DataType::Int &&
82 ns->name() == std::string(kMagicZeroName);
83}
84
85bool isProtectedWithMagicZero(const Val* val) {
86 if (val->definition() == nullptr || !val->definition()->isA<BinaryOp>()) {
87 return false;
88 }
89 auto bop = val->definition()->as<BinaryOp>();
90 return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs());
91}
92
93bool needsMagicZero(
94 kir::ForLoop* loop,
95 IterDomain* reference_domain,
96 Val* ind) {
97 if (ind->isConstScalar()) {
98 return false;
99 }
100
101 bool ref_dom_simple =
102 reference_domain == nullptr || reference_domain->definition() != nullptr;
103 bool ind_simple =
104 ind == nullptr || (ind->definition() != nullptr && !ind->isZeroInt());
105 return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
106}
107
108void protectNonPredicateIndexWithMagicZero(
109 const std::vector<kir::ForLoop*>& loops,
110 const std::vector<IterDomain*>& loop_domains,
111 std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) {
112 // Find magic zero insertion point
113 IterDomain* magic_zero_loop = nullptr;
114
115 // Search for proper magic zero insertion point,
116 // prefer innermost.
117 for (auto idx : c10::irange(loops.size())) {
118 auto loop = loops[idx];
119 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
120 loop_domains[idx], IdMappingMode::EXACT);
121 auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id);
122
123 // Save the concrete id if this loop id is decided to
124 // be the insertion point by the magic zero util.
125 if (needsMagicZero(loop, concrete_loop_id, loop_ind)) {
126 magic_zero_loop = concrete_loop_id;
127 }
128 }
129
130 // Insert magic zero if insertion point found
131 if (magic_zero_loop != nullptr &&
132 concrete_loop_idx_map.count(magic_zero_loop)) {
133 auto& ind = concrete_loop_idx_map.at(magic_zero_loop);
134 ind = SimplifyingIrBuilder::addExpr(
135 ind, GpuLower::current()->kernel()->magicZeroVal());
136 }
137}
138
139namespace {
140
141//! Protect loop_index_to_protect appearing in overall_index_val
142IndexMagicZeroInfo protectIndexByReplacingLoopIndex(
143 IterDomain* loop_id,
144 Val* overall_index_val,
145 Val* loop_index_to_protect) {
146 auto protected_loop_index = SimplifyingIrBuilder::addExpr(
147 loop_index_to_protect, GpuLower::current()->kernel()->magicZeroVal());
148
149 std::unordered_map<Val*, Val*> replacement_map;
150 replacement_map[loop_index_to_protect] = protected_loop_index;
151
152 auto protected_index =
153 ir_utils::replaceValInIndexVal(overall_index_val, replacement_map);
154
155 IndexMagicZeroInfo info;
156 info.index = protected_index;
157 info.original_loop_index = loop_index_to_protect;
158 info.protected_loop_index = protected_loop_index;
159 info.loop_id = loop_id;
160 return info;
161}
162
163} // namespace
164
165IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
166 Val* index,
167 const IndexFromIdGraph& id_graph,
168 const std::vector<kir::ForLoop*>& loops) {
169 // Gather the loop indices
170 std::unordered_set<Val*> loop_indices;
171 for (auto loop_id : id_graph.resolved_loop_domains) {
172 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
173 loop_id, IdMappingMode::EXACT);
174 auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
175 TORCH_INTERNAL_ASSERT(
176 index_it != id_graph.initial_concrete_index_map.end(),
177 "Index not found for loop: ",
178 concrete_loop_id->toString());
179 auto loop_index = index_it->second;
180 loop_indices.insert(loop_index);
181 }
182
183 // Figure out which loop indices are used in index
184 const auto vals = DependencyCheck::getAllValsBetween(loop_indices, {index});
185
186 // Traverser from the inner-most loop and apply the magic-zero
187 // prorection if needed
188 for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
189 auto loop = loops.at(i);
190 auto loop_id = id_graph.resolved_loop_domains.at(i);
191 TORCH_INTERNAL_ASSERT(GpuLower::current()->caMap()->areMapped(
192 loop_id, loop->iter_domain(), IdMappingMode::PERMISSIVE));
193 IterDomain* concrete_loop_id =
194 GpuLower::current()->caMap()->getConcreteMappedID(
195 loop_id, IdMappingMode::EXACT);
196 auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
197 TORCH_INTERNAL_ASSERT(
198 index_it != id_graph.initial_concrete_index_map.end());
199 auto loop_index = index_it->second;
200
201 const auto is_loop_index_used =
202 std::find(vals.begin(), vals.end(), loop_index) != vals.end();
203
204 if (!is_loop_index_used) {
205 continue;
206 }
207
208 if (needsMagicZero(loop, concrete_loop_id, loop_index)) {
209 return protectIndexByReplacingLoopIndex(loop_id, index, loop_index);
210 }
211 }
212
213 // No loop is identified to require protection with magic zero. Just
214 // return the index argument as is
215 IndexMagicZeroInfo not_proteced;
216 not_proteced.index = index;
217 return not_proteced;
218}
219
220} // namespace cuda
221} // namespace fuser
222} // namespace jit
223} // namespace torch
224