1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file schedule_ops.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/te/operation.h> |
25 | #include <tvm/te/schedule_pass.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | #include <utility> |
33 | |
34 | #include "../../tir/transforms/ir_utils.h" |
35 | #include "../operation/op_utils.h" |
36 | #include "graph.h" |
37 | |
38 | namespace tvm { |
39 | namespace te { |
40 | |
41 | using namespace tir; |
42 | |
43 | // Annotate the statement with the layout transforms and axis |
44 | // separators of the stage. These annotations are removed during |
45 | // SchedulePostProcToPrimFunc. Afterwards, layout transforms are |
46 | // specified in the PrimFunc attrs, and the axis_separators are |
47 | // specified in the BufferNode. |
48 | Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) { |
49 | if (stage->layout_transforms.size()) { |
50 | for (int i = 0; i < stage->op->num_outputs(); i++) { |
51 | body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->layout_transforms}, |
52 | tir::attr::layout_transforms, 1, body); |
53 | } |
54 | } |
55 | |
56 | if (stage->axis_separators.size()) { |
57 | for (int i = 0; i < stage->op->num_outputs(); i++) { |
58 | body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->axis_separators}, |
59 | tir::attr::axis_separators, 1, body); |
60 | } |
61 | } |
62 | |
63 | return body; |
64 | } |
65 | |
66 | Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_map, Stmt consumer, |
67 | bool debug_keep_trivial_loop) { |
68 | Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); |
69 | if (s->double_buffer) { |
70 | producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer); |
71 | } |
72 | producer = WrapLayoutTransformationAttrs(s, producer); |
73 | Stmt pipeline = producer; |
74 | |
75 | if (consumer.defined() && !is_no_op(consumer)) { |
76 | pipeline = SeqStmt({producer, consumer}); |
77 | } |
78 | |
79 | if (s->rolling_buffer) { |
80 | pipeline = AttrStmt(s->op, tir::attr::rolling_buffer_scope, Bool(true), pipeline); |
81 | } |
82 | |
83 | return s->op->BuildRealize(s, dom_map, pipeline, s->scope); |
84 | } |
85 | |
86 | // inject the operator's realization on the stmt. |
87 | class InjectAttach : public StmtMutator { |
88 | public: |
89 | InjectAttach(const Stage& stage, const Stage& attach_spec, |
90 | const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop) |
91 | : stage_(stage), |
92 | attach_spec_(attach_spec), |
93 | dom_map_(dom_map), |
94 | debug_keep_trivial_loop_(debug_keep_trivial_loop) {} |
95 | |
96 | Stmt VisitStmt(const Stmt& input_stmt) final { |
97 | ICHECK(input_stmt.defined()); |
98 | auto stmt = StmtMutator::VisitStmt(input_stmt); |
99 | const AttrStmtNode* op = stmt.as<AttrStmtNode>(); |
100 | if (op != nullptr && op->attr_key == tir::attr::loop_scope) { |
101 | if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) { |
102 | ICHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar |
103 | << " in multiple places in the IR" ; |
104 | found_attach = true; |
105 | stmt = AttrStmt(op->node, op->attr_key, op->value, |
106 | MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); |
107 | } |
108 | } |
109 | return stmt; |
110 | } |
111 | // whether attach point is found |
112 | bool found_attach{false}; |
113 | |
114 | private: |
115 | // The stage. |
116 | const Stage& stage_; |
117 | // The attach spec, may not contain op. |
118 | const Stage& attach_spec_; |
119 | // domain map |
120 | const std::unordered_map<IterVar, Range>& dom_map_; |
121 | // Whether keep trivial loops with extent of 1 during lowering. |
122 | // This is a debug feature for dataflow/axis analysis |
123 | bool debug_keep_trivial_loop_; |
124 | }; |
125 | |
126 | // inject the operator's realization on the stmt. |
127 | class InjectScanStep : public StmtMutator { |
128 | public: |
129 | InjectScanStep(const Stage& stage, const Operation& scan_op, |
130 | const std::unordered_map<IterVar, Range>& dom_map, bool is_init, |
131 | bool debug_keep_trivial_loop) |
132 | : stage_(stage), |
133 | scan_op_(scan_op), |
134 | dom_map_(dom_map), |
135 | is_init_(is_init), |
136 | debug_keep_trivial_loop_(debug_keep_trivial_loop) {} |
137 | |
138 | Stmt VisitStmt(const Stmt& input_stmt) final { |
139 | ICHECK(input_stmt.defined()); |
140 | auto stmt = StmtMutator::VisitStmt(input_stmt); |
141 | // update |
142 | const AttrStmtNode* op = stmt.as<AttrStmtNode>(); |
143 | if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || |
144 | (op->attr_key == tir::attr::scan_init_scope && is_init_))) { |
145 | if (op->node.same_as(scan_op_)) { |
146 | found_attach = true; |
147 | stmt = AttrStmt(op->node, op->attr_key, op->value, |
148 | MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); |
149 | } |
150 | } |
151 | return stmt; |
152 | } |
153 | |
154 | // whether attach point is found |
155 | bool found_attach{false}; |
156 | |
157 | private: |
158 | // the operations to be carried |
159 | const Stage& stage_; |
160 | const Operation& scan_op_; |
161 | // domain map |
162 | const std::unordered_map<IterVar, Range>& dom_map_; |
163 | // whether it is init. |
164 | bool is_init_; |
165 | // Whether keep trivial loops with extent of 1 during lowering. |
166 | // This is a debug feature for dataflow/axis analysis |
167 | bool debug_keep_trivial_loop_; |
168 | }; |
169 | |
170 | // Postprocessing of schedule op |
171 | // Replace the init and update's expression by scan's buffer. |
172 | class SchedulePostProc : public StmtExprMutator { |
173 | public: |
174 | Stmt VisitStmt_(const LetStmtNode* op) final { |
175 | if (SideEffect(op->value) <= CallEffectKind::kPure) { |
176 | var_value_[op->var.get()] = this->VisitExpr(op->value); |
177 | return this->VisitStmt(op->body); |
178 | } else { |
179 | return StmtExprMutator::VisitStmt_(op); |
180 | } |
181 | } |
182 | |
183 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
184 | if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) { |
185 | return this->VisitStmt(op->body); |
186 | } else if (op->attr_key == tir::attr::scan_update_scope) { |
187 | const ScanOpNode* scan = op->node.as<ScanOpNode>(); |
188 | ICHECK(scan); |
189 | var_value_[scan->scan_axis->var.get()] = op->value; |
190 | return this->VisitStmt(op->body); |
191 | } else if (op->attr_key == tir::attr::thread_extent) { |
192 | // delete duplicated thread extent attr |
193 | auto it = thread_extent_scope_.find(op->node.get()); |
194 | if (it != thread_extent_scope_.end()) { |
195 | ICHECK(is_zero(analyzer_.Simplify(it->second - op->value))); |
196 | return this->VisitStmt(op->body); |
197 | } else { |
198 | thread_extent_scope_[op->node.get()] = op->value; |
199 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
200 | thread_extent_scope_.erase(op->node.get()); |
201 | return ret; |
202 | } |
203 | } else if (op->attr_key == tir::attr::double_buffer_scope) { |
204 | auto it = replace_op_.find(op->node.get()); |
205 | if (it != replace_op_.end()) { |
206 | if (it->second.defined()) { |
207 | Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body); |
208 | return this->VisitStmt(ret); |
209 | } else { |
210 | return this->VisitStmt(op->body); |
211 | } |
212 | } |
213 | } else if (op->attr_key == tir::attr::buffer_bind_scope) { |
214 | Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node); |
215 | Tensor tensor = Downcast<Tensor>(tuple[1]); |
216 | auto it = replace_op_.find(tensor->op.get()); |
217 | if (it != replace_op_.end()) { |
218 | if (it->second.defined()) { |
219 | return AttrStmt(Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)}, |
220 | op->attr_key, op->value, this->VisitStmt(op->body)); |
221 | } else { |
222 | return this->VisitStmt(op->body); |
223 | } |
224 | } |
225 | } else if (op->attr_key == tir::attr::buffer_dim_align) { |
226 | Tensor tensor = Downcast<Tensor>(op->node); |
227 | auto it = replace_op_.find(tensor->op.get()); |
228 | if (it != replace_op_.end()) { |
229 | if (it->second.defined()) { |
230 | return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value, |
231 | this->VisitStmt(op->body)); |
232 | } else { |
233 | return this->VisitStmt(op->body); |
234 | } |
235 | } |
236 | } else if (op->attr_key == tir::attr::layout_transforms || |
237 | op->attr_key == tir::attr::axis_separators) { |
238 | auto arr = Downcast<Array<ObjectRef>>(op->node); |
239 | ICHECK_EQ(arr.size(), 2); |
240 | |
241 | Stmt body = op->body; |
242 | |
243 | Tensor tensor = Downcast<Tensor>(arr[0]); |
244 | auto it = replace_op_.find(tensor->op.get()); |
245 | if (it != replace_op_.end()) { |
246 | if (it->second.defined()) { |
247 | return AttrStmt(Array<ObjectRef>{it->second.output(tensor->value_index), arr[1]}, |
248 | op->attr_key, op->value, this->VisitStmt(op->body)); |
249 | } else { |
250 | return this->VisitStmt(op->body); |
251 | } |
252 | } |
253 | } |
254 | return StmtExprMutator::VisitStmt_(op); |
255 | } |
256 | |
257 | Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
258 | auto key = Downcast<Tensor>(op->producer); |
259 | auto it = replace_realize_.find(key); |
260 | if (it != replace_realize_.end()) { |
261 | if (it->second.defined()) { |
262 | Stmt ret = |
263 | ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); |
264 | return this->VisitStmt(ret); |
265 | } else { |
266 | return this->VisitStmt(op->body); |
267 | } |
268 | } else { |
269 | return StmtExprMutator::VisitStmt_(op); |
270 | } |
271 | } |
272 | |
273 | Stmt VisitStmt_(const ProducerStoreNode* op) final { |
274 | auto key = Downcast<Tensor>(op->producer); |
275 | auto it = replace_buffer_.find(key); |
276 | if (it != replace_buffer_.end()) { |
277 | const Tensor& dst = it->second; |
278 | Stmt ret = ProducerStore(dst, op->value, op->indices); |
279 | return this->VisitStmt(ret); |
280 | } else { |
281 | return StmtExprMutator::VisitStmt_(op); |
282 | } |
283 | } |
284 | |
285 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
286 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
287 | op = expr.as<ProducerLoadNode>(); |
288 | ICHECK(op != nullptr); |
289 | |
290 | auto key = Downcast<Tensor>(op->producer); |
291 | auto it = replace_buffer_.find(key); |
292 | if (it != replace_buffer_.end()) { |
293 | const Tensor& dst = it->second; |
294 | return ProducerLoad(dst, op->indices); |
295 | } else { |
296 | return expr; |
297 | } |
298 | } |
299 | |
300 | PrimExpr VisitExpr_(const VarNode* op) final { |
301 | auto it = var_value_.find(op); |
302 | if (it != var_value_.end()) { |
303 | return it->second; |
304 | } else { |
305 | return GetRef<PrimExpr>(op); |
306 | } |
307 | } |
308 | |
309 | void Init(const Schedule& sch) { |
310 | for (Stage s : sch->stages) { |
311 | for (auto kv : s->iter_var_attrs) { |
312 | // Update bind thread information. |
313 | if (kv.second->bind_thread.defined()) { |
314 | const Var& from = kv.first->var; |
315 | const Var& to = kv.second->bind_thread->var; |
316 | ICHECK(!var_value_.count(from.get())); |
317 | var_value_[from.get()] = to; |
318 | } |
319 | } |
320 | // This must be checked for all ops, including scan. |
321 | if (!s->op.same_as(s->origin_op)) { |
322 | for (int i = 0; i < s->op->num_outputs(); ++i) { |
323 | Tensor target = s->origin_op.output(i); |
324 | AddReplace(s->op.output(i), target, target, s->origin_op); |
325 | } |
326 | } |
327 | // Specially add replacements for scan op. |
328 | if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) { |
329 | for (size_t i = 0; i < scan->update.size(); ++i) { |
330 | Tensor t = s->origin_op.output(i); |
331 | AddReplace(scan->init[i], t); |
332 | AddReplace(scan->update[i], t); |
333 | AddReplace(scan->state_placeholder[i], t); |
334 | } |
335 | } |
336 | } |
337 | } |
338 | |
339 | private: |
340 | void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), |
341 | Operation repl_op = Operation()) { |
342 | replace_buffer_[src] = dst; |
343 | replace_realize_[src] = repl_realize; |
344 | replace_op_[src->op.get()] = repl_op; |
345 | } |
346 | // The thread extent scope. |
347 | std::unordered_map<const Object*, PrimExpr> thread_extent_scope_; |
348 | // The scan value |
349 | std::unordered_map<const VarNode*, PrimExpr> var_value_; |
350 | // buffer replacement |
351 | std::unordered_map<Tensor, Tensor> replace_buffer_; |
352 | // buffere realization to be replaced |
353 | std::unordered_map<Tensor, Tensor> replace_realize_; |
354 | // replace producer consumer. |
355 | std::unordered_map<const Object*, Operation> replace_op_; |
356 | // integer analyzer |
357 | arith::Analyzer analyzer_; |
358 | }; |
359 | |
360 | Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) { |
361 | Stmt body = Stmt(); |
362 | std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_); |
363 | // scan init and scan updates |
364 | std::unordered_map<Operation, Operation> scan_init; |
365 | for (Stage s : sch->stages) { |
366 | const ScanOpNode* scan = s->op.as<ScanOpNode>(); |
367 | if (!scan) continue; |
368 | for (Tensor t : scan->init) { |
369 | if (scan_init.count(t->op)) { |
370 | ICHECK(scan_init.at(t->op).same_as(s->op)) |
371 | << "Scan init tensor can only belong to one scan" ; |
372 | } else { |
373 | scan_init[t->op] = s->op; |
374 | } |
375 | } |
376 | } |
377 | // verify correctness of group. |
378 | for (Stage g : sch->groups) { |
379 | ICHECK(!g->op.defined()); |
380 | ICHECK_EQ(g->leaf_iter_vars.size(), 0U); |
381 | } |
382 | // reverse the post DFS order. |
383 | for (size_t i = sch->stages.size(); i != 0; --i) { |
384 | Stage s = sch->stages[i - 1]; |
385 | ICHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops" ; |
386 | ICHECK(s->op.defined()); |
387 | // Remove grouping sugar, get the real attach spec. |
388 | Stage attach_spec = s.GetAttachSpec(); |
389 | |
390 | if (s->op.as<PlaceholderOpNode>()) { |
391 | // Placeholders don't need any realize/provide statements, but |
392 | // may be annotated with set_physical_layout to indicate the |
393 | // physical layout of an input, and must still have the |
394 | // attribute given. |
395 | body = WrapLayoutTransformationAttrs(s, std::move(body)); |
396 | } else if (scan_init.count(s->op)) { |
397 | ICHECK(body.defined()); |
398 | InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); |
399 | body = mu(std::move(body)); |
400 | ICHECK(mu.found_attach) << "did not find attachment point for scan.init" ; |
401 | } else if (attach_spec->attach_type == kScanUpdate) { |
402 | // Handle scan update |
403 | ICHECK(body.defined()); |
404 | InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); |
405 | body = mu(std::move(body)); |
406 | ICHECK(mu.found_attach) << "did not find attachment point for scan.update" ; |
407 | } else if (attach_spec->attach_type == kInlinedAlready) { |
408 | // do nothing |
409 | } else if (attach_spec->attach_type == kGroupRoot) { |
410 | ICHECK(!s->group.defined()); |
411 | body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop); |
412 | } else { |
413 | ICHECK_EQ(attach_spec->attach_type, kScope); |
414 | ICHECK(body.defined()); |
415 | InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); |
416 | body = mutator(std::move(body)); |
417 | ICHECK(mutator.found_attach) |
418 | << "did not find attachment point for " << s << " in " << attach_spec->attach_stage->op |
419 | << " x " << attach_spec->attach_ivar << ", body:\n" |
420 | << body; |
421 | } |
422 | } |
423 | |
424 | SchedulePostProc post_proc; |
425 | post_proc.Init(sch); |
426 | return post_proc(std::move(body)); |
427 | } |
428 | |
429 | TVM_REGISTER_GLOBAL("schedule.ScheduleOps" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
430 | if (args.size() == 2) |
431 | *ret = ScheduleOps(args[0], args[1], false); |
432 | else |
433 | *ret = ScheduleOps(args[0], args[1], args[2]); |
434 | }); |
435 | |
436 | } // namespace te |
437 | } // namespace tvm |
438 | |