1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/src/pass/lower_custom_datatypes.cc
21 * \brief Pass for lowering custom datatypes
22 */
23
24#include <tvm/runtime/registry.h>
25#include <tvm/target/target.h>
26#include <tvm/tir/op.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include "../../target/datatype/registry.h"
31
32namespace tvm {
33namespace tir {
34
35/*!
36 * \brief Helper mutator to implement lowering of custom datatypes.
37 *
38 * Lowering datatypes works as follows: for every expression containing a custom
39 * datatype, we search for a global (registered by the implementer of the custom
40 * datatype) for lowering this type of expression, and uses it to lower the
41 * expression.
42 */
43class CustomDatatypesLowerer : public StmtExprMutator {
44 public:
45 explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
46
47 PrimExpr VisitExpr_(const CastNode* op) final {
48 auto type_code = op->dtype.code();
49 auto src_type_code = op->value.dtype().code();
50 // If either datatype is a registered custom datatype, we must lower.
51 bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
52 datatype::Registry::Global()->GetTypeRegistered(src_type_code);
53 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
54 if (to_be_lowered) {
55 auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
56 ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type "
57 << static_cast<unsigned>(type_code) << " source type "
58 << static_cast<unsigned>(src_type_code) << " not found";
59 return (*lower)(expr);
60 }
61 return expr;
62 }
63
64 PrimExpr VisitExpr_(const FloatImmNode* imm) final {
65 auto type_code = imm->dtype.code();
66 auto e = GetRef<PrimExpr>(imm);
67 if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
68 auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
69 ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
70 << static_cast<unsigned>(type_code) << " not found";
71 return (*lower)(e);
72 }
73 return e;
74 }
75
76 PrimExpr VisitExpr_(const VarNode* op) final {
77 Var var = GetRef<Var>(op);
78
79 auto itr = var_remap_.find(var);
80 if (itr != var_remap_.end()) {
81 return itr->second;
82 } else {
83 return std::move(var);
84 }
85 }
86
87 Stmt VisitStmt_(const AllocateNode* allocate) final {
88 bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
89
90 if (to_be_lowered) {
91 auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
92 auto new_buffer_var =
93 Var(allocate->buffer_var->name_hint, PointerType(PrimType(new_allocate_type)));
94 var_remap_[allocate->buffer_var] = new_buffer_var;
95
96 Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
97 allocate = stmt.as<AllocateNode>();
98
99 return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition,
100 allocate->body);
101 } else {
102 return StmtExprMutator::VisitStmt_(allocate);
103 }
104 }
105
106 PrimExpr VisitExpr_(const LoadNode* op) final {
107 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
108 }
109
110 Stmt VisitStmt_(const StoreNode* op) final {
111 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
112 }
113
114 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
115 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
116 auto modified = VisitBufferAccess(node);
117
118 // Not needed for BufferStoreNode, so we can't just call
119 // LegalizeDtype() in VisitBufferAccess.
120 if (node.same_as(modified)) {
121 return std::move(node);
122 } else {
123 auto writer = modified.CopyOnWrite();
124 writer->LegalizeDType();
125 return std::move(modified);
126 }
127 }
128
129 Stmt VisitStmt_(const BufferStoreNode* op) final {
130 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
131 return VisitBufferAccess(std::move(node));
132 }
133
134 template <typename Node>
135 Node VisitBufferAccess(Node node) {
136 Buffer new_buf = GetRemappedBuffer(node->buffer);
137 if (!new_buf.same_as(node->buffer)) {
138 auto writer = node.CopyOnWrite();
139 writer->buffer = new_buf;
140 }
141
142 return node;
143 }
144
145 Buffer GetRemappedBuffer(Buffer buf) {
146 auto key = buf;
147 auto cache_it = buf_remap_.find(key);
148 if (cache_it != buf_remap_.end()) {
149 return cache_it->second;
150 }
151
152 bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code());
153
154 if (to_be_lowered) {
155 auto new_load_type = DataType::UInt(buf->dtype.bits());
156 auto writer = buf.CopyOnWrite();
157 writer->dtype = new_load_type;
158
159 auto var_it = var_remap_.find(buf->data);
160 if (var_it != var_remap_.end()) {
161 writer->data = var_it->second;
162 }
163 }
164
165 buf_remap_[key] = buf;
166 return buf;
167 }
168
169 Stmt VisitStmt_(const AttrStmtNode* op) final {
170 Stmt ret = StmtExprMutator::VisitStmt_(op);
171 op = ret.as<AttrStmtNode>();
172 // Due to legacy reasons, some attr node can contain
173 // information(e.g. alignment) of buffer variables.
174 // remap these vars when needed
175 // TODO(tvm-team): remove the rewriting once the buffer var
176 // attrs are being refactored into the corresponding definition node
177 if (const auto* var_node = op->node.as<VarNode>()) {
178 auto it = var_remap_.find(GetRef<Var>(var_node));
179 if (it != var_remap_.end()) {
180 return AttrStmt(it->second, op->attr_key, op->value, op->body);
181 }
182 }
183 return ret;
184 }
185
186 PrimExpr VisitExpr_(const CallNode* call) final {
187 bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
188 PrimExpr expr = StmtExprMutator::VisitExpr_(call);
189 call = expr.as<CallNode>();
190 if (to_be_lowered) {
191 auto op = call->op.as<OpNode>();
192 ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented";
193 auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code());
194 ICHECK(lower) << "Intrinsic lowering function for target " << target_ << ", intrinsic name "
195 << op->name << ", type " << static_cast<unsigned>(call->dtype.code())
196 << " not found";
197 return (*lower)(expr);
198 }
199 return expr;
200 }
201
202#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \
203 PrimExpr VisitExpr_(const NodeName* op) final { \
204 auto type_code = op->dtype.code(); \
205 bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
206 PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
207 op = expr.as<NodeName>(); \
208 if (to_be_lowered) { \
209 auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
210 ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \
211 << static_cast<unsigned>(type_code) << " not found"; \
212 return (*lower)(expr); \
213 } \
214 return expr; \
215 }
216
217 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode);
218 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode);
219 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode);
220 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode);
221 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode);
222 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode);
223 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode);
224 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode);
225 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode);
226 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode);
227 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode);
228 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode);
229 TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode);
230 // Later changes may need to add more mutate functions as we support workloads with more ops.
231
232#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE
233
234 private:
235 std::string target_;
236 // remap buffer vars
237 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
238 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
239};
240
241namespace transform {
242
243Pass LowerCustomDatatypes() {
244 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
245 auto* n = f.CopyOnWrite();
246 auto target = f->GetAttr<Target>(tvm::attr::kTarget);
247 ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute";
248
249 n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body));
250 return f;
251 };
252 return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
253}
254
255TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes);
256
257} // namespace transform
258
259} // namespace tir
260} // namespace tvm
261