1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! |
25 | * \brief A helper class to create a new scope that contains decomposed init body |
26 | * and replaced old reduction block. |
27 | */ |
28 | class DecomposeReductionBlockReplacer : public StmtMutator { |
29 | public: |
30 | /*! |
31 | * \brief The open interface to users to call the helper class |
32 | * \param old_scope_root The original block scope before decomposition |
33 | * \param target_loop The loop we insert the decomposed init body before |
34 | * \param decompose_body The decomposed init body |
35 | * \param old_reduction_block The reduction block we want to decompose |
36 | * \return The new block scope and the updated reduction block |
37 | */ |
38 | static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop, |
39 | Stmt decomposed_body, Block old_reduction_block) { |
40 | DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body), |
41 | std::move(old_reduction_block)); |
42 | return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))), |
43 | replacer.new_reduction_block_); |
44 | } |
45 | |
46 | private: |
47 | explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body, |
48 | Block old_reduction_block) |
49 | : target_loop_(std::move(target_loop)), |
50 | decomposed_body_(std::move(decomposed_body)), |
51 | old_reduction_block_(std::move(old_reduction_block)) {} |
52 | |
53 | Stmt VisitStmt_(const ForNode* loop) final { |
54 | Stmt mutated_stmt = StmtMutator::VisitStmt_(loop); |
55 | if (loop == target_loop_.get()) { |
56 | return SeqStmt({decomposed_body_, mutated_stmt}); |
57 | } else { |
58 | return mutated_stmt; |
59 | } |
60 | } |
61 | |
62 | Stmt VisitStmt_(const BlockNode* block) final { |
63 | if (block == old_reduction_block_.get()) { |
64 | ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block); |
65 | p_new_block->name_hint = p_new_block->name_hint + "_update" ; |
66 | p_new_block->init = NullOpt; |
67 | // Add write regions back to read regions in update block. |
68 | Array<BufferRegion> new_reads; |
69 | std::unordered_set<const BufferNode*> read_bufs; |
70 | for (const BufferRegion& read_access : block->reads) { |
71 | read_bufs.insert(read_access->buffer.get()); |
72 | } |
73 | for (const BufferRegion& write_access : block->writes) { |
74 | if (read_bufs.find(write_access->buffer.get()) == read_bufs.end()) { |
75 | new_reads.push_back(write_access); |
76 | } |
77 | } |
78 | for (const BufferRegion& read_access : block->reads) { |
79 | new_reads.push_back(read_access); |
80 | } |
81 | p_new_block->reads = new_reads; |
82 | new_reduction_block_ = Block(p_new_block); |
83 | return new_reduction_block_; |
84 | } else { |
85 | return StmtMutator::VisitStmt_(block); |
86 | } |
87 | } |
88 | |
89 | Stmt VisitStmt_(const SeqStmtNode* seq) final { |
90 | Array<Stmt> new_stmts; |
91 | new_stmts.reserve(seq->seq.size()); |
92 | for (const Stmt& old_stmt : seq->seq) { |
93 | new_stmts.push_back(VisitStmt(old_stmt)); |
94 | } |
95 | return SeqStmt::Flatten(new_stmts); |
96 | } |
97 | |
98 | private: |
99 | For target_loop_; |
100 | Stmt decomposed_body_; |
101 | Block old_reduction_block_; |
102 | Block new_reduction_block_; |
103 | }; |
104 | |
105 | class LoopHeightError : public ScheduleError { |
106 | public: |
107 | static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, |
108 | const BlockRealizeNode* realize, |
109 | const Array<StmtSRef>& loops, |
110 | const StmtSRef& loop_sref) { |
111 | for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { |
112 | // For each block var of type kCommReduce, check its binding |
113 | const IterVar& iter_var = block->iter_vars[i]; |
114 | const PrimExpr& binding = realize->iter_values[i]; |
115 | if (iter_var->iter_type != IterVarType::kCommReduce) { |
116 | continue; |
117 | } |
118 | for (const StmtSRef& higher_loop : loops) { |
119 | // Only check loops not lower than the target loop |
120 | if (higher_loop.same_as(loop_sref)) { |
121 | break; |
122 | } |
123 | // loop_var of a higher loop shouldn't contain loop var |
124 | const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var; |
125 | if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { |
126 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
127 | throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block)); |
128 | } |
129 | } |
130 | } |
131 | } |
132 | |
133 | explicit LoopHeightError(IRModule mod, For loop, Block block) |
134 | : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} |
135 | |
136 | String FastErrorString() const final { |
137 | return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops " |
138 | "related to reduce block var" ; |
139 | } |
140 | |
141 | String DetailRenderTemplate() const final { |
142 | std::ostringstream os; |
143 | os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops " |
144 | "related to reduce block var of block {1}" ; |
145 | return os.str(); |
146 | } |
147 | |
148 | IRModule mod() const final { return mod_; } |
149 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; } |
150 | |
151 | IRModule mod_; |
152 | For loop_; |
153 | Block block_; |
154 | }; |
155 | |
156 | PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) { |
157 | if (is_one(pred)) return Bool(true); |
158 | PrimExpr new_pred = Bool(true); |
159 | auto f = [&](const VarNode* var) { return discarded_loops.count(var); }; |
160 | arith::PVar<PrimExpr> lhs, rhs, rest; |
161 | for (;;) { |
162 | if ((rest && (lhs < rhs)).Match(pred)) { |
163 | if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval()); |
164 | pred = rest.Eval(); |
165 | } else if ((lhs < rhs).Match(pred)) { |
166 | if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval()); |
167 | break; |
168 | } else { |
169 | ICHECK(false) << "Unexpected predicate for reduction block" ; |
170 | } |
171 | } |
172 | return new_pred; |
173 | } |
174 | |
175 | StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, |
176 | const StmtSRef& loop_sref) { |
177 | /*! |
178 | * Check |
179 | * - block is a reduction block |
180 | * - loop is not lower than all the loops related to reduce block var |
181 | * Mutate |
182 | * - generate loops related to data par block vars |
183 | * - generate corresponding init block and update block |
184 | */ |
185 | // Condition Checks and Information Collection |
186 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
187 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
188 | // Get the outer loops from high to low |
189 | Array<StmtSRef> loops = GetLoops(block_sref); |
190 | const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); |
191 | // Cond 0. Check loop_sref is an ancestor of block_sref |
192 | if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { |
193 | throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block), |
194 | "decompose_reduction" ); |
195 | } |
196 | // Cond 1. Check block is reduction |
197 | StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, |
198 | /*require_stage_pipeline=*/false); |
199 | CheckReductionBlock(self, block_sref, scope_root_sref); |
200 | // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction |
201 | LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); |
202 | // IR Manipulation |
203 | ObjectPtr<BlockNode> init_block = make_object<BlockNode>(); |
204 | ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>(); |
205 | init_block->name_hint = block->name_hint + "_init" ; |
206 | init_block->annotations = block->annotations; |
207 | init_realize->iter_values = {}; |
208 | init_realize->block = Block(init_block); |
209 | // Step 1. Create new block vars and their bindings |
210 | // Maps an old block var to the new corresponding block var |
211 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map; |
212 | block_var_map.reserve(block->iter_vars.size()); |
213 | for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { |
214 | const IterVar& iter_var = block->iter_vars[i]; |
215 | const PrimExpr& binding = realize->iter_values[i]; |
216 | // Only process data parallel block vars |
217 | if (iter_var->iter_type != IterVarType::kDataPar) { |
218 | continue; |
219 | } |
220 | // Create a new block var |
221 | IterVar new_iter_var(/*dom=*/iter_var->dom, |
222 | /*var=*/iter_var->var.copy_with_suffix("" ), |
223 | /*iter_type=*/iter_var->iter_type, |
224 | /*thread_tag=*/iter_var->thread_tag); |
225 | // Add a block var and its binding |
226 | init_block->iter_vars.push_back(new_iter_var); |
227 | init_realize->iter_values.push_back(binding); |
228 | // Add a mapping from old block vars to new block vars |
229 | block_var_map[iter_var->var] = new_iter_var->var; |
230 | } |
231 | // Step 2. After copying block vars, substitute them in init block |
232 | init_block->body = Substitute(block->init.value(), block_var_map); |
233 | for (const BufferRegion& write : block->writes) { |
234 | init_block->writes.push_back( |
235 | BufferRegion(write->buffer, Substitute(write->region, block_var_map))); |
236 | } |
237 | // Step 3. Scan loops not higher than the specified loop above the reduction block. |
238 | // If the loop is used in the init block binding, then it is chosen. |
239 | // Otherwise, it is discarded. |
240 | std::unordered_set<const VarNode*> discarded_loops; |
241 | std::vector<int> chosen_loops; |
242 | for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) { |
243 | const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get(); |
244 | bool discarded = true; |
245 | for (const PrimExpr& expr : init_realize->iter_values) { |
246 | if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) { |
247 | continue; |
248 | } |
249 | // The loop is related to init block bindings; |
250 | chosen_loops.push_back(i); |
251 | discarded = false; |
252 | break; |
253 | } |
254 | if (discarded) discarded_loops.insert(loop_var); |
255 | // Only scan loops not higher than the given loop |
256 | if (loops[i].same_as(loop_sref)) { |
257 | break; |
258 | } |
259 | } |
260 | // Step 4. After scanning loops, make a new predicate in the init block realize |
261 | // We discard predicate that is related to discarded loops |
262 | init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops); |
263 | // Step 5. Create new loops above init block |
264 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map; |
265 | Stmt body = BlockRealize(init_realize); |
266 | for (int i : chosen_loops) { |
267 | const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); |
268 | // Create a new equivalent to the chosen loop |
269 | Var old_loop_var = old_loop->loop_var; |
270 | Var new_loop_var = old_loop_var.copy_with_suffix("_init" ); |
271 | loop_var_map[old_loop_var] = new_loop_var; |
272 | body = For(/*loop_var=*/new_loop_var, |
273 | /*min=*/old_loop->min, |
274 | /*extent=*/old_loop->extent, |
275 | /*kind=*/old_loop->kind, |
276 | /*body=*/body); |
277 | } |
278 | body = Substitute(body, loop_var_map); |
279 | // Step 6. Mutate IR |
280 | const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); |
281 | auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace( |
282 | GetRef<Block>(old_scope_root), GetRef<For>(loop), body, GetRef<Block>(block)); |
283 | self->Replace(scope_root_sref, new_scope_root, |
284 | {{GetRef<Block>(old_scope_root), new_scope_root}, |
285 | {GetRef<Block>(block), new_reduction_block}}); |
286 | self->UpdateScopeBlockInfo(new_scope_root); |
287 | return self->stmt2ref.at(init_block.get()); |
288 | } |
289 | |
290 | /******** Commutative Reducer ********/ |
291 | |
292 | /*! |
293 | * \brief A structure used for registering new commutative reducers, and store all the registered |
294 | * reducers. The reducers are preserved in a list, in the form of "reducer-getter function". When |
295 | * invoking a reducer-getter function with a specific datatype, the reducer-getter will return the |
296 | * CommReducer of the corresponding reduction pattern and the specific datatype |
297 | */ |
298 | struct ReducerRegistry { |
299 | ReducerRegistry() |
300 | : reducer_getters{ |
301 | CreateReducerGetter( |
302 | /*n_buffers=*/1, |
303 | [](const Array<Var>& x, const Array<Var>& y) { |
304 | return Array<PrimExpr>{x[0] + y[0]}; |
305 | }, |
306 | [](const Array<PrimExpr>& values) { |
307 | return Array<PrimExpr>{make_const(values[0]->dtype, 0)}; |
308 | }), |
309 | CreateReducerGetter( |
310 | /*n_buffers=*/1, |
311 | [](const Array<Var>& x, const Array<Var>& y) { |
312 | return Array<PrimExpr>{x[0] * y[0]}; |
313 | }, |
314 | [](const Array<PrimExpr>& values) { |
315 | return Array<PrimExpr>{make_const(values[0]->dtype, 1)}; |
316 | }), |
317 | CreateReducerGetter( |
318 | /*n_buffers=*/1, |
319 | [](const Array<Var>& x, const Array<Var>& y) { |
320 | return Array<PrimExpr>{min(x[0], y[0])}; |
321 | }, |
322 | [](const Array<PrimExpr>& values) { |
323 | return Array<PrimExpr>{max_value(values[0]->dtype)}; |
324 | }), |
325 | CreateReducerGetter( |
326 | /*n_buffers=*/1, |
327 | [](const Array<Var>& x, const Array<Var>& y) { |
328 | return Array<PrimExpr>{max(x[0], y[0])}; |
329 | }, |
330 | [](const Array<PrimExpr>& values) { |
331 | return Array<PrimExpr>{min_value(values[0]->dtype)}; |
332 | }), |
333 | CreateReducerGetter( |
334 | /*n_buffers=*/2, |
335 | [](const Array<Var>& x, const Array<Var>& y) { |
336 | return Array<PrimExpr>{x[0] + y[0], x[1] + y[1]}; |
337 | }, |
338 | [](const Array<PrimExpr>& values) { |
339 | return Array<PrimExpr>{make_const(values[0]->dtype, 0), |
340 | make_const(values[1]->dtype, 0)}; |
341 | }), |
342 | CreateReducerGetter( |
343 | /*n_buffers=*/2, |
344 | [](const Array<Var>& x, const Array<Var>& y) { |
345 | PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); |
346 | PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); |
347 | return Array<PrimExpr>{idx, val}; |
348 | }, |
349 | [](const Array<PrimExpr>& values) { |
350 | return Array<PrimExpr>{make_const(values[0]->dtype, -1), |
351 | min_value(values[1]->dtype)}; |
352 | }), |
353 | CreateReducerGetter( |
354 | /*n_buffers=*/2, |
355 | [](const Array<Var>& x, const Array<Var>& y) { |
356 | PrimExpr idx = |
357 | Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), |
358 | x[0], y[0]); |
359 | PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]); |
360 | return Array<PrimExpr>{idx, val}; |
361 | }, |
362 | [](const Array<PrimExpr>& values) { |
363 | return Array<PrimExpr>{make_const(values[0]->dtype, -1), |
364 | min_value(values[1]->dtype)}; |
365 | }), |
366 | CreateReducerGetter( |
367 | /*n_buffers=*/2, |
368 | [](const Array<Var>& x, const Array<Var>& y) { |
369 | PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); |
370 | PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); |
371 | return Array<PrimExpr>{idx, val}; |
372 | }, |
373 | [](const Array<PrimExpr>& values) { |
374 | return Array<PrimExpr>{make_const(values[0]->dtype, -1), |
375 | max_value(values[1]->dtype)}; |
376 | }), |
377 | CreateReducerGetter( |
378 | /*n_buffers=*/2, |
379 | [](const Array<Var>& x, const Array<Var>& y) { |
380 | PrimExpr idx = Select( |
381 | Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); |
382 | PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]); |
383 | return Array<PrimExpr>{idx, val}; |
384 | }, |
385 | [](const Array<PrimExpr>& values) { |
386 | return Array<PrimExpr>{make_const(values[0]->dtype, -1), |
387 | max_value(values[1]->dtype)}; |
388 | })} {} |
389 | |
390 | static void RegisterReducer( |
391 | int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter, |
392 | TypedPackedFunc<Array<PrimExpr>(Array<PrimExpr>)> identity_getter) { |
393 | ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( |
394 | n_buffers, std::move(combiner_getter), std::move(identity_getter))); |
395 | } |
396 | |
397 | static TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)> CreateReducerGetter( |
398 | int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter, |
399 | TypedPackedFunc<Array<PrimExpr>(Array<PrimExpr>)> identity_getter) { |
400 | return [n_buffers, // |
401 | combiner_getter = std::move(combiner_getter), // |
402 | identity_getter = std::move(identity_getter) // |
403 | ](Array<PrimExpr> values) -> Optional<CommReducer> { |
404 | if (static_cast<int>(values.size()) != n_buffers) { |
405 | return NullOpt; |
406 | } |
407 | Array<Var> lhs; |
408 | Array<Var> rhs; |
409 | for (int i = 0; i < n_buffers; ++i) { |
410 | lhs.push_back(Var("x" + std::to_string(i), values[i]->dtype)); |
411 | rhs.push_back(Var("y" + std::to_string(i), values[i]->dtype)); |
412 | } |
413 | return CommReducer(lhs, rhs, combiner_getter(lhs, rhs), identity_getter(values)); |
414 | }; |
415 | } |
416 | |
417 | static ReducerRegistry* Global() { |
418 | static ReducerRegistry instance; |
419 | return &instance; |
420 | } |
421 | |
422 | std::vector<TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> reducer_getters; |
423 | }; |
424 | |
425 | std::vector<TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> GetReducerGetters() { |
426 | return ReducerRegistry::Global()->reducer_getters; |
427 | } |
428 | |
429 | class NotSerialLoopKindError : public ScheduleError { |
430 | public: |
431 | explicit NotSerialLoopKindError(IRModule mod, For loop) |
432 | : mod_(std::move(mod)), loop_(std::move(loop)) {} |
433 | |
434 | String FastErrorString() const final { |
435 | return "ScheduleError: The input loop of rfactor is required to be `kSerial`" ; |
436 | } |
437 | |
438 | String DetailRenderTemplate() const final { |
439 | String str_kind = ForKind2String(loop_->kind); |
440 | std::ostringstream os; |
441 | os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " |
442 | "kind of {0} is `" |
443 | << str_kind << "`" ; |
444 | return os.str(); |
445 | } |
446 | |
447 | IRModule mod() const final { return mod_; } |
448 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
449 | |
450 | IRModule mod_; |
451 | For loop_; |
452 | }; |
453 | |
454 | class FactorAxisOutOfRangeError : public ScheduleError { |
455 | public: |
456 | explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) |
457 | : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} |
458 | |
459 | String FastErrorString() const final { |
460 | return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " |
461 | "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer" ; |
462 | } |
463 | |
464 | String DetailRenderTemplate() const final { |
465 | std::ostringstream os; |
466 | int ndim = static_cast<int>(buffer_->shape.size()); |
467 | os << "The write buffer " << buffer_->name << " has " << ndim |
468 | << " dimension(s), so `factor_axis` is required to be in [" << -(ndim + 1) << ", " << ndim |
469 | << "] for rfactor. However, the input `factor_axis` is " << factor_axis_ |
470 | << ", which is out of the expected range" ; |
471 | return os.str(); |
472 | } |
473 | |
474 | IRModule mod() const final { return mod_; } |
475 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
476 | |
477 | static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { |
478 | int ndim = static_cast<int>(buffer->shape.size()); |
479 | if (factor_axis < -(ndim + 1) || factor_axis > ndim) { |
480 | throw FactorAxisOutOfRangeError(mod, buffer, factor_axis); |
481 | } |
482 | // If factor_axis is negative, convert it to a non-negative one. |
483 | if (factor_axis < 0) { |
484 | factor_axis += ndim + 1; |
485 | } |
486 | return factor_axis; |
487 | } |
488 | |
489 | IRModule mod_; |
490 | Buffer buffer_; |
491 | int factor_axis_; |
492 | }; |
493 | |
494 | class LoopPropertyError : public ScheduleError { |
495 | public: |
496 | enum ErrorType { |
497 | kDataParIterTouchRFactorLoop = 0, |
498 | kLoopTouchedByBothKindsOfBlockIters = 1, |
499 | kNotFirstChildBlockOfOutermostLoop = 2, |
500 | kUnboundLoopUnderReductionLoop = 3 |
501 | }; |
502 | |
503 | explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) |
504 | : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} |
505 | |
506 | String FastErrorString() const final { |
507 | switch (error_type_) { |
508 | case kDataParIterTouchRFactorLoop: |
509 | return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " |
510 | "data parallel block iter of the block" ; |
511 | case kLoopTouchedByBothKindsOfBlockIters: |
512 | return "ScheduleError: The loops outside of the reduction block are required not to be " |
513 | "touched by both data parallel block iters and reduction block iters" ; |
514 | case kNotFirstChildBlockOfOutermostLoop: |
515 | return "ScheduleError: The reduction block should be the first child block of the " |
516 | "outermost loop outside of it" ; |
517 | case kUnboundLoopUnderReductionLoop: |
518 | return "ScheduleError: A loop who has extent greater than one and is not bound to any " |
519 | "block iter should not appear under a reduction loop" ; |
520 | } |
521 | ICHECK(false) << "Unreachable" ; |
522 | throw; |
523 | } |
524 | |
525 | String DetailRenderTemplate() const final { |
526 | switch (error_type_) { |
527 | case kDataParIterTouchRFactorLoop: |
528 | return "The loop to be applied rfactor is {0}, which is required not to be touched by any " |
529 | "data parallel block iter of the block below. However, some of the block's data " |
530 | "parallel block iters touch this loop" ; |
531 | case kLoopTouchedByBothKindsOfBlockIters: |
532 | return "It is not allowed that the loop {0} is touched by both some data parallel block " |
533 | "iters and some reduction block iters" ; |
534 | case kNotFirstChildBlockOfOutermostLoop: |
535 | return "The first child block of the outermost loop {0} is not the reduction block." ; |
536 | case kUnboundLoopUnderReductionLoop: |
537 | return "The loop {0} has extent greater than one, and is not bound to any block iter. " |
538 | "Therefore it shouldn't appear under a reduction loop" ; |
539 | } |
540 | ICHECK(false) << "Unreachable" ; |
541 | throw; |
542 | } |
543 | |
544 | IRModule mod() const final { return mod_; } |
545 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
546 | |
547 | static void CheckLoopProperty(const ScheduleState& self, const Array<For>& loops, |
548 | const ForNode* rf_loop, const Block& block, |
549 | const std::unordered_set<const VarNode*>& data_par_loop_vars, |
550 | const std::unordered_set<const VarNode*>& reduce_loop_vars) { |
551 | Array<BlockRealize> children_of_outermost_loop = |
552 | GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); |
553 | if (!children_of_outermost_loop[0]->block.same_as(block)) { |
554 | throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); |
555 | } |
556 | |
557 | bool meet_reduction_loop = false; |
558 | for (const For& loop : loops) { |
559 | bool data_par_touched = data_par_loop_vars.count(loop->loop_var.get()); |
560 | bool reduction_touched = reduce_loop_vars.count(loop->loop_var.get()); |
561 | |
562 | if (data_par_touched && reduction_touched) { |
563 | throw LoopPropertyError(self->mod, loop, kLoopTouchedByBothKindsOfBlockIters); |
564 | } else if (data_par_touched) { |
565 | if (loop.get() == rf_loop) { |
566 | throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); |
567 | } |
568 | continue; |
569 | } else if (reduction_touched) { |
570 | if (!meet_reduction_loop) { |
571 | CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); |
572 | meet_reduction_loop = true; |
573 | } |
574 | continue; |
575 | } else if (meet_reduction_loop && !is_one(loop->extent)) { |
576 | throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); |
577 | } |
578 | } |
579 | } |
580 | |
581 | IRModule mod_; |
582 | For loop_; |
583 | ErrorType error_type_; |
584 | }; |
585 | |
586 | /*! |
587 | * \brief For each loop in the given array of loop, associate its loop var with the loop itself |
588 | * using a mapping |
589 | * \param loops The loops to be analyzed |
590 | * \return A mapping from loops to their corresponding loop vars |
591 | */ |
592 | std::unordered_map<const VarNode*, For> GetLoopVar2LoopMap(const Array<For>& loops) { |
593 | std::unordered_map<const VarNode*, For> loop_vars2loop; |
594 | loop_vars2loop.reserve(loops.size()); |
595 | for (const For& loop : loops) { |
596 | loop_vars2loop[loop->loop_var.get()] = loop; |
597 | } |
598 | return loop_vars2loop; |
599 | } |
600 | |
601 | /*! |
602 | * \brief Create the intermediate rfactor buffers, which the rfactor block writes to and the |
603 | * write-back block reads from |
604 | * \param buf_stores The BufferStores of the original block, where the rfactor buffers will be |
605 | * created from |
606 | * \param factor_axis The `factor_axis` parameter of rfactor |
607 | * \param rf_loop The rfactor loop |
608 | * \return The new created intermediate rfactor buffer |
609 | */ |
610 | Array<Buffer> CreateRFactorBuffers(const Array<BufferStore>& buf_stores, int factor_axis, |
611 | const ForNode* rf_loop) { |
612 | Array<Buffer> rf_buffers; |
613 | rf_buffers.reserve(buf_stores.size()); |
614 | for (const BufferStore& buf_store : buf_stores) { |
615 | Buffer buffer = buf_store->buffer; |
616 | Array<PrimExpr> rf_shape = buffer->shape; |
617 | rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); |
618 | |
619 | ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get()); |
620 | n->shape = rf_shape; |
621 | n->name = buffer->name + ".rf" ; |
622 | n->data = buffer->data.copy_with_suffix(".rf" ); |
623 | rf_buffers.push_back(Buffer(n)); |
624 | } |
625 | return rf_buffers; |
626 | } |
627 | |
628 | /*! |
629 | * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four |
630 | * steps: |
631 | * 1) Create the new block iters and the their iter bindings |
632 | * 2) Create the body and init of the new block |
633 | * 3) Create the read/write regions of the new block |
634 | * 4) Create the new block and the new block-realize |
635 | */ |
636 | class BaseBlockCreator { |
637 | public: |
638 | explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, |
639 | Array<BufferStore> old_reduction_updates, CommReducer reducer, |
640 | Array<Buffer> rf_buffers, bool is_rf_block) |
641 | : old_block_realize_(std::move(old_block_realize)), |
642 | rf_loop_(std::move(rf_loop)), |
643 | old_reduction_updates_(std::move(old_reduction_updates)), |
644 | reducer_(std::move(reducer)), |
645 | rf_buffers_(std::move(rf_buffers)), |
646 | n_buffers_(static_cast<int>(rf_buffers_.size())), |
647 | is_rf_block_(is_rf_block) { |
648 | n_block_iters_ = static_cast<int>(old_block_realize_->iter_values.size()); |
649 | update_buffers_.reserve(n_buffers_); |
650 | update_indices_.reserve(n_buffers_); |
651 | update_lhs_.reserve(n_buffers_); |
652 | update_rhs_.reserve(n_buffers_); |
653 | } |
654 | |
655 | void CreateBlock() { |
656 | CreateAdditionalIter(); |
657 | for (int i = 0; i < n_block_iters_; ++i) { |
658 | CreateNormalIters(i); |
659 | } |
660 | bool has_reduce_iter = false; |
661 | for (const IterVar& iter_var : iter_vars_) { |
662 | if (iter_var->iter_type == IterVarType::kCommReduce) { |
663 | has_reduce_iter = true; |
664 | break; |
665 | } |
666 | } |
667 | |
668 | // The pre-processing finds out the buffers written in the block, the indices of the buffer |
669 | // accesses, and the reduction LHS and RHS of the stored values. |
670 | PreProcess(); |
671 | Stmt block_body = Substitute(CreateBlockBody(has_reduce_iter), var_map_); |
672 | Optional<Stmt> block_init = CreateBlockInit(has_reduce_iter); |
673 | if (block_init.defined()) { |
674 | block_init = Substitute(block_init.value(), var_map_); |
675 | } |
676 | CreateReadWriteRegions(); |
677 | |
678 | String new_block_name = old_block_realize_->block->name_hint; |
679 | PrimExpr predicate = const_true(); |
680 | if (is_rf_block_) { |
681 | new_block_name = new_block_name + "_rf" ; |
682 | predicate = old_block_realize_->predicate; |
683 | } |
684 | new_block_ = Block( |
685 | /*iter_vars=*/iter_vars_, |
686 | /*reads=*/read_regions_, |
687 | /*writes=*/write_regions_, |
688 | /*name_hint=*/new_block_name, |
689 | /*body=*/std::move(block_body), |
690 | /*init=*/std::move(block_init), |
691 | /*alloc_buffers=*/{}, |
692 | /*match_buffers=*/{}, |
693 | /*annotations=*/old_block_realize_->block->annotations); |
694 | new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_); |
695 | } |
696 | |
697 | private: |
698 | virtual void CreateAdditionalIter() = 0; |
699 | virtual void CreateNormalIters(int idx) = 0; |
700 | virtual void PreProcess() = 0; |
701 | virtual void CreateReadWriteRegions() = 0; |
702 | |
703 | Stmt CreateBlockBody(bool has_reduce_iter) { |
704 | Array<Stmt> buf_stores; |
705 | buf_stores.reserve(n_buffers_); |
706 | |
707 | // Case 1. If the block has no reduction iterator, we just store the RHS values into the |
708 | // buffers. |
709 | if (!has_reduce_iter) { |
710 | for (int i = 0; i < n_buffers_; ++i) { |
711 | buf_stores.push_back(BufferStore(update_buffers_[i], update_rhs_[i], update_indices_[i])); |
712 | } |
713 | return n_buffers_ > 1 ? SeqStmt(buf_stores) : buf_stores[0]; |
714 | } |
715 | |
716 | // Case 2. If the reduction is for single buffer, the block body is a single BufferStore. |
717 | Array<PrimExpr> stored_values = (*reducer_.get())(update_lhs_, update_rhs_); |
718 | if (n_buffers_ == 1) { |
719 | return BufferStore(update_buffers_[0], stored_values[0], update_indices_[0]); |
720 | } |
721 | |
722 | // Case 3. In case the reduction is for multiple buffers, we should create the reduction with |
723 | // LetStmt so that the reduction execution generates correct results. |
724 | Array<Var> let_vars; |
725 | let_vars.reserve(n_buffers_); |
726 | for (int i = 0; i < n_buffers_; ++i) { |
727 | Var var("v_" + update_buffers_[i]->name, PrimType(stored_values[i]->dtype)); |
728 | let_vars.push_back(var); |
729 | buf_stores.push_back(BufferStore(update_buffers_[i], var, update_indices_[i])); |
730 | } |
731 | Stmt body = SeqStmt(buf_stores); |
732 | for (int i = n_buffers_ - 1; i >= 0; --i) { |
733 | body = LetStmt(let_vars[i], stored_values[i], std::move(body)); |
734 | } |
735 | return body; |
736 | } |
737 | |
738 | Optional<Stmt> CreateBlockInit(bool has_reduce_iter) { |
739 | if (!has_reduce_iter) { |
740 | return NullOpt; |
741 | } |
742 | |
743 | Array<Stmt> inits; |
744 | inits.reserve(n_buffers_); |
745 | for (int i = 0; i < n_buffers_; ++i) { |
746 | inits.push_back( |
747 | BufferStore(update_buffers_[i], reducer_->identity_element[i], update_indices_[i])); |
748 | } |
749 | return n_buffers_ > 1 ? SeqStmt(inits) : inits[0]; |
750 | } |
751 | |
752 | public: |
753 | /*! \brief The new created block */ |
754 | Block new_block_; |
755 | /*! \brief The new created block-realize */ |
756 | BlockRealize new_block_realize_; |
757 | /*! \brief The indices used to access the intermediate rfactor buffer */ |
758 | Array<PrimExpr> rf_buf_access_indices_; |
759 | |
760 | protected: |
761 | /*! \brief The old block-realize */ |
762 | BlockRealize old_block_realize_; |
763 | /*! \brief The number of block iters in the old block */ |
764 | int n_block_iters_; |
765 | /*! \brief The rfactor loop */ |
766 | For rf_loop_; |
767 | /*! \brief The update BufferStores of the old block */ |
768 | Array<BufferStore> old_reduction_updates_; |
769 | /*! \brief The matched commutative reducer */ |
770 | CommReducer reducer_; |
771 | /*! \brief The intermediate rfactor buffers */ |
772 | Array<Buffer> rf_buffers_; |
773 | /*! \brief The number of rfactor buffers. */ |
774 | const int n_buffers_; |
775 | /*! |
776 | * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced |
777 | * by the expressions in future substitution for the two blocks |
778 | */ |
779 | Map<Var, PrimExpr> var_map_; |
780 | |
781 | /*! \brief Whether we are creating the rfactor block or the write-back block */ |
782 | bool is_rf_block_; |
783 | /*! \brief The new block iters of the new created block */ |
784 | std::vector<IterVar> iter_vars_; |
785 | /*! \brief The new block iter bindings of the new created block-realize */ |
786 | std::vector<PrimExpr> iter_values_; |
787 | /*! \brief The buffers updated in this block */ |
788 | Array<Buffer> update_buffers_; |
789 | /*! \brief The indices of the buffers updated in this block, respectively */ |
790 | Array<Array<PrimExpr>> update_indices_; |
791 | /*! \brief The LHS values of the reduction in this block */ |
792 | Array<PrimExpr> update_lhs_; |
793 | /*! \brief THe RHS values of the reduction in this block */ |
794 | Array<PrimExpr> update_rhs_; |
795 | /*! \brief The read regions of the new created block */ |
796 | Array<BufferRegion> read_regions_; |
797 | /*! \brief The write regions of the new created block */ |
798 | Array<BufferRegion> write_regions_; |
799 | }; |
800 | |
801 | /*! |
802 | * \brief The derived class of the rfactor block creator, which implements all virtual methods in |
803 | * the base creator |
804 | * \details Start constructing the rfactor block. The main difficulty to construct the rfactor block |
805 | * is to create its block iters. So here we introduce the algorithm to create the block iters. |
806 | * 1. Create a block iter for the rfactor loop. The block binding of this iter is the loop var, and |
807 | * the block iter is data parallel. |
808 | * 2. For all the old block's block iters, there are two cases: |
809 | * (a) If it is data parallel block iter, or a reduction block iter which doesn't touch the |
810 | * rfactor loop, we keep it and its block binding in the rfactor block. |
811 | * (b) Otherwise it is a reduction block iter which touches the rfactor loop. In this case, we |
812 | * "split" the block iter into one or more new block iters and do not keep the old block |
813 | * var. More specifically, we create a new reduction block iter for each loop var that |
814 | * appears in the reduction block iter's binding (except for the rfactor loop), and the |
815 | * binding of the new block iter is exactly the loop var. (Note that for each loop var, we |
816 | * create at most one block iter, even if there are multiple old block iters which touch |
817 | * both this loop and the rfactor loop). |
818 | * Then we substitute the appearances of the old block iter with the new created block |
819 | * iters by recording two mappings: one maps loops vars to new created block iters which |
820 | * is used for binding substitution, and another maps old block iters to new expressions |
821 | * which is used for substitutions of the old block iters. |
822 | */ |
823 | class RFactorBlockCreator : public BaseBlockCreator { |
824 | public: |
825 | explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, |
826 | Array<BufferStore> old_reduction_updates, CommReducer reducer, |
827 | Array<Buffer> rf_buffers, |
828 | std::unordered_map<const VarNode*, For> loop_vars2loop, |
829 | int factor_axis, Array<PrimExpr> combiner_rhs) |
830 | : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), |
831 | std::move(old_reduction_updates), std::move(reducer), |
832 | std::move(rf_buffers), true), |
833 | loop_vars2loop_(std::move(loop_vars2loop)), |
834 | factor_axis_(factor_axis), |
835 | combiner_rhs_(std::move(combiner_rhs)) {} |
836 | |
837 | private: |
838 | void CreateAdditionalIter() final { |
839 | // Create a new data parallel block iter for the rfactor loop. |
840 | additional_iter_ = |
841 | IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); |
842 | loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; |
843 | iter_vars_.push_back(additional_iter_); |
844 | iter_values_.push_back(rf_loop_->loop_var); |
845 | } |
846 | |
847 | void CreateNormalIters(int idx) final { |
848 | IterVar old_iter = old_block_realize_->block->iter_vars[idx]; |
849 | PrimExpr old_binding = old_block_realize_->iter_values[idx]; |
850 | if (old_iter->iter_type == IterVarType::kDataPar || |
851 | !UsesVar(old_binding, |
852 | [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { |
853 | // The old block iter is either a data parallel block iter, or a reduction block iter that |
854 | // doesn't touch the rfactor loop. In this case reuse the old reduction block iter and its |
855 | // corresponding binding. |
856 | iter_vars_.push_back(old_iter); |
857 | iter_values_.push_back(old_binding); |
858 | return; |
859 | } |
860 | ICHECK(old_iter->iter_type == kCommReduce); |
861 | // This block iter is a reduction block iter that touches the rfactor loop. So next we try to |
862 | // create a new block iter for all loop vars that appear in the old binding. |
863 | Array<Var> vars_in_old_binding = UndefinedVars(old_binding); |
864 | for (const Var& var : vars_in_old_binding) { |
865 | auto it = loop_vars2loop_.find(var.get()); |
866 | if (it == loop_vars2loop_.end()) { |
867 | // `var` is not a loop var. So skip. |
868 | continue; |
869 | } |
870 | const For& loop = it->second; |
871 | if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { |
872 | // We haven't created the new block iter for `var`. So here we create it, append it |
873 | // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. |
874 | IterVar new_iter_var = |
875 | IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); |
876 | loop_var2block_binding_[var.get()] = new_iter_var->var; |
877 | iter_vars_.push_back(new_iter_var); |
878 | iter_values_.push_back(var); |
879 | } |
880 | } |
881 | // Substitute the original binding with new block iters. Store the result expression |
882 | // in `rf_var_map` for future substitution. |
883 | var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); |
884 | } |
885 | |
886 | void PreProcess() final { |
887 | // The accessed indices for all reduction buffers are the same. |
888 | rf_buf_access_indices_ = old_reduction_updates_[0]->indices; |
889 | rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, |
890 | additional_iter_->var); |
891 | for (int i = 0; i < n_buffers_; ++i) { |
892 | update_buffers_.push_back(rf_buffers_[i]); |
893 | update_indices_.push_back(rf_buf_access_indices_); |
894 | update_lhs_.push_back(BufferLoad(update_buffers_[i], rf_buf_access_indices_)); |
895 | update_rhs_.push_back(combiner_rhs_[i]); |
896 | } |
897 | } |
898 | |
899 | void CreateReadWriteRegions() final { |
900 | Map<Buffer, Buffer> buffer_map; |
901 | for (int i = 0; i < n_buffers_; ++i) { |
902 | buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]); |
903 | } |
904 | const Block& old_block = old_block_realize_->block; |
905 | read_regions_.reserve(old_block->reads.size()); |
906 | for (const BufferRegion& read_region : old_block->reads) { |
907 | read_regions_.push_back( |
908 | BufferRegion(read_region->buffer, Substitute(read_region->region, var_map_))); |
909 | } |
910 | write_regions_.reserve(old_block->writes.size()); |
911 | for (const BufferRegion& write_region : old_block->writes) { |
912 | Array<Range> region = write_region->region; |
913 | region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1)); |
914 | Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer); |
915 | ICHECK(rf_buffer.defined()); |
916 | write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); |
917 | } |
918 | } |
919 | |
920 | public: |
921 | /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ |
922 | IterVar additional_iter_; |
923 | |
924 | private: |
925 | /*! |
926 | * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction |
927 | * block's outer loops |
928 | */ |
929 | std::unordered_map<const VarNode*, For> loop_vars2loop_; |
930 | /*! \brief The factor_axis specified for rfactor */ |
931 | int factor_axis_; |
932 | /*! \brief The RHS values of the reduction in the old block */ |
933 | Array<PrimExpr> combiner_rhs_; |
934 | /*! |
935 | * \brief A mapping which maps loop vars to new created block iters. This map is used to |
936 | * substitute the loop vars which appear in the bindings of some old block iters with the new |
937 | * created block iters |
938 | */ |
939 | std::unordered_map<const VarNode*, PrimExpr> loop_var2block_binding_; |
940 | }; |
941 | |
942 | /*! |
943 | * \brief The derived class of the write-back block creator, which implements all virtual methods in |
944 | * the base creator |
945 | */ |
946 | class WriteBackBlockCreator : public BaseBlockCreator { |
947 | public: |
948 | explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, |
949 | Array<BufferStore> old_reduction_updates, CommReducer reducer, |
950 | Array<Buffer> rf_buffers, IterVar rf_additional_iter, |
951 | Array<PrimExpr> combiner_lhs, |
952 | Array<PrimExpr> rf_buf_access_indices) |
953 | : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), |
954 | std::move(old_reduction_updates), std::move(reducer), |
955 | std::move(rf_buffers), false), |
956 | rf_additional_iter_(std::move(rf_additional_iter)), |
957 | combiner_lhs_(std::move(combiner_lhs)) { |
958 | iter_vars_.reserve(n_block_iters_); |
959 | iter_values_.reserve(n_block_iters_); |
960 | rf_buf_access_indices_ = std::move(rf_buf_access_indices); |
961 | } |
962 | |
963 | private: |
964 | void CreateAdditionalIter() final { |
965 | // Create a new reduction block iter for the rfactor loop. |
966 | IterVar wb_new_block_iter = |
967 | IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); |
968 | iter_vars_.push_back(wb_new_block_iter); |
969 | iter_values_.push_back(rf_loop_->loop_var); |
970 | var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); |
971 | } |
972 | |
973 | void CreateNormalIters(int idx) final { |
974 | IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; |
975 | if (old_block_iter->iter_type == IterVarType::kDataPar) { |
976 | iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix("" ), |
977 | kDataPar); |
978 | iter_values_.push_back(old_block_realize_->iter_values[idx]); |
979 | var_map_.Set(old_block_iter->var, iter_vars_.back()); |
980 | } |
981 | } |
982 | |
983 | void PreProcess() final { |
984 | for (int i = 0; i < n_buffers_; ++i) { |
985 | PrimExpr rhs = BufferLoad(rf_buffers_[i], rf_buf_access_indices_); |
986 | update_buffers_.push_back(old_reduction_updates_[i]->buffer); |
987 | update_indices_.push_back(old_reduction_updates_[i]->indices); |
988 | update_lhs_.push_back(Substitute(combiner_lhs_[i], var_map_)); |
989 | update_rhs_.push_back(Substitute(std::move(rhs), var_map_)); |
990 | } |
991 | } |
992 | |
993 | void CreateReadWriteRegions() final { |
994 | CreateRegion(update_rhs_, true); |
995 | CreateRegion(update_lhs_, false); |
996 | } |
997 | |
998 | void CreateRegion(const Array<PrimExpr>& buf_loads, bool is_read) { |
999 | Array<BufferRegion>& buf_regions = is_read ? read_regions_ : write_regions_; |
1000 | for (const PrimExpr& expr : buf_loads) { |
1001 | const auto* buf_load = expr.as<BufferLoadNode>(); |
1002 | ICHECK(buf_load != nullptr); |
1003 | Array<Range> region; |
1004 | region.reserve(buf_load->indices.size()); |
1005 | for (const PrimExpr& index : buf_load->indices) { |
1006 | region.push_back(Range::FromMinExtent(index, 1)); |
1007 | } |
1008 | buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region))); |
1009 | } |
1010 | } |
1011 | |
1012 | private: |
1013 | /*! \brief The new created additional block iter of the rfactor block */ |
1014 | IterVar rf_additional_iter_; |
1015 | /*! \brief The LHS values of the reduction in the old block */ |
1016 | Array<PrimExpr> combiner_lhs_; |
1017 | }; |
1018 | |
1019 | /*! |
1020 | * \brief Create new outer loops for the rfactor block, meanwhile update the rfactor block's iter |
1021 | * bindings to use the new created loop vars |
1022 | * \param rf_block_realize The BlockRealize of the rfactor block |
1023 | * \param loops The loops to be wrapped over the rfactor block |
1024 | * \return A Stmt which is the wrapping result |
1025 | */ |
1026 | Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array<For>& loops) { |
1027 | int n_loops = static_cast<int>(loops.size()); |
1028 | |
1029 | // Step 1. Create new loop vars. |
1030 | Array<For> new_loops; |
1031 | std::unordered_map<const VarNode*, PrimExpr> new_loop_var_map; |
1032 | new_loops.reserve(n_loops); |
1033 | new_loop_var_map.reserve(n_loops); |
1034 | for (const For& old_loop : loops) { |
1035 | Var new_loop_var = old_loop->loop_var.copy_with_suffix("" ); |
1036 | new_loop_var_map[old_loop->loop_var.get()] = new_loop_var; |
1037 | } |
1038 | |
1039 | // Step 2. Update the iter bindings and predicate of the rfactor block. |
1040 | Array<PrimExpr> new_bindings; |
1041 | new_bindings.reserve(rf_block_realize->iter_values.size()); |
1042 | for (const PrimExpr& old_binding : rf_block_realize->iter_values) { |
1043 | new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); |
1044 | } |
1045 | { |
1046 | BlockRealizeNode* p_rf_block_realize = rf_block_realize.CopyOnWrite(); |
1047 | p_rf_block_realize->iter_values = new_bindings; |
1048 | p_rf_block_realize->predicate = Substitute(rf_block_realize->predicate, new_loop_var_map); |
1049 | } |
1050 | |
1051 | // Step 3. Wrap `rf_block_realize` with outer loops. |
1052 | Stmt rf_body = rf_block_realize; |
1053 | for (int i = n_loops - 1; i >= 0; --i) { |
1054 | ObjectPtr<ForNode> p_loop = make_object<ForNode>(*loops[i].get()); |
1055 | p_loop->loop_var = Downcast<Var>(new_loop_var_map[loops[i]->loop_var.get()]); |
1056 | p_loop->body = rf_body; |
1057 | rf_body = For(std::move(p_loop)); |
1058 | } |
1059 | |
1060 | return rf_body; |
1061 | } |
1062 | |
1063 | class BlockReplacer : public StmtMutator { |
1064 | public: |
1065 | /*! |
1066 | * \brief The replace takes the old scope root block as input, and does four things: |
1067 | * 1) replace the reduction block with the write-back block, |
1068 | * 2) remove loops outside the write-back block that are touched by reduction block iters, except |
1069 | * for the rfactor loop |
1070 | * 3) combine the rfactor block (wrapped with outer loops) and the transformed outermost loop |
1071 | * into a SeqStmt, and |
1072 | * 4) insert the rfactor buffer into the scope root block's `alloc_buffers` |
1073 | * After transformation, the function returns the new scope root block |
1074 | * \param scope_root_block The old scope root block |
1075 | * \param rf_body The rfactor block, which is already wrapped with outer loops |
1076 | * \param outermost_loop The loop that is outermost among all loops outside the reduction block |
1077 | * \param wb_block_realize The new created BlockRealize of the write-back block |
1078 | * \param old_block_realize The BlockRealize of the reduction block |
1079 | * \param rf_loop The rfactor loop, which should be kept outside the write-back block |
1080 | * \param reduce_loop_vars The loops that are touched by reduction block iters, used to remove |
1081 | * loops outside the write-back block |
1082 | * \param loop_vars2loop The mapping from loop vars to loops that are outside the reduction block, |
1083 | * which is used to reduce redundant recursive visits |
1084 | * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers` |
1085 | * \return The transformed new scope root block |
1086 | */ |
1087 | static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop, |
1088 | BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, |
1089 | std::unordered_set<const VarNode*> reduce_loop_vars, |
1090 | std::unordered_map<const VarNode*, For> loop_vars2loop, |
1091 | const Array<Buffer>& rf_buffers) { |
1092 | BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), |
1093 | std::move(wb_block_realize), std::move(old_block_realize), |
1094 | std::move(rf_loop), std::move(reduce_loop_vars), |
1095 | std::move(loop_vars2loop)); |
1096 | Block new_scope_root = Downcast<Block>(replacer(std::move(scope_root_block))); |
1097 | BlockNode* p = new_scope_root.CopyOnWrite(); |
1098 | for (const Buffer& rf_buffer : rf_buffers) { |
1099 | p->alloc_buffers.push_back(rf_buffer); |
1100 | } |
1101 | return new_scope_root; |
1102 | } |
1103 | |
1104 | private: |
1105 | explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize, |
1106 | BlockRealize old_block_realize, For rf_loop, |
1107 | std::unordered_set<const VarNode*> reduce_loop_vars, |
1108 | std::unordered_map<const VarNode*, For> loop_vars2loop) |
1109 | : rf_body_(std::move(rf_body)), |
1110 | outermost_loop_(std::move(outermost_loop)), |
1111 | wb_block_realize_(std::move(wb_block_realize)), |
1112 | old_block_realize_(std::move(old_block_realize)), |
1113 | rf_loop_(std::move(rf_loop)), |
1114 | reduce_loop_vars_(std::move(reduce_loop_vars)), |
1115 | loop_vars2loop_(std::move(loop_vars2loop)) {} |
1116 | |
1117 | Stmt VisitStmt_(const ForNode* loop) final { |
1118 | // Step 1. Check whether this loop is outside the reduction block. Given that we've made sure |
1119 | // that the scope root block has stage-pipeline property, if this loop is not outside the |
1120 | // reduction block, there's no need to recursively mutate. |
1121 | if (!loop_vars2loop_.count(loop->loop_var.get())) { |
1122 | return GetRef<For>(loop); |
1123 | } |
1124 | |
1125 | // Step 2. Recursively mutate. |
1126 | Stmt body = StmtMutator::VisitStmt(loop->body); |
1127 | |
1128 | // Step 3. If this loop is the rfactor loop and isn't touched by any reduction block iter, it |
1129 | // should be kept outside the write-back block. Otherwise it shouldn't. |
1130 | if (loop == rf_loop_.get() || !reduce_loop_vars_.count(loop->loop_var.get())) { |
1131 | ObjectPtr<ForNode> p_loop = CopyOnWrite(loop); |
1132 | p_loop->body = body; |
1133 | body = Stmt(p_loop); |
1134 | } |
1135 | |
1136 | // Step 4. If this loop is the outermost loop of the reduction block, return the combination of |
1137 | // `rf_body_` and the mutation result `body`. Otherwise return the mutation result. |
1138 | return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body; |
1139 | } |
1140 | |
1141 | Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { |
1142 | // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's |
1143 | // block-realize. And we directly return the new `wb_block_realize`. |
1144 | ICHECK_EQ(block_realize, old_block_realize_.get()); |
1145 | return wb_block_realize_; |
1146 | } |
1147 | |
1148 | Stmt VisitStmt_(const SeqStmtNode* seq) final { |
1149 | Array<Stmt> new_stmts; |
1150 | new_stmts.reserve(static_cast<int>(seq->seq.size())); |
1151 | |
1152 | for (const Stmt old_stmt : seq->seq) { |
1153 | new_stmts.push_back(VisitStmt(old_stmt)); |
1154 | } |
1155 | return SeqStmt::Flatten(new_stmts); |
1156 | } |
1157 | |
1158 | private: |
1159 | Stmt rf_body_; |
1160 | For outermost_loop_; |
1161 | BlockRealize wb_block_realize_; |
1162 | BlockRealize old_block_realize_; |
1163 | For rf_loop_; |
1164 | std::unordered_set<const VarNode*> reduce_loop_vars_; |
1165 | std::unordered_map<const VarNode*, For> loop_vars2loop_; |
1166 | }; |
1167 | |
1168 | StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_axis) { |
1169 | // ***************************************************** |
1170 | // * Condition Checks and Information Collection * |
1171 | // ***************************************************** |
1172 | |
1173 | // Step 1. Check some basic conditions for rfactor. Get the block and block-realize. |
1174 | BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); |
1175 | const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); |
1176 | const Block& block = block_realize->block; |
1177 | StmtSRef scope_root = GetScopeRoot(self, block_sref, // |
1178 | /*require_stage_pipeline=*/true); |
1179 | CheckReductionBlock(self, block_sref, scope_root); |
1180 | const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); |
1181 | if (rf_loop->kind != ForKind::kSerial) { |
1182 | throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop)); |
1183 | } |
1184 | |
1185 | // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block |
1186 | // iters, respectively. |
1187 | std::unordered_set<const VarNode*> data_par_loop_vars; |
1188 | std::unordered_set<const VarNode*> reduce_loop_vars; |
1189 | GetVarsTouchedByBlockIters(block_realize, &data_par_loop_vars, &reduce_loop_vars); |
1190 | |
1191 | // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to |
1192 | // corresponding loop vars. |
1193 | Array<For> loops = LoopSRefs2Loops(GetLoops(block_sref)); |
1194 | std::unordered_map<const VarNode*, For> loop_vars2loop = GetLoopVar2LoopMap(loops); |
1195 | |
1196 | // Step 4. Check four properties that the loops should have: |
1197 | // - the rfactor loop cannot be touched by any data parallel block iter; |
1198 | // - all the loops cannot be touched by both data parallel block iters and reduction block iters; |
1199 | // - the outermost loop should have the reduction block as its first child block; |
1200 | // - the outermost loop that is touched by some reduction block iters can only have one child |
1201 | // block. |
1202 | LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, |
1203 | reduce_loop_vars); |
1204 | |
1205 | // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the |
1206 | // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the |
1207 | // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs |
1208 | // will be used when constructing the rfactor block. |
1209 | Array<PrimExpr> init_values{nullptr}; |
1210 | Array<BufferStore> updates{nullptr}; |
1211 | CommReducer reducer{nullptr}; |
1212 | Array<PrimExpr> combiner_lhs{nullptr}; |
1213 | Array<PrimExpr> combiner_rhs{nullptr}; |
1214 | std::tie(init_values, updates) = GetInitValuesAndUpdatesFromReductionBlock(self, block); |
1215 | std::tie(reducer, combiner_lhs, combiner_rhs) = |
1216 | GetReducerAndCombinerLhsRhs(self, init_values, updates); |
1217 | |
1218 | // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it |
1219 | // is negative. |
1220 | factor_axis = |
1221 | FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, updates[0]->buffer, factor_axis); |
1222 | |
1223 | // ***************************************************** |
1224 | // * IR Manipulation * |
1225 | // ***************************************************** |
1226 | // Since rfactor splits the reduction block into two, we call the first one "rfactor block", and |
1227 | // the latter one "write-back block", and the intermediate buffer is called "rfactor buffer". |
1228 | |
1229 | // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional |
1230 | // dimension that specified by `factor_axis` and `rf_loop`. |
1231 | Array<Buffer> rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); |
1232 | |
1233 | // Step 2. Create the rfactor block. |
1234 | RFactorBlockCreator rf_block_creator(block_realize, GetRef<For>(rf_loop), updates, reducer, |
1235 | rf_buffers, loop_vars2loop, factor_axis, |
1236 | std::move(combiner_rhs)); |
1237 | rf_block_creator.CreateBlock(); |
1238 | |
1239 | // Step 3. Create the write-back block. |
1240 | WriteBackBlockCreator wb_block_creator(block_realize, GetRef<For>(rf_loop), updates, reducer, |
1241 | rf_buffers, std::move(rf_block_creator.additional_iter_), |
1242 | std::move(combiner_lhs), |
1243 | std::move(rf_block_creator.rf_buf_access_indices_)); |
1244 | wb_block_creator.CreateBlock(); |
1245 | |
1246 | // Step 4. Wrap the rfactor block with loops. |
1247 | Stmt rf_body = CreateLoopOutsideRfactorBlock(rf_block_creator.new_block_realize_, loops); |
1248 | |
1249 | // ***************************************************** |
1250 | // * Schedule Replacement & Update * |
1251 | // ***************************************************** |
1252 | |
1253 | // Step 1. Substitute the old scope root block with the new scope root block. |
1254 | Block old_scope_root_block = GetRef<Block>(scope_root->StmtAs<BlockNode>()); |
1255 | Block new_scope_root_block = BlockReplacer::Replace( |
1256 | old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, |
1257 | GetRef<For>(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); |
1258 | self->Replace( |
1259 | scope_root, new_scope_root_block, |
1260 | {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); |
1261 | |
1262 | // Step 2. Update scope information. |
1263 | std::vector<StmtSRef> new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), |
1264 | self->stmt2ref.at(wb_block_creator.new_block_.get())}; |
1265 | for (const StmtSRef& new_block_sref : new_block_srefs) { |
1266 | BlockInfo& info = self->block_info[new_block_sref]; |
1267 | info.affine_binding = true; |
1268 | info.region_cover = true; |
1269 | info.scope->stage_pipeline = true; |
1270 | } |
1271 | return new_block_srefs[0]; |
1272 | } |
1273 | |
1274 | /******** InstructionKind Registration ********/ |
1275 | |
1276 | struct DecomposeReductionTraits : public UnpackedInstTraits<DecomposeReductionTraits> { |
1277 | static constexpr const char* kName = "DecomposeReduction" ; |
1278 | static constexpr bool kIsPure = false; |
1279 | |
1280 | private: |
1281 | static constexpr size_t kNumInputs = 2; |
1282 | static constexpr size_t kNumAttrs = 0; |
1283 | static constexpr size_t kNumDecisions = 0; |
1284 | |
1285 | static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) { |
1286 | return sch->DecomposeReduction(block_rv, loop_rv); |
1287 | } |
1288 | |
1289 | static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv) { |
1290 | PythonAPICall py("decompose_reduction" ); |
1291 | py.Input("block" , block_rv); |
1292 | py.Input("loop" , loop_rv); |
1293 | py.SingleOutput(outputs); |
1294 | return py.Str(); |
1295 | } |
1296 | |
1297 | template <typename> |
1298 | friend struct ::tvm::tir::UnpackedInstTraits; |
1299 | }; |
1300 | |
1301 | struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> { |
1302 | static constexpr const char* kName = "RFactor" ; |
1303 | static constexpr bool kIsPure = false; |
1304 | |
1305 | private: |
1306 | static constexpr size_t kNumInputs = 1; |
1307 | static constexpr size_t kNumAttrs = 1; |
1308 | static constexpr size_t kNumDecisions = 0; |
1309 | |
1310 | static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { |
1311 | return sch->RFactor(loop_rv, factor_axis->value); |
1312 | } |
1313 | |
1314 | static String UnpackedAsPython(Array<String> outputs, String loop_rv, Integer factor_axis) { |
1315 | PythonAPICall py("rfactor" ); |
1316 | py.Input("loop" , loop_rv); |
1317 | py.Input("factor_axis" , factor_axis->value); |
1318 | py.SingleOutput(outputs); |
1319 | return py.Str(); |
1320 | } |
1321 | |
1322 | template <typename> |
1323 | friend struct ::tvm::tir::UnpackedInstTraits; |
1324 | }; |
1325 | |
1326 | TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits); |
1327 | TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); |
1328 | |
1329 | /******** FFI ********/ |
1330 | |
1331 | TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer" ) |
1332 | .set_body_typed([](int n_buffers, PackedFunc combiner_getter, PackedFunc identity_getter) { |
1333 | ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), |
1334 | std::move(identity_getter)); |
1335 | }); |
1336 | |
1337 | } // namespace tir |
1338 | } // namespace tvm |
1339 | |