1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \brief Lift specified AttrStmt scope to outer if |
23 | * the body contains the same scope. |
24 | * \file lift_attr_scope.cc |
25 | */ |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include "ir_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | // NOTE: this optimization can only be applied |
36 | // to a few specified attr keys |
37 | class AttrScopeLifter : public StmtMutator { |
38 | public: |
39 | explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} |
40 | |
41 | Stmt Lift(Stmt stmt) { |
42 | stmt = operator()(std::move(stmt)); |
43 | if (attr_node_.defined()) { |
44 | stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt); |
45 | } |
46 | return stmt; |
47 | } |
48 | |
49 | // do not go beyond |
50 | Stmt VisitStmt_(const AllocateNode* op) final { |
51 | Stmt stmt = StmtMutator::VisitStmt_(op); |
52 | op = stmt.as<AllocateNode>(); |
53 | if (attr_node_.defined()) { |
54 | Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body); |
55 | // undefine them |
56 | attr_node_ = ObjectRef(); |
57 | attr_value_ = PrimExpr(); |
58 | return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); |
59 | } else { |
60 | return stmt; |
61 | } |
62 | } |
63 | |
64 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
65 | if (op->attr_key == attr_key_) { |
66 | attr_node_ = op->node; |
67 | attr_value_ = op->value; |
68 | return op->body; |
69 | } else { |
70 | return StmtMutator::VisitStmt_(op); |
71 | } |
72 | } |
73 | |
74 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
75 | // remember the decorations. |
76 | std::vector<ObjectRef> attr_node; |
77 | std::vector<PrimExpr> attr_value; |
78 | |
79 | auto fmutate = [&](const Stmt& s) { |
80 | attr_node_ = ObjectRef(); |
81 | attr_value_ = PrimExpr(); |
82 | Stmt ret = this->VisitStmt(s); |
83 | attr_node.push_back(attr_node_); |
84 | attr_value.push_back(attr_value_); |
85 | return ret; |
86 | }; |
87 | Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate); |
88 | if (attr_node.size() == 0) return ret; |
89 | |
90 | op = ret.as<SeqStmtNode>(); |
91 | ICHECK(op != nullptr); |
92 | Array<Stmt> reorg; |
93 | // check if all decorations are common. |
94 | for (size_t begin = 0; begin < attr_node.size();) { |
95 | size_t end = begin + 1; |
96 | while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && |
97 | ValueSame(attr_value[end], attr_value[begin])) { |
98 | ++end; |
99 | } |
100 | // covers everything |
101 | // lift attr to parent. |
102 | if (begin == 0 && end == attr_node.size()) { |
103 | attr_node_ = attr_node[0]; |
104 | attr_value_ = attr_value[0]; |
105 | return ret; |
106 | } |
107 | // construct subsegments. |
108 | Array<Stmt> seq; |
109 | for (size_t i = begin; i < end; ++i) { |
110 | seq.push_back(op->seq[i]); |
111 | } |
112 | Stmt stmt = SeqStmt::Flatten(seq); |
113 | if (attr_node[begin].defined()) { |
114 | stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt); |
115 | } |
116 | reorg.push_back(stmt); |
117 | begin = end; |
118 | } |
119 | attr_node_ = ObjectRef(); |
120 | attr_value_ = PrimExpr(); |
121 | return SeqStmt::Flatten(reorg); |
122 | } |
123 | |
124 | Stmt VisitStmt_(const IfThenElseNode* op) final { |
125 | if (!op->else_case) { |
126 | return StmtMutator::VisitStmt_(op); |
127 | } |
128 | Stmt then_case = this->VisitStmt(op->then_case); |
129 | ObjectRef first_node; |
130 | PrimExpr first_value; |
131 | std::swap(first_node, attr_node_); |
132 | std::swap(first_value, attr_value_); |
133 | Stmt else_case = this->VisitStmt(op->else_case.value()); |
134 | if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && |
135 | first_value.defined() && attr_node_.same_as(first_node) && |
136 | ValueSame(attr_value_, first_value)) { |
137 | if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { |
138 | return GetRef<Stmt>(op); |
139 | } else { |
140 | return IfThenElse(op->condition, then_case, else_case); |
141 | } |
142 | } else { |
143 | if (first_node.defined()) { |
144 | then_case = AttrStmt(first_node, attr_key_, first_value, then_case); |
145 | } |
146 | if (attr_node_.defined()) { |
147 | else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case); |
148 | // undefine them |
149 | attr_node_ = ObjectRef(); |
150 | attr_value_ = PrimExpr(); |
151 | } |
152 | if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { |
153 | return GetRef<Stmt>(op); |
154 | } else { |
155 | return IfThenElse(op->condition, then_case, else_case); |
156 | } |
157 | } |
158 | } |
159 | |
160 | Stmt VisitStmt_(const WhileNode* op) final { |
161 | // TODO(masahi): Do we need a special handling for While nodes? |
162 | LOG(FATAL) << "WhileNode not supported in LiftAttrScope." ; |
163 | } |
164 | |
165 | private: |
166 | // value comparison that also compares content of int constant |
167 | static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { |
168 | if (a.same_as(b)) return true; |
169 | if (!a.defined() || !b.defined()) return false; |
170 | if (a->type_index() != b->type_index()) return false; |
171 | if (a.dtype() != b.dtype()) return false; |
172 | if (const IntImmNode* op = a.as<IntImmNode>()) { |
173 | return op->value == b.as<IntImmNode>()->value; |
174 | } |
175 | return false; |
176 | } |
177 | |
178 | std::string attr_key_; |
179 | ObjectRef attr_node_; |
180 | PrimExpr attr_value_; |
181 | }; |
182 | |
183 | Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { |
184 | return AttrScopeLifter(attr_key).Lift(std::move(stmt)); |
185 | } |
186 | |
187 | namespace transform { |
188 | |
189 | Pass LiftAttrScope(String attr_key) { |
190 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
191 | auto* n = f.CopyOnWrite(); |
192 | n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); |
193 | return f; |
194 | }; |
195 | return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope" , {}); |
196 | } |
197 | |
198 | TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope" ).set_body_typed(LiftAttrScope); |
199 | |
200 | } // namespace transform |
201 | |
202 | } // namespace tir |
203 | } // namespace tvm |
204 | |