1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Loop unrolling as in Halide pipeline.
22 * \file unroll_loop.cc
23 */
24// Unrolls the loop as in Halide pipeline.
25#include <tvm/arith/analyzer.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32#include <unordered_map>
33#include <unordered_set>
34#include <vector>
35
36#include "ir_utils.h"
37
38namespace tvm {
39namespace tir {
40
41struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
42 int auto_max_step;
43 int auto_max_depth;
44 int auto_max_extent;
45 int explicit_unroll;
46
47 TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
48 TVM_ATTR_FIELD(auto_max_step)
49 .describe("Threshold of number of steps in the loop to be automatically unrolled")
50 .set_default(0);
51 TVM_ATTR_FIELD(auto_max_depth)
52 .describe("The maximum nested level of loops that can be automatically unrolled.")
53 .set_default(8);
54 TVM_ATTR_FIELD(auto_max_extent)
55 .describe("The maximum extent of loop that will be unrolled.")
56 .set_default(0);
57 TVM_ATTR_FIELD(explicit_unroll)
58 .describe("Whether to explicitly unroll the loop instead of setting a pragma")
59 .set_default(true);
60 }
61};
62
63class UnrollLoopConfig : public Attrs {
64 public:
65 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode);
66};
67
68TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
69TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
70
71class LoopUnroller : public StmtExprMutator {
72 public:
73 explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
74 bool explicit_unroll)
75 : auto_max_step_(auto_max_step),
76 auto_max_depth_(auto_max_depth),
77 auto_max_extent_(auto_max_extent),
78 explicit_unroll_(explicit_unroll) {}
79
80 Stmt VisitStmt_(const AttrStmtNode* op) final {
81 if (op->attr_key == "pragma_auto_unroll_max_step") {
82 int value = static_cast<int>(Downcast<Integer>(op->value)->value);
83 std::swap(value, auto_max_step_);
84 Stmt ret = this->VisitStmt(op->body);
85 std::swap(value, auto_max_step_);
86 return ret;
87 } else if (op->attr_key == "pragma_unroll_explicit") {
88 bool explicit_unroll = Downcast<Integer>(op->value)->value;
89 std::swap(explicit_unroll, explicit_unroll_);
90 Stmt ret = this->VisitStmt(op->body);
91 std::swap(explicit_unroll, explicit_unroll_);
92 return ret;
93 } else {
94 return StmtExprMutator::VisitStmt_(op);
95 }
96 }
97
98 Stmt VisitStmt_(const ForNode* op) {
99 Stmt stmt = StmtExprMutator::VisitStmt_(op);
100 op = stmt.as<ForNode>();
101 int value = GetExtent(op);
102 // condition for auto unroll
103 bool auto_unroll = (op->kind == ForKind::kSerial && value >= 0 && normal_loop_depth_ == 0 &&
104 unroll_depth_ <= auto_max_depth_);
105
106 auto_unroll =
107 auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_);
108
109 if (op->kind == ForKind::kUnrolled) {
110 ICHECK_GE(value, 0) << "Cannot unroll non-constant loop";
111 auto_unroll = true;
112 }
113
114 if (auto_unroll) {
115 step_count_ *= value;
116 unroll_depth_ += 1;
117 } else {
118 normal_loop_depth_ += 1;
119 }
120
121 if ((auto_unroll && explicit_unroll_) ||
122 // unroll loops with extent = 1, no matter how many steps in body
123 (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
124 return Unroll(op);
125 } else {
126 if (auto_unroll) {
127 if (op->kind != ForKind::kUnrolled) {
128 return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body,
129 op->thread_binding, op->annotations);
130 }
131 }
132 return stmt;
133 }
134 }
135
136 Stmt VisitStmt_(const StoreNode* op) final {
137 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
138 }
139
140 Stmt VisitStmt_(const BufferStoreNode* op) final {
141 ++step_count_;
142 return StmtExprMutator::VisitStmt_(op);
143 }
144
145 Stmt VisitStmt_(const EvaluateNode* op) final {
146 ++step_count_;
147 return StmtExprMutator::VisitStmt_(op);
148 }
149
150 Stmt VisitStmt_(const SeqStmtNode* op) final {
151 auto fmutate = [this](const Stmt& s) {
152 int step_count = step_count_;
153 int unroll_depth = unroll_depth_;
154 int normal_loop_depth = normal_loop_depth_;
155 step_count_ = 0;
156 unroll_depth_ = 0;
157 normal_loop_depth_ = 0;
158 Stmt ret = this->VisitStmt(s);
159 step_count_ += step_count;
160 normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
161 unroll_depth_ = std::max(unroll_depth_, unroll_depth);
162 return ret;
163 };
164 return StmtMutator::VisitSeqStmt_(op, false, fmutate);
165 }
166
167 Stmt Unroll(const ForNode* op) {
168 int value = GetExtent(op);
169 // For loop must have a constant integer extent
170 ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
171 if (value == 0) return Evaluate(0);
172 Stmt body = op->body;
173 Map<Var, PrimExpr> vmap;
174 Array<Stmt> unrolled;
175 for (int i = 0; i < value; ++i) {
176 vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
177 Stmt step = Substitute(body, vmap);
178 unrolled.push_back(step);
179 }
180 return SeqStmt::Flatten(unrolled);
181 }
182
183 private:
184 // returns the extent of the loop if it's a constant integer, otherwise return -1
185 int GetExtent(const ForNode* op) {
186 // constant folding.
187 PrimExpr extent = analyzer_.Simplify(op->extent);
188 const IntImmNode* v1 = extent.as<IntImmNode>();
189 int value = -1;
190 // integers that do not fit in int32_t are treated as symbolic,
191 // as it's impossible to unroll such large loops
192 if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
193 value = static_cast<int>(v1->value);
194 }
195 return value;
196 }
197
198 // maximum number of step to perform auto unroll.
199 int auto_max_step_;
200 int auto_max_depth_;
201 // max extent of loop to auto unroll
202 // this not not count the total steps, only count the number of loops
203 int auto_max_extent_;
204 bool explicit_unroll_;
205 // Number of normal loops in scope
206 int normal_loop_depth_{0};
207 // number of unrolled cases in current scope.
208 int unroll_depth_{0};
209 // Number of total steps unrolled
210 int step_count_{0};
211 // analyzer
212 arith::Analyzer analyzer_;
213};
214
215Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
216 Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
217 cfg->explicit_unroll)(stmt);
218 if (!ret.same_as(stmt)) {
219 return ConvertSSA(ret);
220 } else {
221 return ret;
222 }
223}
224
225namespace transform {
226
227Pass UnrollLoop() {
228 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
229 auto* n = f.CopyOnWrite();
230 auto cfg = ctx->GetConfig<UnrollLoopConfig>("tir.UnrollLoop");
231 if (!cfg.defined()) {
232 cfg = AttrsWithDefaultValues<UnrollLoopConfig>();
233 }
234 n->body = UnrollLoop(std::move(f->body), cfg.value());
235 return f;
236 };
237 return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
238}
239
240TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop);
241
242} // namespace transform
243
244} // namespace tir
245} // namespace tvm
246