1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file pack_args.h |
22 | * \brief Utility to pack TVMArgs to other type-erased fution calling convention. |
23 | * |
24 | * Two type erased function signatures are supported. |
25 | * - cuda_style(void** args, int num_args); |
26 | * - Pack everything by address |
27 | * - metal_style(void** buffers, int num_buffers, |
28 | * union_32bit args[N], int num_args); |
29 | * - Pack buffer by address, pack rest parameter into 32bit union buffer. |
30 | */ |
31 | #ifndef TVM_RUNTIME_PACK_ARGS_H_ |
32 | #define TVM_RUNTIME_PACK_ARGS_H_ |
33 | |
34 | #include <tvm/runtime/c_runtime_api.h> |
35 | #include <tvm/runtime/packed_func.h> |
36 | |
37 | #include <cstring> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | namespace runtime { |
42 | /*! |
43 | * \brief argument union type of 32bit. |
44 | */ |
45 | union ArgUnion32 { |
46 | int32_t v_int32; |
47 | uint32_t v_uint32; |
48 | float v_float32; |
49 | }; |
50 | |
51 | /*! |
52 | * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime. |
53 | */ |
54 | union ArgUnion64 { |
55 | int32_t v_int32[2]; |
56 | uint32_t v_uint32[2]; |
57 | float v_float32[2]; |
58 | int64_t v_int64; |
59 | uint64_t v_uint64; |
60 | double v_float64; |
61 | }; |
62 | /*! |
63 | * \brief Create a packed function from void addr types. |
64 | * |
65 | * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) |
66 | * \param arg_types The arguments type information. |
67 | * \tparam F the function type |
68 | * |
69 | * \return The wrapped packed function. |
70 | */ |
71 | template <typename F> |
72 | inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types); |
73 | /*! |
74 | * \brief Create a packed function that from function only packs buffer arguments. |
75 | * |
76 | * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args) |
77 | * \param arg_types The arguments type information. |
78 | * \tparam F the function type |
79 | * |
80 | * \return The wrapped packed function. |
81 | */ |
82 | template <typename F> |
83 | inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types); |
84 | /*! |
85 | * \brief Create a packed function that from function that takes a packed arguments. |
86 | * |
87 | * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes) |
88 | * \param arg_types The arguments that wish to get from |
89 | * \tparam F the function type |
90 | * |
91 | * \return The wrapped packed function. |
92 | */ |
93 | template <typename F> |
94 | inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types); |
95 | /*! |
96 | * \brief Extract number of buffer argument from the argument types. |
97 | * \param arg_types The argument types. |
98 | * \return number of buffer arguments |
99 | */ |
100 | inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types); |
101 | |
102 | // implementations details |
103 | namespace detail { |
104 | template <typename T, int kSize> |
105 | class TempArray { |
106 | public: |
107 | explicit TempArray(int size) {} |
108 | T* data() { return data_; } |
109 | |
110 | private: |
111 | T data_[kSize]; |
112 | }; |
113 | template <typename T> |
114 | class TempArray<T, 0> { |
115 | public: |
116 | explicit TempArray(int size) : data_(size) {} |
117 | T* data() { return data_.data(); } |
118 | |
119 | private: |
120 | std::vector<T> data_; |
121 | }; |
122 | |
123 | /*! \brief conversion code used in void arg. */ |
124 | enum ArgConvertCode { |
125 | INT64_TO_INT64, |
126 | INT64_TO_INT32, |
127 | INT64_TO_UINT32, |
128 | FLOAT64_TO_FLOAT32, |
129 | FLOAT64_TO_FLOAT64, |
130 | HANDLE_TO_HANDLE |
131 | }; |
132 | |
133 | inline ArgConvertCode GetArgConvertCode(DLDataType t) { |
134 | ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now" ; |
135 | if (t.code == kDLInt) { |
136 | if (t.bits == 64U) return INT64_TO_INT64; |
137 | if (t.bits == 32U) return INT64_TO_INT32; |
138 | } else if (t.code == kDLUInt) { |
139 | if (t.bits == 32U) return INT64_TO_UINT32; |
140 | } else if (t.code == kDLFloat) { |
141 | if (t.bits == 64U) return FLOAT64_TO_FLOAT64; |
142 | if (t.bits == 32U) return FLOAT64_TO_FLOAT32; |
143 | } else if (t.code == kTVMOpaqueHandle) { |
144 | return HANDLE_TO_HANDLE; |
145 | } |
146 | LOG(FATAL) << "Cannot handle " << t << " as device function argument" ; |
147 | } |
148 | |
149 | template <int N, typename F> |
150 | inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) { |
151 | int num_args = static_cast<int>(codes.size()); |
152 | auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { |
153 | TempArray<void*, N> addr_(num_args); |
154 | TempArray<ArgUnion32, N> holder_(num_args); |
155 | void** addr = addr_.data(); |
156 | ArgUnion32* holder = holder_.data(); |
157 | for (int i = 0; i < num_args; ++i) { |
158 | switch (codes[i]) { |
159 | case INT64_TO_INT64: |
160 | case FLOAT64_TO_FLOAT64: |
161 | case HANDLE_TO_HANDLE: { |
162 | addr[i] = (void*)&(args.values[i]); // NOLINT(*) |
163 | break; |
164 | } |
165 | case INT64_TO_INT32: { |
166 | holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64); |
167 | addr[i] = &(holder[i]); |
168 | break; |
169 | } |
170 | case INT64_TO_UINT32: { |
171 | holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64); |
172 | addr[i] = &(holder[i]); |
173 | break; |
174 | } |
175 | case FLOAT64_TO_FLOAT32: { |
176 | holder[i].v_float32 = static_cast<float>(args.values[i].v_float64); |
177 | addr[i] = &(holder[i]); |
178 | break; |
179 | } |
180 | } |
181 | } |
182 | f(args, ret, addr); |
183 | }; |
184 | return PackedFunc(ret); |
185 | } |
186 | |
187 | template <int N, typename F> |
188 | inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) { |
189 | int num_args = static_cast<int>(codes.size()); |
190 | auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { |
191 | TempArray<ArgUnion64, N> holder_(num_args); |
192 | ArgUnion64* holder = holder_.data(); |
193 | for (int i = 0; i < num_args; ++i) { |
194 | switch (codes[i]) { |
195 | case INT64_TO_INT64: { |
196 | holder[i].v_int64 = args.values[base + i].v_int64; |
197 | break; |
198 | } |
199 | case FLOAT64_TO_FLOAT64: { |
200 | holder[i].v_float64 = args.values[base + i].v_float64; |
201 | break; |
202 | } |
203 | case INT64_TO_INT32: { |
204 | holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64); |
205 | break; |
206 | } |
207 | case INT64_TO_UINT32: { |
208 | holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64); |
209 | break; |
210 | } |
211 | case FLOAT64_TO_FLOAT32: { |
212 | holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64); |
213 | break; |
214 | } |
215 | case HANDLE_TO_HANDLE: { |
216 | LOG(FATAL) << "not reached" ; |
217 | break; |
218 | } |
219 | } |
220 | } |
221 | f(args, ret, holder); |
222 | }; |
223 | return PackedFunc(ret); |
224 | } |
225 | |
226 | template <int N, typename F> |
227 | inline PackedFunc PackFuncPackedArg_(F f, const std::vector<ArgConvertCode>& codes) { |
228 | int num_args = static_cast<int>(codes.size()); |
229 | auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { |
230 | TempArray<uint64_t, N> pack_(num_args); |
231 | int32_t* pack = reinterpret_cast<int32_t*>(pack_.data()); |
232 | int32_t* ptr = pack; |
233 | static_assert(sizeof(TVMValue) == 8, "invariant" ); |
234 | static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant" ); |
235 | for (int i = 0; i < num_args; ++i) { |
236 | switch (codes[i]) { |
237 | case HANDLE_TO_HANDLE: { |
238 | std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*)); |
239 | ptr += sizeof(void*) / sizeof(int32_t); |
240 | break; |
241 | } |
242 | case INT64_TO_INT64: |
243 | case FLOAT64_TO_FLOAT64: { |
244 | std::memcpy(ptr, &args.values[i], sizeof(TVMValue)); |
245 | ptr += 2; |
246 | break; |
247 | } |
248 | case INT64_TO_INT32: { |
249 | *ptr = static_cast<int32_t>(args.values[i].v_int64); |
250 | ++ptr; |
251 | break; |
252 | } |
253 | case INT64_TO_UINT32: { |
254 | *reinterpret_cast<uint32_t*>(ptr) = static_cast<uint32_t>(args.values[i].v_int64); |
255 | ++ptr; |
256 | break; |
257 | } |
258 | case FLOAT64_TO_FLOAT32: { |
259 | *reinterpret_cast<float*>(ptr) = static_cast<float>(args.values[i].v_float64); |
260 | ++ptr; |
261 | break; |
262 | } |
263 | default: { |
264 | LOG(FATAL) << "not reached" ; |
265 | break; |
266 | } |
267 | } |
268 | } |
269 | f(args, ret, pack, (ptr - pack) * sizeof(int32_t)); |
270 | }; |
271 | return PackedFunc(ret); |
272 | } |
273 | } // namespace detail |
274 | |
275 | template <typename F> |
276 | inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types) { |
277 | std::vector<detail::ArgConvertCode> codes(arg_types.size()); |
278 | for (size_t i = 0; i < arg_types.size(); ++i) { |
279 | codes[i] = detail::GetArgConvertCode(arg_types[i]); |
280 | } |
281 | size_t num_void_args = arg_types.size(); |
282 | // specialization |
283 | if (num_void_args <= 4) { |
284 | return detail::PackFuncVoidAddr_<4>(f, codes); |
285 | } else if (num_void_args <= 8) { |
286 | return detail::PackFuncVoidAddr_<8>(f, codes); |
287 | } else { |
288 | return detail::PackFuncVoidAddr_<0>(f, codes); |
289 | } |
290 | } |
291 | |
292 | inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types) { |
293 | size_t base = arg_types.size(); |
294 | for (size_t i = 0; i < arg_types.size(); ++i) { |
295 | if (arg_types[i].code != kTVMOpaqueHandle) { |
296 | base = i; |
297 | break; |
298 | } |
299 | } |
300 | for (size_t i = base; i < arg_types.size(); ++i) { |
301 | ICHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized" ; |
302 | } |
303 | return base; |
304 | } |
305 | |
306 | template <typename F> |
307 | inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types) { |
308 | size_t num_buffer = NumBufferArgs(arg_types); |
309 | std::vector<detail::ArgConvertCode> codes; |
310 | for (size_t i = num_buffer; i < arg_types.size(); ++i) { |
311 | codes.push_back(detail::GetArgConvertCode(arg_types[i])); |
312 | } |
313 | int base = static_cast<int>(num_buffer); |
314 | size_t nargs = codes.size(); |
315 | // specialization |
316 | if (nargs <= 4) { |
317 | return detail::PackFuncNonBufferArg_<4>(f, base, codes); |
318 | } else { |
319 | return detail::PackFuncNonBufferArg_<0>(f, base, codes); |
320 | } |
321 | } |
322 | |
323 | template <typename F> |
324 | inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types) { |
325 | std::vector<detail::ArgConvertCode> codes; |
326 | for (size_t i = 0; i < arg_types.size(); ++i) { |
327 | codes.push_back(detail::GetArgConvertCode(arg_types[i])); |
328 | } |
329 | size_t nargs = codes.size(); |
330 | // specialization |
331 | if (nargs <= 4) { |
332 | return detail::PackFuncPackedArg_<4>(f, codes); |
333 | } else { |
334 | return detail::PackFuncPackedArg_<0>(f, codes); |
335 | } |
336 | } |
337 | } // namespace runtime |
338 | } // namespace tvm |
339 | #endif // TVM_RUNTIME_PACK_ARGS_H_ |
340 | |