1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file schedule_ops.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/te/operation.h>
25#include <tvm/te/schedule_pass.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/stmt_functor.h>
29
30#include <unordered_map>
31#include <unordered_set>
32#include <utility>
33
34#include "../../tir/transforms/ir_utils.h"
35#include "../operation/op_utils.h"
36#include "graph.h"
37
38namespace tvm {
39namespace te {
40
41using namespace tir;
42
43// Annotate the statement with the layout transforms and axis
44// separators of the stage. These annotations are removed during
45// SchedulePostProcToPrimFunc. Afterwards, layout transforms are
46// specified in the PrimFunc attrs, and the axis_separators are
47// specified in the BufferNode.
48Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) {
49 if (stage->layout_transforms.size()) {
50 for (int i = 0; i < stage->op->num_outputs(); i++) {
51 body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->layout_transforms},
52 tir::attr::layout_transforms, 1, body);
53 }
54 }
55
56 if (stage->axis_separators.size()) {
57 for (int i = 0; i < stage->op->num_outputs(); i++) {
58 body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->axis_separators},
59 tir::attr::axis_separators, 1, body);
60 }
61 }
62
63 return body;
64}
65
66Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_map, Stmt consumer,
67 bool debug_keep_trivial_loop) {
68 Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
69 if (s->double_buffer) {
70 producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer);
71 }
72 producer = WrapLayoutTransformationAttrs(s, producer);
73 Stmt pipeline = producer;
74
75 if (consumer.defined() && !is_no_op(consumer)) {
76 pipeline = SeqStmt({producer, consumer});
77 }
78
79 if (s->rolling_buffer) {
80 pipeline = AttrStmt(s->op, tir::attr::rolling_buffer_scope, Bool(true), pipeline);
81 }
82
83 return s->op->BuildRealize(s, dom_map, pipeline, s->scope);
84}
85
86// inject the operator's realization on the stmt.
87class InjectAttach : public StmtMutator {
88 public:
89 InjectAttach(const Stage& stage, const Stage& attach_spec,
90 const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop)
91 : stage_(stage),
92 attach_spec_(attach_spec),
93 dom_map_(dom_map),
94 debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
95
96 Stmt VisitStmt(const Stmt& input_stmt) final {
97 ICHECK(input_stmt.defined());
98 auto stmt = StmtMutator::VisitStmt(input_stmt);
99 const AttrStmtNode* op = stmt.as<AttrStmtNode>();
100 if (op != nullptr && op->attr_key == tir::attr::loop_scope) {
101 if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) {
102 ICHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar
103 << " in multiple places in the IR";
104 found_attach = true;
105 stmt = AttrStmt(op->node, op->attr_key, op->value,
106 MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
107 }
108 }
109 return stmt;
110 }
111 // whether attach point is found
112 bool found_attach{false};
113
114 private:
115 // The stage.
116 const Stage& stage_;
117 // The attach spec, may not contain op.
118 const Stage& attach_spec_;
119 // domain map
120 const std::unordered_map<IterVar, Range>& dom_map_;
121 // Whether keep trivial loops with extent of 1 during lowering.
122 // This is a debug feature for dataflow/axis analysis
123 bool debug_keep_trivial_loop_;
124};
125
126// inject the operator's realization on the stmt.
127class InjectScanStep : public StmtMutator {
128 public:
129 InjectScanStep(const Stage& stage, const Operation& scan_op,
130 const std::unordered_map<IterVar, Range>& dom_map, bool is_init,
131 bool debug_keep_trivial_loop)
132 : stage_(stage),
133 scan_op_(scan_op),
134 dom_map_(dom_map),
135 is_init_(is_init),
136 debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
137
138 Stmt VisitStmt(const Stmt& input_stmt) final {
139 ICHECK(input_stmt.defined());
140 auto stmt = StmtMutator::VisitStmt(input_stmt);
141 // update
142 const AttrStmtNode* op = stmt.as<AttrStmtNode>();
143 if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) ||
144 (op->attr_key == tir::attr::scan_init_scope && is_init_))) {
145 if (op->node.same_as(scan_op_)) {
146 found_attach = true;
147 stmt = AttrStmt(op->node, op->attr_key, op->value,
148 MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
149 }
150 }
151 return stmt;
152 }
153
154 // whether attach point is found
155 bool found_attach{false};
156
157 private:
158 // the operations to be carried
159 const Stage& stage_;
160 const Operation& scan_op_;
161 // domain map
162 const std::unordered_map<IterVar, Range>& dom_map_;
163 // whether it is init.
164 bool is_init_;
165 // Whether keep trivial loops with extent of 1 during lowering.
166 // This is a debug feature for dataflow/axis analysis
167 bool debug_keep_trivial_loop_;
168};
169
170// Postprocessing of schedule op
171// Replace the init and update's expression by scan's buffer.
172class SchedulePostProc : public StmtExprMutator {
173 public:
174 Stmt VisitStmt_(const LetStmtNode* op) final {
175 if (SideEffect(op->value) <= CallEffectKind::kPure) {
176 var_value_[op->var.get()] = this->VisitExpr(op->value);
177 return this->VisitStmt(op->body);
178 } else {
179 return StmtExprMutator::VisitStmt_(op);
180 }
181 }
182
183 Stmt VisitStmt_(const AttrStmtNode* op) final {
184 if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) {
185 return this->VisitStmt(op->body);
186 } else if (op->attr_key == tir::attr::scan_update_scope) {
187 const ScanOpNode* scan = op->node.as<ScanOpNode>();
188 ICHECK(scan);
189 var_value_[scan->scan_axis->var.get()] = op->value;
190 return this->VisitStmt(op->body);
191 } else if (op->attr_key == tir::attr::thread_extent) {
192 // delete duplicated thread extent attr
193 auto it = thread_extent_scope_.find(op->node.get());
194 if (it != thread_extent_scope_.end()) {
195 ICHECK(is_zero(analyzer_.Simplify(it->second - op->value)));
196 return this->VisitStmt(op->body);
197 } else {
198 thread_extent_scope_[op->node.get()] = op->value;
199 Stmt ret = StmtExprMutator::VisitStmt_(op);
200 thread_extent_scope_.erase(op->node.get());
201 return ret;
202 }
203 } else if (op->attr_key == tir::attr::double_buffer_scope) {
204 auto it = replace_op_.find(op->node.get());
205 if (it != replace_op_.end()) {
206 if (it->second.defined()) {
207 Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body);
208 return this->VisitStmt(ret);
209 } else {
210 return this->VisitStmt(op->body);
211 }
212 }
213 } else if (op->attr_key == tir::attr::buffer_bind_scope) {
214 Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
215 Tensor tensor = Downcast<Tensor>(tuple[1]);
216 auto it = replace_op_.find(tensor->op.get());
217 if (it != replace_op_.end()) {
218 if (it->second.defined()) {
219 return AttrStmt(Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
220 op->attr_key, op->value, this->VisitStmt(op->body));
221 } else {
222 return this->VisitStmt(op->body);
223 }
224 }
225 } else if (op->attr_key == tir::attr::buffer_dim_align) {
226 Tensor tensor = Downcast<Tensor>(op->node);
227 auto it = replace_op_.find(tensor->op.get());
228 if (it != replace_op_.end()) {
229 if (it->second.defined()) {
230 return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value,
231 this->VisitStmt(op->body));
232 } else {
233 return this->VisitStmt(op->body);
234 }
235 }
236 } else if (op->attr_key == tir::attr::layout_transforms ||
237 op->attr_key == tir::attr::axis_separators) {
238 auto arr = Downcast<Array<ObjectRef>>(op->node);
239 ICHECK_EQ(arr.size(), 2);
240
241 Stmt body = op->body;
242
243 Tensor tensor = Downcast<Tensor>(arr[0]);
244 auto it = replace_op_.find(tensor->op.get());
245 if (it != replace_op_.end()) {
246 if (it->second.defined()) {
247 return AttrStmt(Array<ObjectRef>{it->second.output(tensor->value_index), arr[1]},
248 op->attr_key, op->value, this->VisitStmt(op->body));
249 } else {
250 return this->VisitStmt(op->body);
251 }
252 }
253 }
254 return StmtExprMutator::VisitStmt_(op);
255 }
256
257 Stmt VisitStmt_(const ProducerRealizeNode* op) final {
258 auto key = Downcast<Tensor>(op->producer);
259 auto it = replace_realize_.find(key);
260 if (it != replace_realize_.end()) {
261 if (it->second.defined()) {
262 Stmt ret =
263 ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope);
264 return this->VisitStmt(ret);
265 } else {
266 return this->VisitStmt(op->body);
267 }
268 } else {
269 return StmtExprMutator::VisitStmt_(op);
270 }
271 }
272
273 Stmt VisitStmt_(const ProducerStoreNode* op) final {
274 auto key = Downcast<Tensor>(op->producer);
275 auto it = replace_buffer_.find(key);
276 if (it != replace_buffer_.end()) {
277 const Tensor& dst = it->second;
278 Stmt ret = ProducerStore(dst, op->value, op->indices);
279 return this->VisitStmt(ret);
280 } else {
281 return StmtExprMutator::VisitStmt_(op);
282 }
283 }
284
285 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
286 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
287 op = expr.as<ProducerLoadNode>();
288 ICHECK(op != nullptr);
289
290 auto key = Downcast<Tensor>(op->producer);
291 auto it = replace_buffer_.find(key);
292 if (it != replace_buffer_.end()) {
293 const Tensor& dst = it->second;
294 return ProducerLoad(dst, op->indices);
295 } else {
296 return expr;
297 }
298 }
299
300 PrimExpr VisitExpr_(const VarNode* op) final {
301 auto it = var_value_.find(op);
302 if (it != var_value_.end()) {
303 return it->second;
304 } else {
305 return GetRef<PrimExpr>(op);
306 }
307 }
308
309 void Init(const Schedule& sch) {
310 for (Stage s : sch->stages) {
311 for (auto kv : s->iter_var_attrs) {
312 // Update bind thread information.
313 if (kv.second->bind_thread.defined()) {
314 const Var& from = kv.first->var;
315 const Var& to = kv.second->bind_thread->var;
316 ICHECK(!var_value_.count(from.get()));
317 var_value_[from.get()] = to;
318 }
319 }
320 // This must be checked for all ops, including scan.
321 if (!s->op.same_as(s->origin_op)) {
322 for (int i = 0; i < s->op->num_outputs(); ++i) {
323 Tensor target = s->origin_op.output(i);
324 AddReplace(s->op.output(i), target, target, s->origin_op);
325 }
326 }
327 // Specially add replacements for scan op.
328 if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
329 for (size_t i = 0; i < scan->update.size(); ++i) {
330 Tensor t = s->origin_op.output(i);
331 AddReplace(scan->init[i], t);
332 AddReplace(scan->update[i], t);
333 AddReplace(scan->state_placeholder[i], t);
334 }
335 }
336 }
337 }
338
339 private:
340 void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(),
341 Operation repl_op = Operation()) {
342 replace_buffer_[src] = dst;
343 replace_realize_[src] = repl_realize;
344 replace_op_[src->op.get()] = repl_op;
345 }
346 // The thread extent scope.
347 std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
348 // The scan value
349 std::unordered_map<const VarNode*, PrimExpr> var_value_;
350 // buffer replacement
351 std::unordered_map<Tensor, Tensor> replace_buffer_;
352 // buffere realization to be replaced
353 std::unordered_map<Tensor, Tensor> replace_realize_;
354 // replace producer consumer.
355 std::unordered_map<const Object*, Operation> replace_op_;
356 // integer analyzer
357 arith::Analyzer analyzer_;
358};
359
360Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
361 Stmt body = Stmt();
362 std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
363 // scan init and scan updates
364 std::unordered_map<Operation, Operation> scan_init;
365 for (Stage s : sch->stages) {
366 const ScanOpNode* scan = s->op.as<ScanOpNode>();
367 if (!scan) continue;
368 for (Tensor t : scan->init) {
369 if (scan_init.count(t->op)) {
370 ICHECK(scan_init.at(t->op).same_as(s->op))
371 << "Scan init tensor can only belong to one scan";
372 } else {
373 scan_init[t->op] = s->op;
374 }
375 }
376 }
377 // verify correctness of group.
378 for (Stage g : sch->groups) {
379 ICHECK(!g->op.defined());
380 ICHECK_EQ(g->leaf_iter_vars.size(), 0U);
381 }
382 // reverse the post DFS order.
383 for (size_t i = sch->stages.size(); i != 0; --i) {
384 Stage s = sch->stages[i - 1];
385 ICHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops";
386 ICHECK(s->op.defined());
387 // Remove grouping sugar, get the real attach spec.
388 Stage attach_spec = s.GetAttachSpec();
389
390 if (s->op.as<PlaceholderOpNode>()) {
391 // Placeholders don't need any realize/provide statements, but
392 // may be annotated with set_physical_layout to indicate the
393 // physical layout of an input, and must still have the
394 // attribute given.
395 body = WrapLayoutTransformationAttrs(s, std::move(body));
396 } else if (scan_init.count(s->op)) {
397 ICHECK(body.defined());
398 InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
399 body = mu(std::move(body));
400 ICHECK(mu.found_attach) << "did not find attachment point for scan.init";
401 } else if (attach_spec->attach_type == kScanUpdate) {
402 // Handle scan update
403 ICHECK(body.defined());
404 InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
405 body = mu(std::move(body));
406 ICHECK(mu.found_attach) << "did not find attachment point for scan.update";
407 } else if (attach_spec->attach_type == kInlinedAlready) {
408 // do nothing
409 } else if (attach_spec->attach_type == kGroupRoot) {
410 ICHECK(!s->group.defined());
411 body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
412 } else {
413 ICHECK_EQ(attach_spec->attach_type, kScope);
414 ICHECK(body.defined());
415 InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
416 body = mutator(std::move(body));
417 ICHECK(mutator.found_attach)
418 << "did not find attachment point for " << s << " in " << attach_spec->attach_stage->op
419 << " x " << attach_spec->attach_ivar << ", body:\n"
420 << body;
421 }
422 }
423
424 SchedulePostProc post_proc;
425 post_proc.Init(sch);
426 return post_proc(std::move(body));
427}
428
429TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) {
430 if (args.size() == 2)
431 *ret = ScheduleOps(args[0], args[1], false);
432 else
433 *ret = ScheduleOps(args[0], args[1], args[2]);
434});
435
436} // namespace te
437} // namespace tvm
438