1 | #include <lower_magic_zero.h> |
2 | |
3 | #include <dispatch.h> |
4 | #include <instrumentation.h> |
5 | #include <ir_utils.h> |
6 | #include <kernel_ir_dispatch.h> |
7 | #include <lower2device.h> |
8 | #include <lower_index_compute.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | namespace { |
16 | |
17 | class MagicZeroInserter : public kir::ExprMutator { |
18 | public: |
19 | static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) { |
20 | MagicZeroInserter inserter(exprs); |
21 | return inserter.exprs_; |
22 | } |
23 | |
24 | private: |
25 | struct InsertionInfo { |
26 | kir::Scope* scope = nullptr; |
27 | kir::ForLoop* fl = nullptr; |
28 | }; |
29 | |
30 | MagicZeroInserter(const std::vector<Expr*>& exprs) { |
31 | TORCH_INTERNAL_ASSERT(exprs.size()); |
32 | kir::ExprMutator::registerInsertBefore( |
33 | exprs.front(), IrBuilder::create<kir::InitMagicZero>(), nullptr); |
34 | kir::ExprMutator::traverseAndInsert(exprs); |
35 | } |
36 | |
37 | void handle(kir::ForLoop* fl) final { |
38 | if (fl->isUnrolled()) { |
39 | if (scope_.empty()) { |
40 | kir::ExprMutator::registerInsertAfter( |
41 | fl, IrBuilder::create<kir::UpdateMagicZero>()); |
42 | } else { |
43 | TORCH_INTERNAL_ASSERT( |
44 | scope_.back()->exprs().size(), "Not expecting an empty loop." ); |
45 | kir::ExprMutator::registerInsertAfter( |
46 | fl, IrBuilder::create<kir::UpdateMagicZero>(), scope_.back()); |
47 | } |
48 | } else { |
49 | kir::ExprMutator::handle(fl); |
50 | } |
51 | } |
52 | |
53 | std::vector<InsertionInfo> insertion_list_; |
54 | }; |
55 | |
56 | } // namespace |
57 | |
58 | std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs) { |
59 | FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero" ); |
60 | // Check if magic zero was even used, if not we don't have to define it or |
61 | // update it. |
62 | const auto gpu_lower = GpuLower::current(); |
63 | auto kernel = gpu_lower->kernel(); |
64 | const bool has_magic_zero = |
65 | std::any_of(kernel->vals().begin(), kernel->vals().end(), [](Val* val) { |
66 | return isMagicZero(val); |
67 | }); |
68 | |
69 | if (!has_magic_zero) { |
70 | return exprs; |
71 | } |
72 | |
73 | return MagicZeroInserter::insert(exprs); |
74 | } |
75 | |
76 | bool isMagicZero(const Val* val) { |
77 | if (!val->isA<NamedScalar>()) { |
78 | return false; |
79 | } |
80 | auto ns = val->as<NamedScalar>(); |
81 | return ns->dtype() == DataType::Int && |
82 | ns->name() == std::string(kMagicZeroName); |
83 | } |
84 | |
85 | bool isProtectedWithMagicZero(const Val* val) { |
86 | if (val->definition() == nullptr || !val->definition()->isA<BinaryOp>()) { |
87 | return false; |
88 | } |
89 | auto bop = val->definition()->as<BinaryOp>(); |
90 | return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs()); |
91 | } |
92 | |
93 | bool needsMagicZero( |
94 | kir::ForLoop* loop, |
95 | IterDomain* reference_domain, |
96 | Val* ind) { |
97 | if (ind->isConstScalar()) { |
98 | return false; |
99 | } |
100 | |
101 | bool ref_dom_simple = |
102 | reference_domain == nullptr || reference_domain->definition() != nullptr; |
103 | bool ind_simple = |
104 | ind == nullptr || (ind->definition() != nullptr && !ind->isZeroInt()); |
105 | return loop->isUnrolled() && (!ref_dom_simple || !ind_simple); |
106 | } |
107 | |
108 | void protectNonPredicateIndexWithMagicZero( |
109 | const std::vector<kir::ForLoop*>& loops, |
110 | const std::vector<IterDomain*>& loop_domains, |
111 | std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) { |
112 | // Find magic zero insertion point |
113 | IterDomain* magic_zero_loop = nullptr; |
114 | |
115 | // Search for proper magic zero insertion point, |
116 | // prefer innermost. |
117 | for (auto idx : c10::irange(loops.size())) { |
118 | auto loop = loops[idx]; |
119 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
120 | loop_domains[idx], IdMappingMode::EXACT); |
121 | auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id); |
122 | |
123 | // Save the concrete id if this loop id is decided to |
124 | // be the insertion point by the magic zero util. |
125 | if (needsMagicZero(loop, concrete_loop_id, loop_ind)) { |
126 | magic_zero_loop = concrete_loop_id; |
127 | } |
128 | } |
129 | |
130 | // Insert magic zero if insertion point found |
131 | if (magic_zero_loop != nullptr && |
132 | concrete_loop_idx_map.count(magic_zero_loop)) { |
133 | auto& ind = concrete_loop_idx_map.at(magic_zero_loop); |
134 | ind = SimplifyingIrBuilder::addExpr( |
135 | ind, GpuLower::current()->kernel()->magicZeroVal()); |
136 | } |
137 | } |
138 | |
139 | namespace { |
140 | |
141 | //! Protect loop_index_to_protect appearing in overall_index_val |
142 | IndexMagicZeroInfo protectIndexByReplacingLoopIndex( |
143 | IterDomain* loop_id, |
144 | Val* overall_index_val, |
145 | Val* loop_index_to_protect) { |
146 | auto protected_loop_index = SimplifyingIrBuilder::addExpr( |
147 | loop_index_to_protect, GpuLower::current()->kernel()->magicZeroVal()); |
148 | |
149 | std::unordered_map<Val*, Val*> replacement_map; |
150 | replacement_map[loop_index_to_protect] = protected_loop_index; |
151 | |
152 | auto protected_index = |
153 | ir_utils::replaceValInIndexVal(overall_index_val, replacement_map); |
154 | |
155 | IndexMagicZeroInfo info; |
156 | info.index = protected_index; |
157 | info.original_loop_index = loop_index_to_protect; |
158 | info.protected_loop_index = protected_loop_index; |
159 | info.loop_id = loop_id; |
160 | return info; |
161 | } |
162 | |
163 | } // namespace |
164 | |
165 | IndexMagicZeroInfo protectPredicateIndexWithMagicZero( |
166 | Val* index, |
167 | const IndexFromIdGraph& id_graph, |
168 | const std::vector<kir::ForLoop*>& loops) { |
169 | // Gather the loop indices |
170 | std::unordered_set<Val*> loop_indices; |
171 | for (auto loop_id : id_graph.resolved_loop_domains) { |
172 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
173 | loop_id, IdMappingMode::EXACT); |
174 | auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id); |
175 | TORCH_INTERNAL_ASSERT( |
176 | index_it != id_graph.initial_concrete_index_map.end(), |
177 | "Index not found for loop: " , |
178 | concrete_loop_id->toString()); |
179 | auto loop_index = index_it->second; |
180 | loop_indices.insert(loop_index); |
181 | } |
182 | |
183 | // Figure out which loop indices are used in index |
184 | const auto vals = DependencyCheck::getAllValsBetween(loop_indices, {index}); |
185 | |
186 | // Traverser from the inner-most loop and apply the magic-zero |
187 | // prorection if needed |
188 | for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) { |
189 | auto loop = loops.at(i); |
190 | auto loop_id = id_graph.resolved_loop_domains.at(i); |
191 | TORCH_INTERNAL_ASSERT(GpuLower::current()->caMap()->areMapped( |
192 | loop_id, loop->iter_domain(), IdMappingMode::PERMISSIVE)); |
193 | IterDomain* concrete_loop_id = |
194 | GpuLower::current()->caMap()->getConcreteMappedID( |
195 | loop_id, IdMappingMode::EXACT); |
196 | auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id); |
197 | TORCH_INTERNAL_ASSERT( |
198 | index_it != id_graph.initial_concrete_index_map.end()); |
199 | auto loop_index = index_it->second; |
200 | |
201 | const auto is_loop_index_used = |
202 | std::find(vals.begin(), vals.end(), loop_index) != vals.end(); |
203 | |
204 | if (!is_loop_index_used) { |
205 | continue; |
206 | } |
207 | |
208 | if (needsMagicZero(loop, concrete_loop_id, loop_index)) { |
209 | return protectIndexByReplacingLoopIndex(loop_id, index, loop_index); |
210 | } |
211 | } |
212 | |
213 | // No loop is identified to require protection with magic zero. Just |
214 | // return the index argument as is |
215 | IndexMagicZeroInfo not_proteced; |
216 | not_proteced.index = index; |
217 | return not_proteced; |
218 | } |
219 | |
220 | } // namespace cuda |
221 | } // namespace fuser |
222 | } // namespace jit |
223 | } // namespace torch |
224 | |