1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Loop unrolling as in Halide pipeline. |
22 | * \file unroll_loop.cc |
23 | */ |
24 | // Unrolls the loop as in Halide pipeline. |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <unordered_map> |
33 | #include <unordered_set> |
34 | #include <vector> |
35 | |
36 | #include "ir_utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace tir { |
40 | |
41 | struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> { |
42 | int auto_max_step; |
43 | int auto_max_depth; |
44 | int auto_max_extent; |
45 | int explicit_unroll; |
46 | |
47 | TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig" ) { |
48 | TVM_ATTR_FIELD(auto_max_step) |
49 | .describe("Threshold of number of steps in the loop to be automatically unrolled" ) |
50 | .set_default(0); |
51 | TVM_ATTR_FIELD(auto_max_depth) |
52 | .describe("The maximum nested level of loops that can be automatically unrolled." ) |
53 | .set_default(8); |
54 | TVM_ATTR_FIELD(auto_max_extent) |
55 | .describe("The maximum extent of loop that will be unrolled." ) |
56 | .set_default(0); |
57 | TVM_ATTR_FIELD(explicit_unroll) |
58 | .describe("Whether to explicitly unroll the loop instead of setting a pragma" ) |
59 | .set_default(true); |
60 | } |
61 | }; |
62 | |
63 | class UnrollLoopConfig : public Attrs { |
64 | public: |
65 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); |
66 | }; |
67 | |
68 | TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); |
69 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop" , UnrollLoopConfig); |
70 | |
71 | class LoopUnroller : public StmtExprMutator { |
72 | public: |
73 | explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, |
74 | bool explicit_unroll) |
75 | : auto_max_step_(auto_max_step), |
76 | auto_max_depth_(auto_max_depth), |
77 | auto_max_extent_(auto_max_extent), |
78 | explicit_unroll_(explicit_unroll) {} |
79 | |
80 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
81 | if (op->attr_key == "pragma_auto_unroll_max_step" ) { |
82 | int value = static_cast<int>(Downcast<Integer>(op->value)->value); |
83 | std::swap(value, auto_max_step_); |
84 | Stmt ret = this->VisitStmt(op->body); |
85 | std::swap(value, auto_max_step_); |
86 | return ret; |
87 | } else if (op->attr_key == "pragma_unroll_explicit" ) { |
88 | bool explicit_unroll = Downcast<Integer>(op->value)->value; |
89 | std::swap(explicit_unroll, explicit_unroll_); |
90 | Stmt ret = this->VisitStmt(op->body); |
91 | std::swap(explicit_unroll, explicit_unroll_); |
92 | return ret; |
93 | } else { |
94 | return StmtExprMutator::VisitStmt_(op); |
95 | } |
96 | } |
97 | |
98 | Stmt VisitStmt_(const ForNode* op) { |
99 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
100 | op = stmt.as<ForNode>(); |
101 | int value = GetExtent(op); |
102 | // condition for auto unroll |
103 | bool auto_unroll = (op->kind == ForKind::kSerial && value >= 0 && normal_loop_depth_ == 0 && |
104 | unroll_depth_ <= auto_max_depth_); |
105 | |
106 | auto_unroll = |
107 | auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); |
108 | |
109 | if (op->kind == ForKind::kUnrolled) { |
110 | ICHECK_GE(value, 0) << "Cannot unroll non-constant loop" ; |
111 | auto_unroll = true; |
112 | } |
113 | |
114 | if (auto_unroll) { |
115 | step_count_ *= value; |
116 | unroll_depth_ += 1; |
117 | } else { |
118 | normal_loop_depth_ += 1; |
119 | } |
120 | |
121 | if ((auto_unroll && explicit_unroll_) || |
122 | // unroll loops with extent = 1, no matter how many steps in body |
123 | (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) { |
124 | return Unroll(op); |
125 | } else { |
126 | if (auto_unroll) { |
127 | if (op->kind != ForKind::kUnrolled) { |
128 | return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, |
129 | op->thread_binding, op->annotations); |
130 | } |
131 | } |
132 | return stmt; |
133 | } |
134 | } |
135 | |
136 | Stmt VisitStmt_(const StoreNode* op) final { |
137 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
138 | } |
139 | |
140 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
141 | ++step_count_; |
142 | return StmtExprMutator::VisitStmt_(op); |
143 | } |
144 | |
145 | Stmt VisitStmt_(const EvaluateNode* op) final { |
146 | ++step_count_; |
147 | return StmtExprMutator::VisitStmt_(op); |
148 | } |
149 | |
150 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
151 | auto fmutate = [this](const Stmt& s) { |
152 | int step_count = step_count_; |
153 | int unroll_depth = unroll_depth_; |
154 | int normal_loop_depth = normal_loop_depth_; |
155 | step_count_ = 0; |
156 | unroll_depth_ = 0; |
157 | normal_loop_depth_ = 0; |
158 | Stmt ret = this->VisitStmt(s); |
159 | step_count_ += step_count; |
160 | normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); |
161 | unroll_depth_ = std::max(unroll_depth_, unroll_depth); |
162 | return ret; |
163 | }; |
164 | return StmtMutator::VisitSeqStmt_(op, false, fmutate); |
165 | } |
166 | |
167 | Stmt Unroll(const ForNode* op) { |
168 | int value = GetExtent(op); |
169 | // For loop must have a constant integer extent |
170 | ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent" ; |
171 | if (value == 0) return Evaluate(0); |
172 | Stmt body = op->body; |
173 | Map<Var, PrimExpr> vmap; |
174 | Array<Stmt> unrolled; |
175 | for (int i = 0; i < value; ++i) { |
176 | vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); |
177 | Stmt step = Substitute(body, vmap); |
178 | unrolled.push_back(step); |
179 | } |
180 | return SeqStmt::Flatten(unrolled); |
181 | } |
182 | |
183 | private: |
184 | // returns the extent of the loop if it's a constant integer, otherwise return -1 |
185 | int GetExtent(const ForNode* op) { |
186 | // constant folding. |
187 | PrimExpr extent = analyzer_.Simplify(op->extent); |
188 | const IntImmNode* v1 = extent.as<IntImmNode>(); |
189 | int value = -1; |
190 | // integers that do not fit in int32_t are treated as symbolic, |
191 | // as it's impossible to unroll such large loops |
192 | if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) { |
193 | value = static_cast<int>(v1->value); |
194 | } |
195 | return value; |
196 | } |
197 | |
198 | // maximum number of step to perform auto unroll. |
199 | int auto_max_step_; |
200 | int auto_max_depth_; |
201 | // max extent of loop to auto unroll |
202 | // this not not count the total steps, only count the number of loops |
203 | int auto_max_extent_; |
204 | bool explicit_unroll_; |
205 | // Number of normal loops in scope |
206 | int normal_loop_depth_{0}; |
207 | // number of unrolled cases in current scope. |
208 | int unroll_depth_{0}; |
209 | // Number of total steps unrolled |
210 | int step_count_{0}; |
211 | // analyzer |
212 | arith::Analyzer analyzer_; |
213 | }; |
214 | |
215 | Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { |
216 | Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, |
217 | cfg->explicit_unroll)(stmt); |
218 | if (!ret.same_as(stmt)) { |
219 | return ConvertSSA(ret); |
220 | } else { |
221 | return ret; |
222 | } |
223 | } |
224 | |
225 | namespace transform { |
226 | |
227 | Pass UnrollLoop() { |
228 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
229 | auto* n = f.CopyOnWrite(); |
230 | auto cfg = ctx->GetConfig<UnrollLoopConfig>("tir.UnrollLoop" ); |
231 | if (!cfg.defined()) { |
232 | cfg = AttrsWithDefaultValues<UnrollLoopConfig>(); |
233 | } |
234 | n->body = UnrollLoop(std::move(f->body), cfg.value()); |
235 | return f; |
236 | }; |
237 | return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop" , {}); |
238 | } |
239 | |
240 | TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop" ).set_body_typed(UnrollLoop); |
241 | |
242 | } // namespace transform |
243 | |
244 | } // namespace tir |
245 | } // namespace tvm |
246 | |