1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file make_packed_api.cc Lower PrimFunc to use the packed function API. |
22 | */ |
23 | #include <tvm/runtime/device_api.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/target/target.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/buffer.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | #include <tvm/tir/transform.h> |
32 | |
33 | #include <unordered_set> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | #include "arg_binder.h" |
38 | #include "ir_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | static constexpr const char* kDeviceContextVar = "device_api_context" ; |
44 | |
45 | class ReturnRewriter : public StmtMutator { |
46 | public: |
47 | explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} |
48 | |
49 | Stmt VisitStmt_(const ForNode* node) override { |
50 | if (node->kind == ForKind::kParallel) in_parallel_ += 1; |
51 | Stmt ret = StmtMutator::VisitStmt_(node); |
52 | if (node->kind == ForKind::kParallel) in_parallel_ -= 1; |
53 | return ret; |
54 | } |
55 | |
56 | Stmt VisitStmt_(const EvaluateNode* node) override { |
57 | Stmt ret = StmtMutator::VisitStmt_(node); |
58 | const EvaluateNode* eval = ret.as<EvaluateNode>(); |
59 | ICHECK(eval); |
60 | if (const CallNode* call = eval->value.as<CallNode>()) { |
61 | if (call->op.same_as(builtin::ret())) { |
62 | ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope." ; |
63 | ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument." ; |
64 | ret = WriteToOut(call->args[0]); |
65 | } |
66 | } |
67 | return ret; |
68 | } |
69 | |
70 | private: |
71 | struct ConvertedInfo { |
72 | int tcode{-1}; |
73 | PrimExpr expr; |
74 | Buffer dummy_val_buffer; |
75 | Buffer dummy_tcode_buffer; |
76 | }; |
77 | |
78 | ConvertedInfo ConvertForFFI(PrimExpr val) { |
79 | ConvertedInfo info; |
80 | |
81 | // convert val's data type to FFI data type, return type code |
82 | DataType dtype = val.dtype(); |
83 | if (dtype.is_int() || dtype.is_uint()) { |
84 | info.tcode = kTVMArgInt; |
85 | info.expr = Cast(DataType::Int(64), val); |
86 | } else if (dtype.is_float()) { |
87 | info.tcode = kTVMArgFloat; |
88 | info.expr = Cast(DataType::Float(64), val); |
89 | } else if (dtype.is_void()) { |
90 | info.tcode = kTVMNullptr; |
91 | info.expr = val; |
92 | } else { |
93 | LOG(FATAL) << "data type " << dtype << " not supported yet" ; |
94 | } |
95 | |
96 | // If multiple return locations have the same data type, use the |
97 | // same dummy buffer declaration. |
98 | auto it = dummy_val_buffer_map_.find(info.tcode); |
99 | if (it != dummy_val_buffer_map_.end()) { |
100 | info.dummy_val_buffer = it->second; |
101 | } else { |
102 | info.dummy_val_buffer = Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0), |
103 | ret_var_->name_hint, 0, 0, kDefault); |
104 | dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer; |
105 | } |
106 | |
107 | // The tcode is always a 32-bit int, so we don't need to have a separate map. |
108 | if (!dummy_tcode_buffer_.defined()) { |
109 | dummy_tcode_buffer_ = Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0), |
110 | ret_tcode_->name_hint, 0, 0, kDefault); |
111 | } |
112 | info.dummy_tcode_buffer = dummy_tcode_buffer_; |
113 | |
114 | return info; |
115 | } |
116 | |
117 | Stmt WriteToOut(PrimExpr val) { |
118 | auto info = ConvertForFFI(val); |
119 | Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); |
120 | Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0}); |
121 | Stmt ret_zero = Evaluate(tvm::ret(0)); |
122 | return SeqStmt({store_val, store_tcode, ret_zero}); |
123 | } |
124 | |
125 | Var ret_var_; |
126 | Var ret_tcode_; |
127 | int in_parallel_{0}; |
128 | |
129 | std::unordered_map<int, Buffer> dummy_val_buffer_map_; |
130 | Buffer dummy_tcode_buffer_; |
131 | }; |
132 | |
133 | Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { |
134 | ReturnRewriter rewriter(ret_var, ret_tcode); |
135 | return rewriter(body); |
136 | } |
137 | |
138 | inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { |
139 | return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); |
140 | } |
141 | |
142 | PrimFunc MakePackedAPI(PrimFunc&& func) { |
143 | auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
144 | ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute" ; |
145 | |
146 | auto target = func->GetAttr<Target>(tvm::attr::kTarget); |
147 | ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute" ; |
148 | int target_device_type = target.value()->GetTargetDeviceType(); |
149 | |
150 | std::string name_hint = global_symbol.value(); |
151 | |
152 | auto* func_ptr = func.CopyOnWrite(); |
153 | const Stmt nop = Evaluate(0); |
154 | int num_args = static_cast<int>(func_ptr->params.size()); |
155 | |
156 | // Data field definitions |
157 | // The packed fields |
158 | Var v_packed_args("args" , DataType::Handle()); |
159 | Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, |
160 | DataType::Int(32), "arg_type_ids" ); |
161 | Var v_num_packed_args("num_args" , DataType::Int(32)); |
162 | Var v_out_ret_value("out_ret_value" , PointerType(PrimType(DataType::Void()))); |
163 | Var v_out_ret_tcode("out_ret_tcode" , PointerType(PrimType(DataType::Int(32)))); |
164 | Var v_resource_handle("resource_handle" , DataType::Handle()); |
165 | // The arguments of the function. |
166 | |
167 | // The device context |
168 | Var device_id("dev_id" ); |
169 | Integer device_type(target_device_type); |
170 | // seq_init gives sequence of initialization |
171 | // seq_check gives sequence of later checks after init |
172 | std::vector<Stmt> seq_init, seq_check; |
173 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
174 | ArgBinder binder(&vmap); |
175 | // --------------------------- |
176 | // local function definitions |
177 | // load i-th argument as type t |
178 | auto f_arg_value = [&](DataType t, int i) { |
179 | Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i), |
180 | IntImm(DataType::Int(32), builtin::kTVMValueContent)}; |
181 | // load 64 bit version |
182 | DataType api_type = APIType(t); |
183 | PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); |
184 | // cast to the target version. |
185 | if (api_type != t) { |
186 | res = Cast(t, res); |
187 | } |
188 | return res; |
189 | }; |
190 | |
191 | // Need to re-declare vars, in case some arguments also appears in the buffer. |
192 | std::vector<std::pair<Var, Var>> var_def; |
193 | std::vector<std::pair<Var, Buffer>> buffer_def; |
194 | |
195 | for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) { |
196 | Var param = func_ptr->params[i]; |
197 | std::string param_name; |
198 | if (param->name_hint.defined() && (!param->name_hint.empty())) { |
199 | param_name = "arg." + param->name_hint; |
200 | } else { |
201 | param_name = "arg" + std::to_string(i); |
202 | } |
203 | Var v_arg = Var(param_name, param->dtype); |
204 | |
205 | // Pluck the device API context out based on name |
206 | if (param->name_hint == kDeviceContextVar) { |
207 | num_args--; |
208 | v_resource_handle = param; |
209 | continue; |
210 | } |
211 | |
212 | if (func_ptr->buffer_map.count(param)) { |
213 | buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); |
214 | } else { |
215 | var_def.emplace_back(v_arg, param); |
216 | } |
217 | |
218 | // Value loads |
219 | seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); |
220 | // type code checks |
221 | Var tcode(v_arg->name_hint + ".code" , DataType::Int(32)); |
222 | seq_init.emplace_back( |
223 | LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); |
224 | DataType t = v_arg.dtype(); |
225 | if (t.is_handle()) { |
226 | std::ostringstream msg; |
227 | msg << name_hint << ": Expect arg[" << i << "] to be pointer" ; |
228 | seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || |
229 | tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, |
230 | tvm::tir::StringImm(msg.str()), nop)); |
231 | } else if (t.is_int() || t.is_uint()) { |
232 | std::ostringstream msg; |
233 | msg << name_hint << ": Expect arg[" << i << "] to be int" ; |
234 | seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); |
235 | } else { |
236 | ICHECK(t.is_float()); |
237 | std::ostringstream msg; |
238 | msg << name_hint << ": Expect arg[" << i << "] to be float" ; |
239 | seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); |
240 | } |
241 | } |
242 | |
243 | Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data, |
244 | v_num_packed_args, v_out_ret_value, |
245 | v_out_ret_tcode, v_resource_handle}; |
246 | |
247 | // Arg definitions are defined before buffer binding to avoid the use before |
248 | // def errors. |
249 | // |
250 | // For example, for auto broadcasting, checks are required to guarantee that |
251 | // either 0 or the original stride will be correctly used. Checks here have |
252 | // to use the args that may have no let binding yet. Therefore, hoisting let |
253 | // binding for args before buffer declaration is needed. |
254 | for (const auto& kv : var_def) { |
255 | binder.Bind(kv.second, kv.first, kv.first->name_hint, true); |
256 | } |
257 | |
258 | for (const auto& kv : buffer_def) { |
259 | binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); |
260 | } |
261 | |
262 | func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); |
263 | |
264 | Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); |
265 | body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, |
266 | StringImm(name_hint + "_compute_" ), body); |
267 | // Set device context |
268 | if (vmap.count(device_id.get())) { |
269 | PrimExpr node = StringImm("default" ); |
270 | seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop)); |
271 | seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop)); |
272 | |
273 | if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) { |
274 | Stmt set_device = |
275 | Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), |
276 | {StringImm(runtime::symbol::tvm_set_device), device_type, device_id})); |
277 | body = SeqStmt({set_device, body}); |
278 | } |
279 | } |
280 | |
281 | std::ostringstream num_args_error; |
282 | num_args_error << name_hint << ": num_args should be " << num_args; |
283 | std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; |
284 | func_ptr->body = |
285 | MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); |
286 | func_ptr->params = args; |
287 | |
288 | Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params); |
289 | if (undefined.size() != 0) { |
290 | std::ostringstream os; |
291 | for (Var v : undefined) { |
292 | os << " \'" << v->name_hint << "\' " ; |
293 | } |
294 | os << " is not bound to any variables" ; |
295 | LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); |
296 | } |
297 | |
298 | func_ptr->buffer_map = Map<Var, Buffer>(); |
299 | func_ptr->checked_type_ = func_ptr->func_type_annotation(); |
300 | func_ptr->ret_type = PrimType(DataType::Int(32)); |
301 | |
302 | // return the function. |
303 | return std::move(func); |
304 | } |
305 | |
306 | namespace transform { |
307 | |
308 | Pass MakePackedAPI() { |
309 | auto pass_func = [](IRModule m, PassContext ctx) { |
310 | IRModuleNode* mptr = m.CopyOnWrite(); |
311 | std::vector<std::pair<GlobalVar, PrimFunc>> updates; |
312 | |
313 | for (const auto& kv : mptr->functions) { |
314 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
315 | PrimFunc func = GetRef<PrimFunc>(n); |
316 | if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == |
317 | CallingConv::kDefault) { |
318 | auto updated_func = MakePackedAPI(std::move(func)); |
319 | updates.push_back({kv.first, updated_func}); |
320 | } |
321 | } |
322 | } |
323 | |
324 | for (const auto& pair : updates) { |
325 | mptr->AddUnchecked(pair.first, pair.second); |
326 | } |
327 | return m; |
328 | }; |
329 | |
330 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI" , {}); |
331 | } |
332 | |
333 | TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI" ).set_body_typed([]() { return MakePackedAPI(); }); |
334 | } // namespace transform |
335 | } // namespace tir |
336 | } // namespace tvm |
337 | |