1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
22 */
23#include <tvm/runtime/device_api.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/target/target.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/buffer.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt_functor.h>
31#include <tvm/tir/transform.h>
32
33#include <unordered_set>
34#include <utility>
35#include <vector>
36
37#include "arg_binder.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43static constexpr const char* kDeviceContextVar = "device_api_context";
44
45class ReturnRewriter : public StmtMutator {
46 public:
47 explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
48
49 Stmt VisitStmt_(const ForNode* node) override {
50 if (node->kind == ForKind::kParallel) in_parallel_ += 1;
51 Stmt ret = StmtMutator::VisitStmt_(node);
52 if (node->kind == ForKind::kParallel) in_parallel_ -= 1;
53 return ret;
54 }
55
56 Stmt VisitStmt_(const EvaluateNode* node) override {
57 Stmt ret = StmtMutator::VisitStmt_(node);
58 const EvaluateNode* eval = ret.as<EvaluateNode>();
59 ICHECK(eval);
60 if (const CallNode* call = eval->value.as<CallNode>()) {
61 if (call->op.same_as(builtin::ret())) {
62 ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
63 ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
64 ret = WriteToOut(call->args[0]);
65 }
66 }
67 return ret;
68 }
69
70 private:
71 struct ConvertedInfo {
72 int tcode{-1};
73 PrimExpr expr;
74 Buffer dummy_val_buffer;
75 Buffer dummy_tcode_buffer;
76 };
77
78 ConvertedInfo ConvertForFFI(PrimExpr val) {
79 ConvertedInfo info;
80
81 // convert val's data type to FFI data type, return type code
82 DataType dtype = val.dtype();
83 if (dtype.is_int() || dtype.is_uint()) {
84 info.tcode = kTVMArgInt;
85 info.expr = Cast(DataType::Int(64), val);
86 } else if (dtype.is_float()) {
87 info.tcode = kTVMArgFloat;
88 info.expr = Cast(DataType::Float(64), val);
89 } else if (dtype.is_void()) {
90 info.tcode = kTVMNullptr;
91 info.expr = val;
92 } else {
93 LOG(FATAL) << "data type " << dtype << " not supported yet";
94 }
95
96 // If multiple return locations have the same data type, use the
97 // same dummy buffer declaration.
98 auto it = dummy_val_buffer_map_.find(info.tcode);
99 if (it != dummy_val_buffer_map_.end()) {
100 info.dummy_val_buffer = it->second;
101 } else {
102 info.dummy_val_buffer = Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
103 ret_var_->name_hint, 0, 0, kDefault);
104 dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer;
105 }
106
107 // The tcode is always a 32-bit int, so we don't need to have a separate map.
108 if (!dummy_tcode_buffer_.defined()) {
109 dummy_tcode_buffer_ = Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
110 ret_tcode_->name_hint, 0, 0, kDefault);
111 }
112 info.dummy_tcode_buffer = dummy_tcode_buffer_;
113
114 return info;
115 }
116
117 Stmt WriteToOut(PrimExpr val) {
118 auto info = ConvertForFFI(val);
119 Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
120 Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0});
121 Stmt ret_zero = Evaluate(tvm::ret(0));
122 return SeqStmt({store_val, store_tcode, ret_zero});
123 }
124
125 Var ret_var_;
126 Var ret_tcode_;
127 int in_parallel_{0};
128
129 std::unordered_map<int, Buffer> dummy_val_buffer_map_;
130 Buffer dummy_tcode_buffer_;
131};
132
133Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
134 ReturnRewriter rewriter(ret_var, ret_tcode);
135 return rewriter(body);
136}
137
138inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
139 return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
140}
141
142PrimFunc MakePackedAPI(PrimFunc&& func) {
143 auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
144 ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
145
146 auto target = func->GetAttr<Target>(tvm::attr::kTarget);
147 ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
148 int target_device_type = target.value()->GetTargetDeviceType();
149
150 std::string name_hint = global_symbol.value();
151
152 auto* func_ptr = func.CopyOnWrite();
153 const Stmt nop = Evaluate(0);
154 int num_args = static_cast<int>(func_ptr->params.size());
155
156 // Data field definitions
157 // The packed fields
158 Var v_packed_args("args", DataType::Handle());
159 Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())},
160 DataType::Int(32), "arg_type_ids");
161 Var v_num_packed_args("num_args", DataType::Int(32));
162 Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void())));
163 Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32))));
164 Var v_resource_handle("resource_handle", DataType::Handle());
165 // The arguments of the function.
166
167 // The device context
168 Var device_id("dev_id");
169 Integer device_type(target_device_type);
170 // seq_init gives sequence of initialization
171 // seq_check gives sequence of later checks after init
172 std::vector<Stmt> seq_init, seq_check;
173 std::unordered_map<const VarNode*, PrimExpr> vmap;
174 ArgBinder binder(&vmap);
175 // ---------------------------
176 // local function definitions
177 // load i-th argument as type t
178 auto f_arg_value = [&](DataType t, int i) {
179 Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
180 IntImm(DataType::Int(32), builtin::kTVMValueContent)};
181 // load 64 bit version
182 DataType api_type = APIType(t);
183 PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
184 // cast to the target version.
185 if (api_type != t) {
186 res = Cast(t, res);
187 }
188 return res;
189 };
190
191 // Need to re-declare vars, in case some arguments also appears in the buffer.
192 std::vector<std::pair<Var, Var>> var_def;
193 std::vector<std::pair<Var, Buffer>> buffer_def;
194
195 for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
196 Var param = func_ptr->params[i];
197 std::string param_name;
198 if (param->name_hint.defined() && (!param->name_hint.empty())) {
199 param_name = "arg." + param->name_hint;
200 } else {
201 param_name = "arg" + std::to_string(i);
202 }
203 Var v_arg = Var(param_name, param->dtype);
204
205 // Pluck the device API context out based on name
206 if (param->name_hint == kDeviceContextVar) {
207 num_args--;
208 v_resource_handle = param;
209 continue;
210 }
211
212 if (func_ptr->buffer_map.count(param)) {
213 buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
214 } else {
215 var_def.emplace_back(v_arg, param);
216 }
217
218 // Value loads
219 seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
220 // type code checks
221 Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
222 seq_init.emplace_back(
223 LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
224 DataType t = v_arg.dtype();
225 if (t.is_handle()) {
226 std::ostringstream msg;
227 msg << name_hint << ": Expect arg[" << i << "] to be pointer";
228 seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
229 tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
230 tvm::tir::StringImm(msg.str()), nop));
231 } else if (t.is_int() || t.is_uint()) {
232 std::ostringstream msg;
233 msg << name_hint << ": Expect arg[" << i << "] to be int";
234 seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
235 } else {
236 ICHECK(t.is_float());
237 std::ostringstream msg;
238 msg << name_hint << ": Expect arg[" << i << "] to be float";
239 seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
240 }
241 }
242
243 Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data,
244 v_num_packed_args, v_out_ret_value,
245 v_out_ret_tcode, v_resource_handle};
246
247 // Arg definitions are defined before buffer binding to avoid the use before
248 // def errors.
249 //
250 // For example, for auto broadcasting, checks are required to guarantee that
251 // either 0 or the original stride will be correctly used. Checks here have
252 // to use the args that may have no let binding yet. Therefore, hoisting let
253 // binding for args before buffer declaration is needed.
254 for (const auto& kv : var_def) {
255 binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
256 }
257
258 for (const auto& kv : buffer_def) {
259 binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
260 }
261
262 func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
263
264 Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
265 body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
266 StringImm(name_hint + "_compute_"), body);
267 // Set device context
268 if (vmap.count(device_id.get())) {
269 PrimExpr node = StringImm("default");
270 seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
271 seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));
272
273 if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
274 Stmt set_device =
275 Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
276 {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
277 body = SeqStmt({set_device, body});
278 }
279 }
280
281 std::ostringstream num_args_error;
282 num_args_error << name_hint << ": num_args should be " << num_args;
283 std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())};
284 func_ptr->body =
285 MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
286 func_ptr->params = args;
287
288 Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
289 if (undefined.size() != 0) {
290 std::ostringstream os;
291 for (Var v : undefined) {
292 os << " \'" << v->name_hint << "\' ";
293 }
294 os << " is not bound to any variables";
295 LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
296 }
297
298 func_ptr->buffer_map = Map<Var, Buffer>();
299 func_ptr->checked_type_ = func_ptr->func_type_annotation();
300 func_ptr->ret_type = PrimType(DataType::Int(32));
301
302 // return the function.
303 return std::move(func);
304}
305
306namespace transform {
307
308Pass MakePackedAPI() {
309 auto pass_func = [](IRModule m, PassContext ctx) {
310 IRModuleNode* mptr = m.CopyOnWrite();
311 std::vector<std::pair<GlobalVar, PrimFunc>> updates;
312
313 for (const auto& kv : mptr->functions) {
314 if (auto* n = kv.second.as<PrimFuncNode>()) {
315 PrimFunc func = GetRef<PrimFunc>(n);
316 if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
317 CallingConv::kDefault) {
318 auto updated_func = MakePackedAPI(std::move(func));
319 updates.push_back({kv.first, updated_func});
320 }
321 }
322 }
323
324 for (const auto& pair : updates) {
325 mptr->AddUnchecked(pair.first, pair.second);
326 }
327 return m;
328 };
329
330 return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
331}
332
333TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { return MakePackedAPI(); });
334} // namespace transform
335} // namespace tir
336} // namespace tvm
337