1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file narrow_datatype.cc
22 * \brief narrow the datatype of indexing vars
23 */
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/data_type_rewriter.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/transform.h>
30
31#include "../../arith/ir_mutator_with_analyzer.h"
32#include "../../arith/ir_visitor_with_analyzer.h"
33
34namespace tvm {
35namespace tir {
36
37// This pass narrows indexing expressions (like StoreNode::Index)
38// that trivially fit into i32/i16 (denoted by `target_bits_`) to
39// i32/i16. Considering that i32/i16 indices may be more
40// efficient on some backends (while i64 may be more efficient
41// on others, like llvm), we may want this pass when i32/i16
42// indices are more efficient.
43//
44// For Var v, we determine its dtype by examining all the PrimExpr
45// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
46// If all expressions in E fit into i32/i16, then we think v can be narrowed
47// to i32/i16.
48//
49// To make an indexing expression i32/i16, we must make sure that every
50// component of that expression is of dtype i32/i16. So besides Var, we
51// rewrite the following inside an indexing expression
52// - Var
53// - IntImm
54// - Cast
55//
56// Algorithm:
57// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
58// - Use DataTypeRewritter to rewrite the components of an indexing expression.
59
60using arith::Analyzer;
61using arith::ConstIntBound;
62using arith::IRMutatorWithAnalyzer;
63
64// Determine the result dtype for Var, IntImm and Cast,
65// which will be stored in `vmap` eventually.
66//
67// Algorithm:
68// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
69// To be more specific, if for each Expr `e` which contains `var`
70// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
71// then we narrow `var` into `target_bits_`. That is,
72// `vmap[var] = min(target_bits_, var.dtype.bits())`
73// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
74class DataTypeVisitor final : public StmtExprVisitor {
75 public:
76 explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {}
77
78 void VisitExpr(const PrimExpr& e) {
79 if (e.dtype().is_int()) {
80 int bits = max_bits_;
81 if (bound_.find(e) == bound_.end()) {
82 analyzer_.const_int_bound(e, &bound_);
83 }
84 ConstIntBound bound = bound_[e];
85 int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
86 int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
87 if (e.dtype().bits() <= target_bits_ ||
88 (bound->max_value <= ubound && bound->min_value >= lbound)) {
89 bits = target_bits_;
90 }
91 int tmp = bits > bits_ ? bits : bits_;
92 std::swap(bits_, tmp);
93 StmtExprVisitor::VisitExpr(e);
94 std::swap(bits_, tmp);
95 } else {
96 StmtExprVisitor::VisitExpr(e);
97 }
98 }
99
100 void VisitStmt_(const ForNode* op) {
101 analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
102 vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
103 return StmtExprVisitor::VisitStmt_(op);
104 }
105
106 void VisitStmt_(const BlockNode* op) {
107 for (const IterVar& iter : op->iter_vars) {
108 analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
109 vextent_[iter->var.as<VarNode>()] = iter->dom->extent.dtype();
110 }
111 StmtExprVisitor::VisitStmt_(op);
112 }
113
114 void VisitStmt_(const AttrStmtNode* op) {
115 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
116 IterVar iv = Downcast<IterVar>(op->node);
117 ICHECK_NE(iv->thread_tag.length(), 0U);
118 analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
119 vextent_[iv->var.as<VarNode>()] = op->value.dtype();
120 StmtExprVisitor::VisitStmt_(op);
121 } else {
122 StmtExprVisitor::VisitStmt_(op);
123 }
124 }
125
126 void VisitExpr_(const ReduceNode* op) {
127 // Setup the domain information before simplification.
128 for (const IterVar& iv : op->axis) {
129 analyzer_.Bind(iv->var, iv->dom);
130 vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
131 }
132 // Recursively call simplification when necessary.
133 StmtExprVisitor::VisitExpr_(op);
134 }
135
136 void VisitExpr_(const VarNode* op) {
137 if (vextent_.find(op) != vextent_.end()) {
138 // We only narrow and never promote, so the result dtype
139 // is upperbounded by its original dtype before rewrite.
140 int bits = std::min(vextent_[op].bits(), bits_);
141 if (vmap.find(op) == vmap.end()) {
142 vmap[op] = op->dtype.with_bits(bits);
143 } else {
144 // We take maximum bits for all the possible Expr where a var occurs
145 vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
146 }
147 }
148 StmtExprVisitor::VisitExpr_(op);
149 }
150
151 void VisitExpr_(const IntImmNode* op) {
152 if (op->dtype.is_int()) {
153 // We only narrow and never promote, so the result dtype
154 // is upperbounded by its original dtype before rewrite.
155 int bits = std::min(op->dtype.bits(), bits_);
156 if (vmap.find(op) == vmap.end()) {
157 vmap[op] = op->dtype.with_bits(bits);
158 } else {
159 vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
160 }
161 }
162 StmtExprVisitor::VisitExpr_(op);
163 }
164
165 void VisitExpr_(const CastNode* op) {
166 if (op->dtype.is_int()) {
167 // We only narrow and never promote, so the result dtype
168 // is upperbounded by its original dtype before rewrite.
169 int bits = std::min(op->dtype.bits(), bits_);
170 if (vmap.find(op) == vmap.end()) {
171 vmap[op] = op->dtype.with_bits(bits);
172 } else {
173 vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
174 }
175 }
176 StmtExprVisitor::VisitExpr_(op);
177 }
178
179 // the narrowed datatype of Var and IntImm
180 std::unordered_map<const PrimExprNode*, DataType> vmap;
181
182 protected:
183 // internal analyzer
184 arith::Analyzer analyzer_;
185
186 private:
187 // the maximum possible bits, which serves as an init value
188 static constexpr const int max_bits_ = 64;
189 // the maximum possible bit of the current expression's return dtype
190 int bits_;
191 // the target bits
192 int target_bits_;
193 // the extent of vars to be rewritten
194 std::unordered_map<const VarNode*, DataType> vextent_;
195 // the memorized bound generated by ConstIntBoundAnalyzer
196 arith::ConstIntBoundAnalyzer::BoundMapType bound_;
197};
198
199class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
200 public:
201 using Parent = IndexDataTypeRewriter;
202 explicit NarrowDataTypeRewriter(int target_bits) : visitor_(target_bits) {}
203
204 Stmt operator()(Stmt s) {
205 visitor_(s);
206 for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
207 PrimExpr e = GetRef<PrimExpr>(i->first);
208 if (e.dtype() == i->second) {
209 i = visitor_.vmap.erase(i);
210 } else {
211 ++i;
212 }
213 }
214 return VisitStmt(s);
215 }
216
217 protected:
218 // This class adds some overrides of `VisitStmt_` and `VisitExpr_` that
219 // are *not* present in the parent class.
220 // These `using` statements ensure that all of the *other* overrides
221 // provided by the parent class are fully visible to users of this class.
222 // (Discussed further in https://github.com/apache/tvm/pull/13267)
223 using Parent::VisitExpr_;
224 using Parent::VisitStmt_;
225
226 Stmt VisitStmt_(const StoreNode* op) final {
227 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
228 }
229
230 PrimExpr VisitExpr_(const LoadNode* op) final {
231 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
232 }
233
234 PrimExpr VisitExpr_(const VarNode* op) final {
235 if (auto it = visitor_.vmap.find(op); !var_remap_.count(op) && it != visitor_.vmap.end()) {
236 var_remap_[op] = Var(op->name_hint, it->second);
237 }
238 return Parent::VisitExpr_(op);
239 }
240
241 PrimExpr VisitExpr_(const IntImmNode* op) final {
242 if (is_enabled_) {
243 if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
244 return IntImm(visitor_.vmap[op], op->value);
245 }
246 }
247 return Parent::VisitExpr_(op);
248 }
249
250 PrimExpr VisitExpr_(const CastNode* op) final {
251 if (is_enabled_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
252 PrimExpr e = Parent::VisitExpr_(op);
253 const CastNode* new_op = e.as<CastNode>();
254 ICHECK(new_op != nullptr) << "Expected type to be CastNode"
255 << ", but get " << e->GetTypeKey();
256 return Cast(visitor_.vmap[op], new_op->value);
257 }
258 return Parent::VisitExpr_(op);
259 }
260
261 private:
262 // the internal visitor to deduce the narrowed dtype
263 DataTypeVisitor visitor_;
264};
265
266Stmt NarrowDataType(Stmt stmt, int target_bits) {
267 return NarrowDataTypeRewriter(target_bits)(stmt);
268}
269
270namespace transform {
271
272Pass NarrowDataType(int target_bits) {
273 auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
274 auto* n = f.CopyOnWrite();
275 n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body));
276 return f;
277 };
278 return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {});
279}
280
281TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType);
282
283} // namespace transform
284} // namespace tir
285} // namespace tvm
286