1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \brief Lift specified AttrStmt scope to outer if
23 * the body contains the same scope.
24 * \file lift_attr_scope.cc
25 */
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include "ir_utils.h"
31
32namespace tvm {
33namespace tir {
34
35// NOTE: this optimization can only be applied
36// to a few specified attr keys
37class AttrScopeLifter : public StmtMutator {
38 public:
39 explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {}
40
41 Stmt Lift(Stmt stmt) {
42 stmt = operator()(std::move(stmt));
43 if (attr_node_.defined()) {
44 stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt);
45 }
46 return stmt;
47 }
48
49 // do not go beyond
50 Stmt VisitStmt_(const AllocateNode* op) final {
51 Stmt stmt = StmtMutator::VisitStmt_(op);
52 op = stmt.as<AllocateNode>();
53 if (attr_node_.defined()) {
54 Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body);
55 // undefine them
56 attr_node_ = ObjectRef();
57 attr_value_ = PrimExpr();
58 return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body);
59 } else {
60 return stmt;
61 }
62 }
63
64 Stmt VisitStmt_(const AttrStmtNode* op) final {
65 if (op->attr_key == attr_key_) {
66 attr_node_ = op->node;
67 attr_value_ = op->value;
68 return op->body;
69 } else {
70 return StmtMutator::VisitStmt_(op);
71 }
72 }
73
74 Stmt VisitStmt_(const SeqStmtNode* op) final {
75 // remember the decorations.
76 std::vector<ObjectRef> attr_node;
77 std::vector<PrimExpr> attr_value;
78
79 auto fmutate = [&](const Stmt& s) {
80 attr_node_ = ObjectRef();
81 attr_value_ = PrimExpr();
82 Stmt ret = this->VisitStmt(s);
83 attr_node.push_back(attr_node_);
84 attr_value.push_back(attr_value_);
85 return ret;
86 };
87 Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
88 if (attr_node.size() == 0) return ret;
89
90 op = ret.as<SeqStmtNode>();
91 ICHECK(op != nullptr);
92 Array<Stmt> reorg;
93 // check if all decorations are common.
94 for (size_t begin = 0; begin < attr_node.size();) {
95 size_t end = begin + 1;
96 while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) &&
97 ValueSame(attr_value[end], attr_value[begin])) {
98 ++end;
99 }
100 // covers everything
101 // lift attr to parent.
102 if (begin == 0 && end == attr_node.size()) {
103 attr_node_ = attr_node[0];
104 attr_value_ = attr_value[0];
105 return ret;
106 }
107 // construct subsegments.
108 Array<Stmt> seq;
109 for (size_t i = begin; i < end; ++i) {
110 seq.push_back(op->seq[i]);
111 }
112 Stmt stmt = SeqStmt::Flatten(seq);
113 if (attr_node[begin].defined()) {
114 stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt);
115 }
116 reorg.push_back(stmt);
117 begin = end;
118 }
119 attr_node_ = ObjectRef();
120 attr_value_ = PrimExpr();
121 return SeqStmt::Flatten(reorg);
122 }
123
124 Stmt VisitStmt_(const IfThenElseNode* op) final {
125 if (!op->else_case) {
126 return StmtMutator::VisitStmt_(op);
127 }
128 Stmt then_case = this->VisitStmt(op->then_case);
129 ObjectRef first_node;
130 PrimExpr first_value;
131 std::swap(first_node, attr_node_);
132 std::swap(first_value, attr_value_);
133 Stmt else_case = this->VisitStmt(op->else_case.value());
134 if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
135 first_value.defined() && attr_node_.same_as(first_node) &&
136 ValueSame(attr_value_, first_value)) {
137 if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
138 return GetRef<Stmt>(op);
139 } else {
140 return IfThenElse(op->condition, then_case, else_case);
141 }
142 } else {
143 if (first_node.defined()) {
144 then_case = AttrStmt(first_node, attr_key_, first_value, then_case);
145 }
146 if (attr_node_.defined()) {
147 else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case);
148 // undefine them
149 attr_node_ = ObjectRef();
150 attr_value_ = PrimExpr();
151 }
152 if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
153 return GetRef<Stmt>(op);
154 } else {
155 return IfThenElse(op->condition, then_case, else_case);
156 }
157 }
158 }
159
160 Stmt VisitStmt_(const WhileNode* op) final {
161 // TODO(masahi): Do we need a special handling for While nodes?
162 LOG(FATAL) << "WhileNode not supported in LiftAttrScope.";
163 }
164
165 private:
166 // value comparison that also compares content of int constant
167 static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
168 if (a.same_as(b)) return true;
169 if (!a.defined() || !b.defined()) return false;
170 if (a->type_index() != b->type_index()) return false;
171 if (a.dtype() != b.dtype()) return false;
172 if (const IntImmNode* op = a.as<IntImmNode>()) {
173 return op->value == b.as<IntImmNode>()->value;
174 }
175 return false;
176 }
177
178 std::string attr_key_;
179 ObjectRef attr_node_;
180 PrimExpr attr_value_;
181};
182
183Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
184 return AttrScopeLifter(attr_key).Lift(std::move(stmt));
185}
186
187namespace transform {
188
189Pass LiftAttrScope(String attr_key) {
190 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
191 auto* n = f.CopyOnWrite();
192 n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
193 return f;
194 };
195 return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
196}
197
198TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope);
199
200} // namespace transform
201
202} // namespace tir
203} // namespace tvm
204