1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file narrow_datatype.cc |
22 | * \brief narrow the datatype of indexing vars |
23 | */ |
24 | |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/data_type_rewriter.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include "../../arith/ir_mutator_with_analyzer.h" |
32 | #include "../../arith/ir_visitor_with_analyzer.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | // This pass narrows indexing expressions (like StoreNode::Index) |
38 | // that trivially fit into i32/i16 (denoted by `target_bits_`) to |
39 | // i32/i16. Considering that i32/i16 indices may be more |
40 | // efficient on some backends (while i64 may be more efficient |
41 | // on others, like llvm), we may want this pass when i32/i16 |
42 | // indices are more efficient. |
43 | // |
44 | // For Var v, we determine its dtype by examining all the PrimExpr |
45 | // that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}. |
46 | // If all expressions in E fit into i32/i16, then we think v can be narrowed |
47 | // to i32/i16. |
48 | // |
49 | // To make an indexing expression i32/i16, we must make sure that every |
50 | // component of that expression is of dtype i32/i16. So besides Var, we |
51 | // rewrite the following inside an indexing expression |
52 | // - Var |
53 | // - IntImm |
54 | // - Cast |
55 | // |
56 | // Algorithm: |
57 | // - Use DataTypeVisitor to determine whether a Var can be narrowed or not. |
58 | // - Use DataTypeRewritter to rewrite the components of an indexing expression. |
59 | |
60 | using arith::Analyzer; |
61 | using arith::ConstIntBound; |
62 | using arith::IRMutatorWithAnalyzer; |
63 | |
64 | // Determine the result dtype for Var, IntImm and Cast, |
65 | // which will be stored in `vmap` eventually. |
66 | // |
67 | // Algorithm: |
68 | // We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`. |
69 | // To be more specific, if for each Expr `e` which contains `var` |
70 | // (`var` is a child node of `e` in AST), `e` fits into `target_bits_`, |
71 | // then we narrow `var` into `target_bits_`. That is, |
72 | // `vmap[var] = min(target_bits_, var.dtype.bits())` |
73 | // Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` |
74 | class DataTypeVisitor final : public StmtExprVisitor { |
75 | public: |
76 | explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} |
77 | |
78 | void VisitExpr(const PrimExpr& e) { |
79 | if (e.dtype().is_int()) { |
80 | int bits = max_bits_; |
81 | if (bound_.find(e) == bound_.end()) { |
82 | analyzer_.const_int_bound(e, &bound_); |
83 | } |
84 | ConstIntBound bound = bound_[e]; |
85 | int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value; |
86 | int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value; |
87 | if (e.dtype().bits() <= target_bits_ || |
88 | (bound->max_value <= ubound && bound->min_value >= lbound)) { |
89 | bits = target_bits_; |
90 | } |
91 | int tmp = bits > bits_ ? bits : bits_; |
92 | std::swap(bits_, tmp); |
93 | StmtExprVisitor::VisitExpr(e); |
94 | std::swap(bits_, tmp); |
95 | } else { |
96 | StmtExprVisitor::VisitExpr(e); |
97 | } |
98 | } |
99 | |
100 | void VisitStmt_(const ForNode* op) { |
101 | analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
102 | vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype(); |
103 | return StmtExprVisitor::VisitStmt_(op); |
104 | } |
105 | |
106 | void VisitStmt_(const BlockNode* op) { |
107 | for (const IterVar& iter : op->iter_vars) { |
108 | analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); |
109 | vextent_[iter->var.as<VarNode>()] = iter->dom->extent.dtype(); |
110 | } |
111 | StmtExprVisitor::VisitStmt_(op); |
112 | } |
113 | |
114 | void VisitStmt_(const AttrStmtNode* op) { |
115 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { |
116 | IterVar iv = Downcast<IterVar>(op->node); |
117 | ICHECK_NE(iv->thread_tag.length(), 0U); |
118 | analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); |
119 | vextent_[iv->var.as<VarNode>()] = op->value.dtype(); |
120 | StmtExprVisitor::VisitStmt_(op); |
121 | } else { |
122 | StmtExprVisitor::VisitStmt_(op); |
123 | } |
124 | } |
125 | |
126 | void VisitExpr_(const ReduceNode* op) { |
127 | // Setup the domain information before simplification. |
128 | for (const IterVar& iv : op->axis) { |
129 | analyzer_.Bind(iv->var, iv->dom); |
130 | vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype(); |
131 | } |
132 | // Recursively call simplification when necessary. |
133 | StmtExprVisitor::VisitExpr_(op); |
134 | } |
135 | |
136 | void VisitExpr_(const VarNode* op) { |
137 | if (vextent_.find(op) != vextent_.end()) { |
138 | // We only narrow and never promote, so the result dtype |
139 | // is upperbounded by its original dtype before rewrite. |
140 | int bits = std::min(vextent_[op].bits(), bits_); |
141 | if (vmap.find(op) == vmap.end()) { |
142 | vmap[op] = op->dtype.with_bits(bits); |
143 | } else { |
144 | // We take maximum bits for all the possible Expr where a var occurs |
145 | vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); |
146 | } |
147 | } |
148 | StmtExprVisitor::VisitExpr_(op); |
149 | } |
150 | |
151 | void VisitExpr_(const IntImmNode* op) { |
152 | if (op->dtype.is_int()) { |
153 | // We only narrow and never promote, so the result dtype |
154 | // is upperbounded by its original dtype before rewrite. |
155 | int bits = std::min(op->dtype.bits(), bits_); |
156 | if (vmap.find(op) == vmap.end()) { |
157 | vmap[op] = op->dtype.with_bits(bits); |
158 | } else { |
159 | vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); |
160 | } |
161 | } |
162 | StmtExprVisitor::VisitExpr_(op); |
163 | } |
164 | |
165 | void VisitExpr_(const CastNode* op) { |
166 | if (op->dtype.is_int()) { |
167 | // We only narrow and never promote, so the result dtype |
168 | // is upperbounded by its original dtype before rewrite. |
169 | int bits = std::min(op->dtype.bits(), bits_); |
170 | if (vmap.find(op) == vmap.end()) { |
171 | vmap[op] = op->dtype.with_bits(bits); |
172 | } else { |
173 | vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); |
174 | } |
175 | } |
176 | StmtExprVisitor::VisitExpr_(op); |
177 | } |
178 | |
179 | // the narrowed datatype of Var and IntImm |
180 | std::unordered_map<const PrimExprNode*, DataType> vmap; |
181 | |
182 | protected: |
183 | // internal analyzer |
184 | arith::Analyzer analyzer_; |
185 | |
186 | private: |
187 | // the maximum possible bits, which serves as an init value |
188 | static constexpr const int max_bits_ = 64; |
189 | // the maximum possible bit of the current expression's return dtype |
190 | int bits_; |
191 | // the target bits |
192 | int target_bits_; |
193 | // the extent of vars to be rewritten |
194 | std::unordered_map<const VarNode*, DataType> vextent_; |
195 | // the memorized bound generated by ConstIntBoundAnalyzer |
196 | arith::ConstIntBoundAnalyzer::BoundMapType bound_; |
197 | }; |
198 | |
199 | class NarrowDataTypeRewriter : public IndexDataTypeRewriter { |
200 | public: |
201 | using Parent = IndexDataTypeRewriter; |
202 | explicit NarrowDataTypeRewriter(int target_bits) : visitor_(target_bits) {} |
203 | |
204 | Stmt operator()(Stmt s) { |
205 | visitor_(s); |
206 | for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { |
207 | PrimExpr e = GetRef<PrimExpr>(i->first); |
208 | if (e.dtype() == i->second) { |
209 | i = visitor_.vmap.erase(i); |
210 | } else { |
211 | ++i; |
212 | } |
213 | } |
214 | return VisitStmt(s); |
215 | } |
216 | |
217 | protected: |
218 | // This class adds some overrides of `VisitStmt_` and `VisitExpr_` that |
219 | // are *not* present in the parent class. |
220 | // These `using` statements ensure that all of the *other* overrides |
221 | // provided by the parent class are fully visible to users of this class. |
222 | // (Discussed further in https://github.com/apache/tvm/pull/13267) |
223 | using Parent::VisitExpr_; |
224 | using Parent::VisitStmt_; |
225 | |
226 | Stmt VisitStmt_(const StoreNode* op) final { |
227 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
228 | } |
229 | |
230 | PrimExpr VisitExpr_(const LoadNode* op) final { |
231 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
232 | } |
233 | |
234 | PrimExpr VisitExpr_(const VarNode* op) final { |
235 | if (auto it = visitor_.vmap.find(op); !var_remap_.count(op) && it != visitor_.vmap.end()) { |
236 | var_remap_[op] = Var(op->name_hint, it->second); |
237 | } |
238 | return Parent::VisitExpr_(op); |
239 | } |
240 | |
241 | PrimExpr VisitExpr_(const IntImmNode* op) final { |
242 | if (is_enabled_) { |
243 | if (visitor_.vmap.find(op) != visitor_.vmap.end()) { |
244 | return IntImm(visitor_.vmap[op], op->value); |
245 | } |
246 | } |
247 | return Parent::VisitExpr_(op); |
248 | } |
249 | |
250 | PrimExpr VisitExpr_(const CastNode* op) final { |
251 | if (is_enabled_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { |
252 | PrimExpr e = Parent::VisitExpr_(op); |
253 | const CastNode* new_op = e.as<CastNode>(); |
254 | ICHECK(new_op != nullptr) << "Expected type to be CastNode" |
255 | << ", but get " << e->GetTypeKey(); |
256 | return Cast(visitor_.vmap[op], new_op->value); |
257 | } |
258 | return Parent::VisitExpr_(op); |
259 | } |
260 | |
261 | private: |
262 | // the internal visitor to deduce the narrowed dtype |
263 | DataTypeVisitor visitor_; |
264 | }; |
265 | |
266 | Stmt NarrowDataType(Stmt stmt, int target_bits) { |
267 | return NarrowDataTypeRewriter(target_bits)(stmt); |
268 | } |
269 | |
270 | namespace transform { |
271 | |
272 | Pass NarrowDataType(int target_bits) { |
273 | auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) { |
274 | auto* n = f.CopyOnWrite(); |
275 | n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body)); |
276 | return f; |
277 | }; |
278 | return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType" , {}); |
279 | } |
280 | |
281 | TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType" ).set_body_typed(NarrowDataType); |
282 | |
283 | } // namespace transform |
284 | } // namespace tir |
285 | } // namespace tvm |
286 | |