1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/src/pass/lower_custom_datatypes.cc |
21 | * \brief Pass for lowering custom datatypes |
22 | */ |
23 | |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/target/target.h> |
26 | #include <tvm/tir/op.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include "../../target/datatype/registry.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | /*! |
36 | * \brief Helper mutator to implement lowering of custom datatypes. |
37 | * |
38 | * Lowering datatypes works as follows: for every expression containing a custom |
39 | * datatype, we search for a global (registered by the implementer of the custom |
40 | * datatype) for lowering this type of expression, and uses it to lower the |
41 | * expression. |
42 | */ |
43 | class CustomDatatypesLowerer : public StmtExprMutator { |
44 | public: |
45 | explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} |
46 | |
47 | PrimExpr VisitExpr_(const CastNode* op) final { |
48 | auto type_code = op->dtype.code(); |
49 | auto src_type_code = op->value.dtype().code(); |
50 | // If either datatype is a registered custom datatype, we must lower. |
51 | bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || |
52 | datatype::Registry::Global()->GetTypeRegistered(src_type_code); |
53 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
54 | if (to_be_lowered) { |
55 | auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); |
56 | ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type " |
57 | << static_cast<unsigned>(type_code) << " source type " |
58 | << static_cast<unsigned>(src_type_code) << " not found" ; |
59 | return (*lower)(expr); |
60 | } |
61 | return expr; |
62 | } |
63 | |
64 | PrimExpr VisitExpr_(const FloatImmNode* imm) final { |
65 | auto type_code = imm->dtype.code(); |
66 | auto e = GetRef<PrimExpr>(imm); |
67 | if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { |
68 | auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); |
69 | ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " |
70 | << static_cast<unsigned>(type_code) << " not found" ; |
71 | return (*lower)(e); |
72 | } |
73 | return e; |
74 | } |
75 | |
76 | PrimExpr VisitExpr_(const VarNode* op) final { |
77 | Var var = GetRef<Var>(op); |
78 | |
79 | auto itr = var_remap_.find(var); |
80 | if (itr != var_remap_.end()) { |
81 | return itr->second; |
82 | } else { |
83 | return std::move(var); |
84 | } |
85 | } |
86 | |
87 | Stmt VisitStmt_(const AllocateNode* allocate) final { |
88 | bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); |
89 | |
90 | if (to_be_lowered) { |
91 | auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); |
92 | auto new_buffer_var = |
93 | Var(allocate->buffer_var->name_hint, PointerType(PrimType(new_allocate_type))); |
94 | var_remap_[allocate->buffer_var] = new_buffer_var; |
95 | |
96 | Stmt stmt = StmtExprMutator::VisitStmt_(allocate); |
97 | allocate = stmt.as<AllocateNode>(); |
98 | |
99 | return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition, |
100 | allocate->body); |
101 | } else { |
102 | return StmtExprMutator::VisitStmt_(allocate); |
103 | } |
104 | } |
105 | |
106 | PrimExpr VisitExpr_(const LoadNode* op) final { |
107 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
108 | } |
109 | |
110 | Stmt VisitStmt_(const StoreNode* op) final { |
111 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
112 | } |
113 | |
114 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
115 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
116 | auto modified = VisitBufferAccess(node); |
117 | |
118 | // Not needed for BufferStoreNode, so we can't just call |
119 | // LegalizeDtype() in VisitBufferAccess. |
120 | if (node.same_as(modified)) { |
121 | return std::move(node); |
122 | } else { |
123 | auto writer = modified.CopyOnWrite(); |
124 | writer->LegalizeDType(); |
125 | return std::move(modified); |
126 | } |
127 | } |
128 | |
129 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
130 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
131 | return VisitBufferAccess(std::move(node)); |
132 | } |
133 | |
134 | template <typename Node> |
135 | Node VisitBufferAccess(Node node) { |
136 | Buffer new_buf = GetRemappedBuffer(node->buffer); |
137 | if (!new_buf.same_as(node->buffer)) { |
138 | auto writer = node.CopyOnWrite(); |
139 | writer->buffer = new_buf; |
140 | } |
141 | |
142 | return node; |
143 | } |
144 | |
145 | Buffer GetRemappedBuffer(Buffer buf) { |
146 | auto key = buf; |
147 | auto cache_it = buf_remap_.find(key); |
148 | if (cache_it != buf_remap_.end()) { |
149 | return cache_it->second; |
150 | } |
151 | |
152 | bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code()); |
153 | |
154 | if (to_be_lowered) { |
155 | auto new_load_type = DataType::UInt(buf->dtype.bits()); |
156 | auto writer = buf.CopyOnWrite(); |
157 | writer->dtype = new_load_type; |
158 | |
159 | auto var_it = var_remap_.find(buf->data); |
160 | if (var_it != var_remap_.end()) { |
161 | writer->data = var_it->second; |
162 | } |
163 | } |
164 | |
165 | buf_remap_[key] = buf; |
166 | return buf; |
167 | } |
168 | |
169 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
170 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
171 | op = ret.as<AttrStmtNode>(); |
172 | // Due to legacy reasons, some attr node can contain |
173 | // information(e.g. alignment) of buffer variables. |
174 | // remap these vars when needed |
175 | // TODO(tvm-team): remove the rewriting once the buffer var |
176 | // attrs are being refactored into the corresponding definition node |
177 | if (const auto* var_node = op->node.as<VarNode>()) { |
178 | auto it = var_remap_.find(GetRef<Var>(var_node)); |
179 | if (it != var_remap_.end()) { |
180 | return AttrStmt(it->second, op->attr_key, op->value, op->body); |
181 | } |
182 | } |
183 | return ret; |
184 | } |
185 | |
186 | PrimExpr VisitExpr_(const CallNode* call) final { |
187 | bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code()); |
188 | PrimExpr expr = StmtExprMutator::VisitExpr_(call); |
189 | call = expr.as<CallNode>(); |
190 | if (to_be_lowered) { |
191 | auto op = call->op.as<OpNode>(); |
192 | ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented" ; |
193 | auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code()); |
194 | ICHECK(lower) << "Intrinsic lowering function for target " << target_ << ", intrinsic name " |
195 | << op->name << ", type " << static_cast<unsigned>(call->dtype.code()) |
196 | << " not found" ; |
197 | return (*lower)(expr); |
198 | } |
199 | return expr; |
200 | } |
201 | |
202 | #define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \ |
203 | PrimExpr VisitExpr_(const NodeName* op) final { \ |
204 | auto type_code = op->dtype.code(); \ |
205 | bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ |
206 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ |
207 | op = expr.as<NodeName>(); \ |
208 | if (to_be_lowered) { \ |
209 | auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ |
210 | ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ |
211 | << static_cast<unsigned>(type_code) << " not found"; \ |
212 | return (*lower)(expr); \ |
213 | } \ |
214 | return expr; \ |
215 | } |
216 | |
217 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode); |
218 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode); |
219 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode); |
220 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode); |
221 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode); |
222 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode); |
223 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode); |
224 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode); |
225 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode); |
226 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode); |
227 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode); |
228 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode); |
229 | TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode); |
230 | // Later changes may need to add more mutate functions as we support workloads with more ops. |
231 | |
232 | #undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE |
233 | |
234 | private: |
235 | std::string target_; |
236 | // remap buffer vars |
237 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_; |
238 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_; |
239 | }; |
240 | |
241 | namespace transform { |
242 | |
243 | Pass LowerCustomDatatypes() { |
244 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
245 | auto* n = f.CopyOnWrite(); |
246 | auto target = f->GetAttr<Target>(tvm::attr::kTarget); |
247 | ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute" ; |
248 | |
249 | n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); |
250 | return f; |
251 | }; |
252 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes" , {}); |
253 | } |
254 | |
255 | TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes" ).set_body_typed(LowerCustomDatatypes); |
256 | |
257 | } // namespace transform |
258 | |
259 | } // namespace tir |
260 | } // namespace tvm |
261 | |