1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file pack_args.h
22 * \brief Utility to pack TVMArgs to other type-erased fution calling convention.
23 *
24 * Two type erased function signatures are supported.
25 * - cuda_style(void** args, int num_args);
26 * - Pack everything by address
27 * - metal_style(void** buffers, int num_buffers,
28 * union_32bit args[N], int num_args);
29 * - Pack buffer by address, pack rest parameter into 32bit union buffer.
30 */
31#ifndef TVM_RUNTIME_PACK_ARGS_H_
32#define TVM_RUNTIME_PACK_ARGS_H_
33
34#include <tvm/runtime/c_runtime_api.h>
35#include <tvm/runtime/packed_func.h>
36
37#include <cstring>
38#include <vector>
39
40namespace tvm {
41namespace runtime {
42/*!
43 * \brief argument union type of 32bit.
44 */
45union ArgUnion32 {
46 int32_t v_int32;
47 uint32_t v_uint32;
48 float v_float32;
49};
50
51/*!
52 * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
53 */
54union ArgUnion64 {
55 int32_t v_int32[2];
56 uint32_t v_uint32[2];
57 float v_float32[2];
58 int64_t v_int64;
59 uint64_t v_uint64;
60 double v_float64;
61};
62/*!
63 * \brief Create a packed function from void addr types.
64 *
65 * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
66 * \param arg_types The arguments type information.
67 * \tparam F the function type
68 *
69 * \return The wrapped packed function.
70 */
71template <typename F>
72inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types);
73/*!
74 * \brief Create a packed function that from function only packs buffer arguments.
75 *
76 * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
77 * \param arg_types The arguments type information.
78 * \tparam F the function type
79 *
80 * \return The wrapped packed function.
81 */
82template <typename F>
83inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types);
84/*!
85 * \brief Create a packed function that from function that takes a packed arguments.
86 *
87 * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
88 * \param arg_types The arguments that wish to get from
89 * \tparam F the function type
90 *
91 * \return The wrapped packed function.
92 */
93template <typename F>
94inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types);
95/*!
96 * \brief Extract number of buffer argument from the argument types.
97 * \param arg_types The argument types.
98 * \return number of buffer arguments
99 */
100inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types);
101
102// implementations details
103namespace detail {
104template <typename T, int kSize>
105class TempArray {
106 public:
107 explicit TempArray(int size) {}
108 T* data() { return data_; }
109
110 private:
111 T data_[kSize];
112};
113template <typename T>
114class TempArray<T, 0> {
115 public:
116 explicit TempArray(int size) : data_(size) {}
117 T* data() { return data_.data(); }
118
119 private:
120 std::vector<T> data_;
121};
122
123/*! \brief conversion code used in void arg. */
124enum ArgConvertCode {
125 INT64_TO_INT64,
126 INT64_TO_INT32,
127 INT64_TO_UINT32,
128 FLOAT64_TO_FLOAT32,
129 FLOAT64_TO_FLOAT64,
130 HANDLE_TO_HANDLE
131};
132
133inline ArgConvertCode GetArgConvertCode(DLDataType t) {
134 ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now";
135 if (t.code == kDLInt) {
136 if (t.bits == 64U) return INT64_TO_INT64;
137 if (t.bits == 32U) return INT64_TO_INT32;
138 } else if (t.code == kDLUInt) {
139 if (t.bits == 32U) return INT64_TO_UINT32;
140 } else if (t.code == kDLFloat) {
141 if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
142 if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
143 } else if (t.code == kTVMOpaqueHandle) {
144 return HANDLE_TO_HANDLE;
145 }
146 LOG(FATAL) << "Cannot handle " << t << " as device function argument";
147}
148
149template <int N, typename F>
150inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
151 int num_args = static_cast<int>(codes.size());
152 auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
153 TempArray<void*, N> addr_(num_args);
154 TempArray<ArgUnion32, N> holder_(num_args);
155 void** addr = addr_.data();
156 ArgUnion32* holder = holder_.data();
157 for (int i = 0; i < num_args; ++i) {
158 switch (codes[i]) {
159 case INT64_TO_INT64:
160 case FLOAT64_TO_FLOAT64:
161 case HANDLE_TO_HANDLE: {
162 addr[i] = (void*)&(args.values[i]); // NOLINT(*)
163 break;
164 }
165 case INT64_TO_INT32: {
166 holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
167 addr[i] = &(holder[i]);
168 break;
169 }
170 case INT64_TO_UINT32: {
171 holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
172 addr[i] = &(holder[i]);
173 break;
174 }
175 case FLOAT64_TO_FLOAT32: {
176 holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
177 addr[i] = &(holder[i]);
178 break;
179 }
180 }
181 }
182 f(args, ret, addr);
183 };
184 return PackedFunc(ret);
185}
186
187template <int N, typename F>
188inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
189 int num_args = static_cast<int>(codes.size());
190 auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
191 TempArray<ArgUnion64, N> holder_(num_args);
192 ArgUnion64* holder = holder_.data();
193 for (int i = 0; i < num_args; ++i) {
194 switch (codes[i]) {
195 case INT64_TO_INT64: {
196 holder[i].v_int64 = args.values[base + i].v_int64;
197 break;
198 }
199 case FLOAT64_TO_FLOAT64: {
200 holder[i].v_float64 = args.values[base + i].v_float64;
201 break;
202 }
203 case INT64_TO_INT32: {
204 holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
205 break;
206 }
207 case INT64_TO_UINT32: {
208 holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
209 break;
210 }
211 case FLOAT64_TO_FLOAT32: {
212 holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
213 break;
214 }
215 case HANDLE_TO_HANDLE: {
216 LOG(FATAL) << "not reached";
217 break;
218 }
219 }
220 }
221 f(args, ret, holder);
222 };
223 return PackedFunc(ret);
224}
225
226template <int N, typename F>
227inline PackedFunc PackFuncPackedArg_(F f, const std::vector<ArgConvertCode>& codes) {
228 int num_args = static_cast<int>(codes.size());
229 auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
230 TempArray<uint64_t, N> pack_(num_args);
231 int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
232 int32_t* ptr = pack;
233 static_assert(sizeof(TVMValue) == 8, "invariant");
234 static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
235 for (int i = 0; i < num_args; ++i) {
236 switch (codes[i]) {
237 case HANDLE_TO_HANDLE: {
238 std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
239 ptr += sizeof(void*) / sizeof(int32_t);
240 break;
241 }
242 case INT64_TO_INT64:
243 case FLOAT64_TO_FLOAT64: {
244 std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
245 ptr += 2;
246 break;
247 }
248 case INT64_TO_INT32: {
249 *ptr = static_cast<int32_t>(args.values[i].v_int64);
250 ++ptr;
251 break;
252 }
253 case INT64_TO_UINT32: {
254 *reinterpret_cast<uint32_t*>(ptr) = static_cast<uint32_t>(args.values[i].v_int64);
255 ++ptr;
256 break;
257 }
258 case FLOAT64_TO_FLOAT32: {
259 *reinterpret_cast<float*>(ptr) = static_cast<float>(args.values[i].v_float64);
260 ++ptr;
261 break;
262 }
263 default: {
264 LOG(FATAL) << "not reached";
265 break;
266 }
267 }
268 }
269 f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
270 };
271 return PackedFunc(ret);
272}
273} // namespace detail
274
275template <typename F>
276inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types) {
277 std::vector<detail::ArgConvertCode> codes(arg_types.size());
278 for (size_t i = 0; i < arg_types.size(); ++i) {
279 codes[i] = detail::GetArgConvertCode(arg_types[i]);
280 }
281 size_t num_void_args = arg_types.size();
282 // specialization
283 if (num_void_args <= 4) {
284 return detail::PackFuncVoidAddr_<4>(f, codes);
285 } else if (num_void_args <= 8) {
286 return detail::PackFuncVoidAddr_<8>(f, codes);
287 } else {
288 return detail::PackFuncVoidAddr_<0>(f, codes);
289 }
290}
291
292inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types) {
293 size_t base = arg_types.size();
294 for (size_t i = 0; i < arg_types.size(); ++i) {
295 if (arg_types[i].code != kTVMOpaqueHandle) {
296 base = i;
297 break;
298 }
299 }
300 for (size_t i = base; i < arg_types.size(); ++i) {
301 ICHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized";
302 }
303 return base;
304}
305
306template <typename F>
307inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types) {
308 size_t num_buffer = NumBufferArgs(arg_types);
309 std::vector<detail::ArgConvertCode> codes;
310 for (size_t i = num_buffer; i < arg_types.size(); ++i) {
311 codes.push_back(detail::GetArgConvertCode(arg_types[i]));
312 }
313 int base = static_cast<int>(num_buffer);
314 size_t nargs = codes.size();
315 // specialization
316 if (nargs <= 4) {
317 return detail::PackFuncNonBufferArg_<4>(f, base, codes);
318 } else {
319 return detail::PackFuncNonBufferArg_<0>(f, base, codes);
320 }
321}
322
323template <typename F>
324inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types) {
325 std::vector<detail::ArgConvertCode> codes;
326 for (size_t i = 0; i < arg_types.size(); ++i) {
327 codes.push_back(detail::GetArgConvertCode(arg_types[i]));
328 }
329 size_t nargs = codes.size();
330 // specialization
331 if (nargs <= 4) {
332 return detail::PackFuncPackedArg_<4>(f, codes);
333 } else {
334 return detail::PackFuncPackedArg_<0>(f, codes);
335 }
336}
337} // namespace runtime
338} // namespace tvm
339#endif // TVM_RUNTIME_PACK_ARGS_H_
340