1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/packed_func.h |
22 | * \brief Type-erased function used across TVM API. |
23 | */ |
24 | #ifndef TVM_RUNTIME_PACKED_FUNC_H_ |
25 | #define TVM_RUNTIME_PACKED_FUNC_H_ |
26 | |
27 | #include <tvm/runtime/c_runtime_api.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/container/map.h> |
30 | #include <tvm/runtime/data_type.h> |
31 | #include <tvm/runtime/logging.h> |
32 | #include <tvm/runtime/module.h> |
33 | #include <tvm/runtime/ndarray.h> |
34 | #include <tvm/runtime/object.h> |
35 | |
36 | #include <functional> |
37 | #include <limits> |
38 | #include <memory> |
39 | #include <string> |
40 | #include <tuple> |
41 | #include <type_traits> |
42 | #include <utility> |
43 | #include <vector> |
44 | |
45 | // Whether use TVM runtime in header only mode. |
46 | #ifndef TVM_RUNTIME_HEADER_ONLY |
47 | #define 0 |
48 | #endif |
49 | |
50 | namespace tvm { |
51 | namespace runtime { |
52 | |
53 | // forward declarations |
54 | class TVMArgs; |
55 | class TVMArgValue; |
56 | class TVMMovableArgValueWithContext_; |
57 | class TVMRetValue; |
58 | class TVMArgsSetter; |
59 | template <typename FType> |
60 | class TypedPackedFunc; |
61 | template <typename TSignature> |
62 | struct SignaturePrinter; |
63 | |
64 | /*! |
65 | * \brief Object container class that backs PackedFunc. |
66 | * \note Do not use this function directly, use PackedFunc. |
67 | */ |
68 | class PackedFuncObj : public Object { |
69 | public: |
70 | /*! |
71 | * \brief Call the function in packed format. |
72 | * \param args The arguments |
73 | * \param rv The return value. |
74 | */ |
75 | TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; |
76 | |
77 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; |
78 | static constexpr const char* _type_key = "runtime.PackedFunc" ; |
79 | TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); |
80 | |
81 | protected: |
82 | /*! |
83 | * \brief Internal struct for extracting the callable method from callable type. |
84 | */ |
85 | template <class TPackedFuncSubObj> |
86 | struct { |
87 | /*! |
88 | * \brief Extracting the callable method from callable type. |
89 | * \param obj The base packed function object class. |
90 | * \param args The arguments |
91 | * \param rv The return value. |
92 | */ |
93 | static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv); |
94 | }; |
95 | |
96 | /*! \brief The internal callable function type. */ |
97 | using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); |
98 | |
99 | /*! |
100 | * \brief Constructing a packed function object from a function pointer. |
101 | * \param f_call_pack The function pointer used to call the packed function. |
102 | */ |
103 | explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} |
104 | |
105 | /*! \brief Delete the default constructor explicitly. */ |
106 | PackedFuncObj() = delete; |
107 | |
108 | /*! \brief Internal callable function pointer used to call the packed function. */ |
109 | FCallPacked* f_call_packed_; |
110 | }; |
111 | |
112 | /*! \brief Derived object class for constructing PackedFuncObj. */ |
113 | template <class TCallable> |
114 | class PackedFuncSubObj : public PackedFuncObj { |
115 | using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type; |
116 | |
117 | public: |
118 | /*! \brief The type of derived object class */ |
119 | using TSelf = PackedFuncSubObj<TCallable>; |
120 | /*! |
121 | * \brief Derived object class for constructing PackedFuncObj. |
122 | * \param callable The type-erased callable object. |
123 | */ |
124 | explicit PackedFuncSubObj(TCallable callable) |
125 | : PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {} |
126 | /*! \brief Type-erased filed for storing callable object*/ |
127 | mutable TStorage callable_; |
128 | }; |
129 | |
130 | /*! |
131 | * \brief Packed function is a type-erased function. |
132 | * The arguments are passed by packed format. |
133 | * |
134 | * This is an useful unified interface to call generated functions, |
135 | * It is the unified function function type of TVM. |
136 | * It corresponds to TVMFunctionHandle in C runtime API. |
137 | */ |
138 | class PackedFunc : public ObjectRef { |
139 | public: |
140 | /*! \brief Constructor from null */ |
141 | PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*) |
142 | /*! |
143 | * \brief Constructing a packed function from a callable type |
144 | * whose signature is consistent with `PackedFunc` |
145 | * \param data the internal container of packed function. |
146 | */ |
147 | template <typename TCallable, |
148 | typename = std::enable_if_t< |
149 | std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value && |
150 | !std::is_base_of<TCallable, PackedFunc>::value>> |
151 | explicit PackedFunc(TCallable data) { |
152 | using ObjType = PackedFuncSubObj<TCallable>; |
153 | data_ = make_object<ObjType>(std::forward<TCallable>(data)); |
154 | } |
155 | /*! |
156 | * \brief Call packed function by directly passing in unpacked format. |
157 | * \param args Arguments to be passed. |
158 | * \tparam Args arguments to be passed. |
159 | * |
160 | * \code |
161 | * // Example code on how to call packed function |
162 | * void CallPacked(PackedFunc f) { |
163 | * // call like normal functions by pass in arguments |
164 | * // return value is automatically converted back |
165 | * int rvalue = f(1, 2.0); |
166 | * } |
167 | * \endcode |
168 | */ |
169 | template <typename... Args> |
170 | inline TVMRetValue operator()(Args&&... args) const; |
171 | /*! |
172 | * \brief Call the function in packed format. |
173 | * \param args The arguments |
174 | * \param rv The return value. |
175 | */ |
176 | TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; |
177 | /*! \return Whether the packed function is nullptr */ |
178 | bool operator==(std::nullptr_t null) const { return data_ == nullptr; } |
179 | /*! \return Whether the packed function is not nullptr */ |
180 | bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } |
181 | |
182 | TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj); |
183 | }; |
184 | |
185 | /*! \brief Using static function to output TypedPackedFunc signature */ |
186 | using FSig = std::string(); |
187 | |
188 | /*! |
189 | * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>" |
190 | */ |
191 | template <typename FType> |
192 | class TypedPackedFunc; |
193 | |
194 | /*! |
195 | * \anchor TypedPackedFuncAnchor |
196 | * \brief A PackedFunc wrapper to provide typed function signature. |
197 | * It is backed by a PackedFunc internally. |
198 | * |
199 | * TypedPackedFunc enables compile time type checking. |
200 | * TypedPackedFunc works with the runtime system: |
201 | * - It can be passed as an argument of PackedFunc. |
202 | * - It can be assigned to TVMRetValue. |
203 | * - It can be directly converted to a type-erased PackedFunc. |
204 | * |
205 | * Developers should prefer TypedPackedFunc over PackedFunc in C++ code |
206 | * as it enables compile time checking. |
207 | * We can construct a TypedPackedFunc from a lambda function |
208 | * with the same signature. |
209 | * |
210 | * \code |
211 | * // user defined lambda function. |
212 | * auto addone = [](int x)->int { |
213 | * return x + 1; |
214 | * }; |
215 | * // We can directly convert |
216 | * // lambda function to TypedPackedFunc |
217 | * TypedPackedFunc<int(int)> ftyped(addone); |
218 | * // invoke the function. |
219 | * int y = ftyped(1); |
220 | * // Can be directly converted to PackedFunc |
221 | * PackedFunc packed = ftype; |
222 | * \endcode |
223 | * \tparam R The return value of the function. |
224 | * \tparam Args The argument signature of the function. |
225 | */ |
226 | template <typename R, typename... Args> |
227 | class TypedPackedFunc<R(Args...)> { |
228 | public: |
229 | /*! \brief short hand for this function type */ |
230 | using TSelf = TypedPackedFunc<R(Args...)>; |
231 | /*! \brief default constructor */ |
232 | TypedPackedFunc() {} |
233 | /*! \brief constructor from null */ |
234 | TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*) |
235 | /*! |
236 | * \brief construct by wrap a PackedFunc |
237 | * |
238 | * Example usage: |
239 | * \code |
240 | * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) { |
241 | * int x = args[0]; |
242 | * *rv = x + 1; |
243 | * }); |
244 | * // construct from packed function |
245 | * TypedPackedFunc<int(int)> ftyped(packed); |
246 | * // call the typed version. |
247 | * ICHECK_EQ(ftyped(1), 2); |
248 | * \endcode |
249 | * |
250 | * \param packed The packed function |
251 | */ |
252 | inline TypedPackedFunc(PackedFunc packed); // NOLINT(*) |
253 | /*! |
254 | * \brief constructor from TVMRetValue |
255 | * \param value The TVMRetValue |
256 | */ |
257 | inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*) |
258 | /*! |
259 | * \brief constructor from TVMArgValue |
260 | * \param value The TVMArgValue |
261 | */ |
262 | inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) |
263 | /*! |
264 | * \brief constructor from TVMMovableArgValue_ |
265 | * \param value The TVMMovableArgValue_ |
266 | */ |
267 | inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*) |
268 | /*! |
269 | * \brief construct from a lambda function with the same signature. |
270 | * |
271 | * Example usage: |
272 | * \code |
273 | * auto typed_lambda = [](int x)->int { return x + 1; } |
274 | * // construct from packed function |
275 | * TypedPackedFunc<int(int)> ftyped(typed_lambda, "add_one"); |
276 | * // call the typed version. |
277 | * ICHECK_EQ(ftyped(1), 2); |
278 | * \endcode |
279 | * |
280 | * \param typed_lambda typed lambda function. |
281 | * \param name the name of the lambda function. |
282 | * \tparam FLambda the type of the lambda function. |
283 | */ |
284 | template <typename FLambda, typename = typename std::enable_if<std::is_convertible< |
285 | FLambda, std::function<R(Args...)>>::value>::type> |
286 | TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*) |
287 | this->AssignTypedLambda(typed_lambda, name); |
288 | } |
289 | /*! |
290 | * \brief construct from a lambda function with the same signature. |
291 | * |
292 | * This version does not take a name. It is highly recommend you use the |
293 | * version that takes a name for the lambda. |
294 | * |
295 | * Example usage: |
296 | * \code |
297 | * auto typed_lambda = [](int x)->int { return x + 1; } |
298 | * // construct from packed function |
299 | * TypedPackedFunc<int(int)> ftyped(typed_lambda); |
300 | * // call the typed version. |
301 | * ICHECK_EQ(ftyped(1), 2); |
302 | * \endcode |
303 | * |
304 | * \param typed_lambda typed lambda function. |
305 | * \tparam FLambda the type of the lambda function. |
306 | */ |
307 | template <typename FLambda, typename = typename std::enable_if<std::is_convertible< |
308 | FLambda, std::function<R(Args...)>>::value>::type> |
309 | TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) |
310 | this->AssignTypedLambda(typed_lambda); |
311 | } |
312 | /*! |
313 | * \brief copy assignment operator from typed lambda |
314 | * |
315 | * Example usage: |
316 | * \code |
317 | * // construct from packed function |
318 | * TypedPackedFunc<int(int)> ftyped; |
319 | * ftyped = [](int x) { return x + 1; } |
320 | * // call the typed version. |
321 | * ICHECK_EQ(ftyped(1), 2); |
322 | * \endcode |
323 | * |
324 | * \param typed_lambda typed lambda function. |
325 | * \tparam FLambda the type of the lambda function. |
326 | * \returns reference to self. |
327 | */ |
328 | template <typename FLambda, typename = typename std::enable_if< |
329 | std::is_convertible<FLambda, |
330 | std::function<R(Args...)>>::value>::type> |
331 | TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) |
332 | this->AssignTypedLambda(typed_lambda); |
333 | return *this; |
334 | } |
335 | /*! |
336 | * \brief copy assignment operator from PackedFunc. |
337 | * \param packed The packed function. |
338 | * \returns reference to self. |
339 | */ |
340 | TSelf& operator=(PackedFunc packed) { |
341 | packed_ = packed; |
342 | return *this; |
343 | } |
344 | /*! |
345 | * \brief Invoke the operator. |
346 | * \param args The arguments |
347 | * \returns The return value. |
348 | */ |
349 | TVM_ALWAYS_INLINE R operator()(Args... args) const; |
350 | /*! |
351 | * \brief convert to PackedFunc |
352 | * \return the internal PackedFunc |
353 | */ |
354 | operator PackedFunc() const { return packed(); } |
355 | /*! |
356 | * \return reference the internal PackedFunc |
357 | */ |
358 | const PackedFunc& packed() const { return packed_; } |
359 | /*! \return Whether the packed function is nullptr */ |
360 | bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } |
361 | /*! \return Whether the packed function is not nullptr */ |
362 | bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } |
363 | |
364 | private: |
365 | friend class TVMRetValue; |
366 | /*! \brief The internal packed function */ |
367 | PackedFunc packed_; |
368 | /*! |
369 | * \brief Assign the packed field using a typed lambda function. |
370 | * |
371 | * \param flambda The lambda function. |
372 | * \param name The name associated with this lambda. |
373 | * \tparam FLambda The lambda function type. |
374 | * \note We capture the lambda when possible for maximum efficiency. |
375 | */ |
376 | template <typename FLambda> |
377 | inline void AssignTypedLambda(FLambda flambda, std::string name); |
378 | /*! |
379 | * \brief Assign the packed field using a typed lambda function. This variant is for functions |
380 | * without names. |
381 | * |
382 | * \param flambda The lambda function. |
383 | * \tparam FLambda The lambda function type. |
384 | * \note We capture the lambda when possible for maximum efficiency. |
385 | */ |
386 | template <typename FLambda> |
387 | inline void AssignTypedLambda(FLambda flambda); |
388 | }; |
389 | |
390 | /*! \brief Arguments into TVM functions. */ |
391 | class TVMArgs { |
392 | public: |
393 | const TVMValue* values; |
394 | const int* type_codes; |
395 | int num_args; |
396 | /*! |
397 | * \brief constructor |
398 | * \param values The argument values |
399 | * \param type_codes The argument type codes |
400 | * \param num_args number of arguments. |
401 | */ |
402 | TVMArgs(const TVMValue* values, const int* type_codes, int num_args) |
403 | : values(values), type_codes(type_codes), num_args(num_args) {} |
404 | /*! \return size of the arguments */ |
405 | inline int size() const; |
406 | /*! |
407 | * \brief Get i-th argument |
408 | * \param i the index. |
409 | * \return the ith argument. |
410 | */ |
411 | inline TVMArgValue operator[](int i) const; |
412 | }; |
413 | |
414 | /*! |
415 | * \brief Convert argument type code to string. |
416 | * \param type_code The input type code. |
417 | * \return The corresponding string repr. |
418 | */ |
419 | inline const char* ArgTypeCode2Str(int type_code); |
420 | |
421 | // macro to check type code. |
422 | #define TVM_CHECK_TYPE_CODE(CODE, T) \ |
423 | ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) |
424 | |
425 | /*! |
426 | * \brief Type traits for runtime type check during FFI conversion. |
427 | * \tparam T the type to be checked. |
428 | */ |
429 | template <typename T> |
430 | struct ObjectTypeChecker { |
431 | /*! |
432 | * \brief Check if an object matches the template type and return the |
433 | * mismatched type if it exists. |
434 | * \param ptr The object to check the type of. |
435 | * \return An Optional containing the actual type of the pointer if it does not match the |
436 | * template type. If the Optional does not contain a value, then the types match. |
437 | */ |
438 | static Optional<String> CheckAndGetMismatch(const Object* ptr) { |
439 | using ContainerType = typename T::ContainerType; |
440 | if (ptr == nullptr) { |
441 | if (T::_type_is_nullable) { |
442 | return NullOpt; |
443 | } else { |
444 | return String("nullptr" ); |
445 | } |
446 | } |
447 | if (ptr->IsInstance<ContainerType>()) { |
448 | return NullOpt; |
449 | } else { |
450 | return String(ptr->GetTypeKey()); |
451 | } |
452 | } |
453 | /*! |
454 | * \brief Check if an object matches the template type. |
455 | * \param ptr The object to check the type of. |
456 | * \return Whether or not the template type matches the objects type. |
457 | */ |
458 | static bool Check(const Object* ptr) { |
459 | using ContainerType = typename T::ContainerType; |
460 | if (ptr == nullptr) return T::_type_is_nullable; |
461 | return ptr->IsInstance<ContainerType>(); |
462 | } |
463 | static std::string TypeName() { |
464 | using ContainerType = typename T::ContainerType; |
465 | return ContainerType::_type_key; |
466 | } |
467 | }; |
468 | |
469 | // Additional overloads for PackedFunc checking. |
470 | template <typename T> |
471 | struct ObjectTypeChecker<Array<T>> { |
472 | static Optional<String> CheckAndGetMismatch(const Object* ptr) { |
473 | if (ptr == nullptr) { |
474 | return NullOpt; |
475 | } |
476 | if (!ptr->IsInstance<ArrayNode>()) { |
477 | return String(ptr->GetTypeKey()); |
478 | } |
479 | const ArrayNode* n = static_cast<const ArrayNode*>(ptr); |
480 | for (size_t i = 0; i < n->size(); i++) { |
481 | const ObjectRef& p = (*n)[i]; |
482 | Optional<String> check_subtype = ObjectTypeChecker<T>::CheckAndGetMismatch(p.get()); |
483 | if (check_subtype.defined()) { |
484 | return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]" ); |
485 | } |
486 | } |
487 | return NullOpt; |
488 | } |
489 | static bool Check(const Object* ptr) { |
490 | if (ptr == nullptr) return true; |
491 | if (!ptr->IsInstance<ArrayNode>()) return false; |
492 | const ArrayNode* n = static_cast<const ArrayNode*>(ptr); |
493 | for (const ObjectRef& p : *n) { |
494 | if (!ObjectTypeChecker<T>::Check(p.get())) { |
495 | return false; |
496 | } |
497 | } |
498 | return true; |
499 | } |
500 | static std::string TypeName() { return "Array[" + ObjectTypeChecker<T>::TypeName() + "]" ; } |
501 | }; |
502 | template <typename K, typename V> |
503 | struct ObjectTypeChecker<Map<K, V>> { |
504 | static Optional<String> CheckAndGetMismatch(const Object* ptr) { |
505 | if (ptr == nullptr) return NullOpt; |
506 | if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey()); |
507 | const MapNode* n = static_cast<const MapNode*>(ptr); |
508 | for (const auto& kv : *n) { |
509 | Optional<String> key_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get()); |
510 | Optional<String> value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get()); |
511 | if (key_type.defined() || value_type.defined()) { |
512 | std::string key_name = |
513 | key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName(); |
514 | std::string value_name = value_type.defined() ? std::string(value_type.value()) |
515 | : ObjectTypeChecker<V>::TypeName(); |
516 | return String("Map[" + key_name + ", " + value_name + "]" ); |
517 | } |
518 | } |
519 | return NullOpt; |
520 | } |
521 | static bool Check(const Object* ptr) { |
522 | if (ptr == nullptr) return true; |
523 | if (!ptr->IsInstance<MapNode>()) return false; |
524 | const MapNode* n = static_cast<const MapNode*>(ptr); |
525 | for (const auto& kv : *n) { |
526 | if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false; |
527 | if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false; |
528 | } |
529 | return true; |
530 | } |
531 | static std::string TypeName() { |
532 | return "Map[" + ObjectTypeChecker<K>::TypeName() + ", " + ObjectTypeChecker<V>::TypeName() + |
533 | ']'; |
534 | } |
535 | }; |
536 | |
537 | /*! |
538 | * \brief Internal base class to |
539 | * handle conversion to POD values. |
540 | */ |
541 | class TVMPODValue_ { |
542 | public: |
543 | operator double() const { |
544 | // Allow automatic conversion from int to float |
545 | // This avoids errors when user pass in int from |
546 | // the frontend while the API expects a float. |
547 | if (type_code_ == kDLInt) { |
548 | return static_cast<double>(value_.v_int64); |
549 | } |
550 | TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); |
551 | return value_.v_float64; |
552 | } |
553 | operator int64_t() const { |
554 | TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
555 | return value_.v_int64; |
556 | } |
557 | operator uint64_t() const { |
558 | TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
559 | return value_.v_int64; |
560 | } |
561 | operator int() const { |
562 | TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
563 | ICHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); |
564 | ICHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); |
565 | return static_cast<int>(value_.v_int64); |
566 | } |
567 | operator bool() const { |
568 | TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
569 | return value_.v_int64 != 0; |
570 | } |
571 | operator void*() const { |
572 | if (type_code_ == kTVMNullptr) return nullptr; |
573 | if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; |
574 | TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle); |
575 | return value_.v_handle; |
576 | } |
577 | operator DLTensor*() const { |
578 | if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { |
579 | return static_cast<DLTensor*>(value_.v_handle); |
580 | } else { |
581 | if (type_code_ == kTVMNullptr) return nullptr; |
582 | LOG(FATAL) << "Expected " |
583 | << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_); |
584 | return nullptr; |
585 | } |
586 | } |
587 | operator NDArray() const { |
588 | if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr)); |
589 | TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); |
590 | return NDArray(NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle))); |
591 | } |
592 | operator Module() const { |
593 | if (type_code_ == kTVMNullptr) { |
594 | return Module(ObjectPtr<Object>(nullptr)); |
595 | } |
596 | TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); |
597 | return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
598 | } |
599 | operator PackedFunc() const { |
600 | if (type_code_ == kTVMNullptr) { |
601 | return PackedFunc(ObjectPtr<Object>(nullptr)); |
602 | } |
603 | TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); |
604 | return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
605 | } |
606 | operator Device() const { |
607 | TVM_CHECK_TYPE_CODE(type_code_, kDLDevice); |
608 | return value_.v_device; |
609 | } |
610 | int type_code() const { return type_code_; } |
611 | /*! |
612 | * \brief return handle as specific pointer type. |
613 | * \tparam T the data type. |
614 | * \return The pointer type. |
615 | */ |
616 | template <typename T> |
617 | T* ptr() const { |
618 | return static_cast<T*>(value_.v_handle); |
619 | } |
620 | // ObjectRef handling |
621 | template <typename TObjectRef, |
622 | typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
623 | inline bool IsObjectRef() const; |
624 | template <typename TObjectRef> |
625 | inline TObjectRef AsObjectRef() const; |
626 | |
627 | protected: |
628 | friend class TVMArgsSetter; |
629 | friend class TVMRetValue; |
630 | friend class TVMMovableArgValue_; |
631 | TVMPODValue_() : type_code_(kTVMNullptr) {} |
632 | TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} |
633 | |
634 | /*! \brief The value */ |
635 | TVMValue value_; |
636 | /*! \brief the type code */ |
637 | int type_code_; |
638 | }; |
639 | |
640 | /*! |
641 | * \brief A single argument value to PackedFunc. |
642 | * Containing both type_code and TVMValue |
643 | * |
644 | * Provides utilities to do type cast into other types. |
645 | */ |
646 | class TVMArgValue : public TVMPODValue_ { |
647 | public: |
648 | /*! \brief default constructor */ |
649 | TVMArgValue() {} |
650 | /*! |
651 | * \brief constructor |
652 | * \param value of the function |
653 | * \param type_code The type code. |
654 | */ |
655 | TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} |
656 | // reuse converter from parent |
657 | using TVMPODValue_::operator double; |
658 | using TVMPODValue_::operator int64_t; |
659 | using TVMPODValue_::operator uint64_t; |
660 | using TVMPODValue_::operator int; |
661 | using TVMPODValue_::operator bool; |
662 | using TVMPODValue_::operator void*; |
663 | using TVMPODValue_::operator DLTensor*; |
664 | using TVMPODValue_::operator NDArray; |
665 | using TVMPODValue_::operator Device; |
666 | using TVMPODValue_::operator Module; |
667 | using TVMPODValue_::operator PackedFunc; |
668 | using TVMPODValue_::AsObjectRef; |
669 | using TVMPODValue_::IsObjectRef; |
670 | |
671 | // conversion operator. |
672 | operator std::string() const { |
673 | if (type_code_ == kTVMDataType) { |
674 | return DLDataType2String(operator DLDataType()); |
675 | } else if (type_code_ == kTVMBytes) { |
676 | TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle); |
677 | return std::string(arr->data, arr->size); |
678 | } else if (type_code_ == kTVMStr) { |
679 | return std::string(value_.v_str); |
680 | } else { |
681 | ICHECK(IsObjectRef<tvm::runtime::String>()) |
682 | << "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_) |
683 | << " to a string." ; |
684 | return AsObjectRef<tvm::runtime::String>().operator std::string(); |
685 | } |
686 | } |
687 | template <typename FType> |
688 | operator TypedPackedFunc<FType>() const { |
689 | return TypedPackedFunc<FType>(operator PackedFunc()); |
690 | } |
691 | const TVMValue& value() const { return value_; } |
692 | |
693 | template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type> |
694 | inline operator T() const; |
695 | inline operator DLDataType() const; |
696 | inline operator DataType() const; |
697 | }; |
698 | |
699 | /*! |
700 | * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument. |
701 | * |
702 | * We can only construct a movable argument once from a single argument position. |
703 | * If the argument is passed as RValue reference, the result will be moved. |
704 | * We should only construct a MovableArg from an argument once, |
705 | * as the result will can moved. |
706 | * |
707 | * \note For internal development purpose only. |
708 | */ |
709 | class TVMMovableArgValue_ : public TVMPODValue_ { |
710 | public: |
711 | TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} |
712 | // reuse converter from parent |
713 | using TVMPODValue_::operator double; |
714 | using TVMPODValue_::operator int64_t; |
715 | using TVMPODValue_::operator uint64_t; |
716 | using TVMPODValue_::operator int; |
717 | using TVMPODValue_::operator bool; |
718 | using TVMPODValue_::operator void*; |
719 | using TVMPODValue_::operator DLTensor*; |
720 | using TVMPODValue_::operator NDArray; |
721 | using TVMPODValue_::operator Device; |
722 | using TVMPODValue_::operator Module; |
723 | using TVMPODValue_::operator PackedFunc; |
724 | // reuse conversion rule from ArgValue. |
725 | operator std::string() const { return AsArgValue().operator std::string(); } |
726 | template <typename FType> |
727 | operator TypedPackedFunc<FType>() const { |
728 | return TypedPackedFunc<FType>(operator PackedFunc()); |
729 | } |
730 | operator DLDataType() const { return AsArgValue().operator DLDataType(); } |
731 | operator DataType() const { return AsArgValue().operator DataType(); } |
732 | operator TVMArgValue() const { return AsArgValue(); } |
733 | /*! |
734 | * \brief Helper converter function. |
735 | * Try to move out an argument if possible, |
736 | * fall back to normal argument conversion rule otherwise. |
737 | */ |
738 | template <typename T, |
739 | typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |
740 | inline operator T() const; |
741 | |
742 | private: |
743 | /*! \return The arg value repr of the value. */ |
744 | TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); } |
745 | }; |
746 | |
747 | /*! |
748 | * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with |
749 | * additional context information (function name and argument index) for better error reporting. |
750 | * |
751 | * \sa MovableArgValue_ |
752 | * \note For internal development purpose only. |
753 | */ |
754 | class TVMMovableArgValueWithContext_ { |
755 | public: |
756 | /*! |
757 | * \brief move constructor from another return value. |
758 | * \param value The other return value. |
759 | * \param type_code The code associated with the type of the value. |
760 | * \param arg_index In a function call, this argument is at index arg_index (0-indexed). |
761 | * \param optional_name Name of the function being called. Can be nullptr if the function is not. |
762 | * \param f_sig Pointer to static function outputting signature of the function being called. |
763 | * named. |
764 | */ |
765 | TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, |
766 | const std::string* optional_name, FSig* f_sig) |
767 | : value_(value, type_code), |
768 | arg_index_(arg_index), |
769 | optional_name_(optional_name), |
770 | f_sig_(f_sig) {} |
771 | |
772 | template <typename T> |
773 | operator T() const { |
774 | try { |
775 | return value_; // implicit conversion happens here |
776 | } catch (dmlc::Error& e) { |
777 | LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_) |
778 | << (f_sig_ == nullptr ? "" : (*f_sig_)()) << ": error while converting argument " |
779 | << arg_index_ << ": " << e.what(); |
780 | throw; // never reached, LOG(FATAL) throws, but this silences a warning. |
781 | } |
782 | } |
783 | |
784 | private: |
785 | TVMMovableArgValue_ value_; |
786 | int arg_index_; |
787 | const std::string* optional_name_; |
788 | FSig* f_sig_; |
789 | }; |
790 | |
791 | /*! |
792 | * \brief Return Value container, |
793 | * Unlike TVMArgValue, which only holds reference and do not delete |
794 | * the underlying container during destruction. |
795 | * |
796 | * TVMRetValue holds value and will manage the underlying containers |
797 | * when it stores a complicated data type. |
798 | */ |
799 | class TVMRetValue : public TVMPODValue_ { |
800 | public: |
801 | /*! \brief default constructor */ |
802 | TVMRetValue() {} |
803 | /*! |
804 | * \brief move constructor from another return value. |
805 | * \param other The other return value. |
806 | */ |
807 | TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { |
808 | other.value_.v_handle = nullptr; |
809 | other.type_code_ = kTVMNullptr; |
810 | } |
811 | /*! \brief destructor */ |
812 | ~TVMRetValue() { this->Clear(); } |
813 | // reuse converter from parent |
814 | using TVMPODValue_::operator double; |
815 | using TVMPODValue_::operator int64_t; |
816 | using TVMPODValue_::operator uint64_t; |
817 | using TVMPODValue_::operator int; |
818 | using TVMPODValue_::operator bool; |
819 | using TVMPODValue_::operator void*; |
820 | using TVMPODValue_::operator DLTensor*; |
821 | using TVMPODValue_::operator Device; |
822 | using TVMPODValue_::operator NDArray; |
823 | using TVMPODValue_::operator Module; |
824 | using TVMPODValue_::operator PackedFunc; |
825 | using TVMPODValue_::AsObjectRef; |
826 | using TVMPODValue_::IsObjectRef; |
827 | |
828 | TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } |
829 | // conversion operators |
830 | operator std::string() const { |
831 | if (type_code_ == kTVMDataType) { |
832 | return DLDataType2String(operator DLDataType()); |
833 | } else if (type_code_ == kTVMBytes) { |
834 | return *ptr<std::string>(); |
835 | } |
836 | TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); |
837 | return *ptr<std::string>(); |
838 | } |
839 | operator DLDataType() const { |
840 | if (type_code_ == kTVMStr) { |
841 | return String2DLDataType(operator std::string()); |
842 | } |
843 | TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); |
844 | return value_.v_type; |
845 | } |
846 | operator DataType() const { return DataType(operator DLDataType()); } |
847 | template <typename FType> |
848 | operator TypedPackedFunc<FType>() const { |
849 | return TypedPackedFunc<FType>(operator PackedFunc()); |
850 | } |
851 | // Assign operators |
852 | TVMRetValue& operator=(TVMRetValue&& other) { |
853 | this->Clear(); |
854 | value_ = other.value_; |
855 | type_code_ = other.type_code_; |
856 | other.type_code_ = kTVMNullptr; |
857 | return *this; |
858 | } |
859 | TVMRetValue& operator=(double value) { |
860 | this->SwitchToPOD(kDLFloat); |
861 | value_.v_float64 = value; |
862 | return *this; |
863 | } |
864 | TVMRetValue& operator=(std::nullptr_t value) { |
865 | this->SwitchToPOD(kTVMNullptr); |
866 | value_.v_handle = value; |
867 | return *this; |
868 | } |
869 | TVMRetValue& operator=(void* value) { |
870 | this->SwitchToPOD(kTVMOpaqueHandle); |
871 | value_.v_handle = value; |
872 | return *this; |
873 | } |
874 | TVMRetValue& operator=(int64_t value) { |
875 | this->SwitchToPOD(kDLInt); |
876 | value_.v_int64 = value; |
877 | return *this; |
878 | } |
879 | TVMRetValue& operator=(int value) { |
880 | this->SwitchToPOD(kDLInt); |
881 | value_.v_int64 = value; |
882 | return *this; |
883 | } |
884 | TVMRetValue& operator=(DLDevice value) { |
885 | this->SwitchToPOD(kDLDevice); |
886 | value_.v_device = value; |
887 | return *this; |
888 | } |
889 | TVMRetValue& operator=(DLDataType t) { |
890 | this->SwitchToPOD(kTVMDataType); |
891 | value_.v_type = t; |
892 | return *this; |
893 | } |
894 | TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } |
895 | TVMRetValue& operator=(bool value) { |
896 | this->SwitchToPOD(kDLInt); |
897 | value_.v_int64 = value; |
898 | return *this; |
899 | } |
900 | TVMRetValue& operator=(std::string value) { |
901 | this->SwitchToClass(kTVMStr, value); |
902 | return *this; |
903 | } |
904 | TVMRetValue& operator=(TVMByteArray value) { |
905 | this->SwitchToClass(kTVMBytes, std::string(value.data, value.size)); |
906 | return *this; |
907 | } |
908 | TVMRetValue& operator=(NDArray other) { |
909 | if (other.data_ != nullptr) { |
910 | this->Clear(); |
911 | type_code_ = kTVMNDArrayHandle; |
912 | value_.v_handle = NDArray::FFIGetHandle(other); |
913 | ObjectRef::FFIClearAfterMove(&other); |
914 | } else { |
915 | SwitchToPOD(kTVMNullptr); |
916 | value_.v_handle = nullptr; |
917 | } |
918 | return *this; |
919 | } |
920 | TVMRetValue& operator=(Module m) { |
921 | SwitchToObject(kTVMModuleHandle, std::move(m.data_)); |
922 | return *this; |
923 | } |
924 | TVMRetValue& operator=(PackedFunc f) { |
925 | this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_)); |
926 | return *this; |
927 | } |
928 | template <typename FType> |
929 | TVMRetValue& operator=(const TypedPackedFunc<FType>& f) { |
930 | return operator=(f.packed()); |
931 | } |
932 | TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 |
933 | this->Assign(other); |
934 | return *this; |
935 | } |
936 | TVMRetValue& operator=(const TVMArgValue& other) { |
937 | this->Assign(other); |
938 | return *this; |
939 | } |
940 | TVMRetValue& operator=(TVMMovableArgValue_&& other) { |
941 | this->Assign(other); |
942 | return *this; |
943 | } |
944 | /*! |
945 | * \brief Move the value back to front-end via C API. |
946 | * This marks the current container as null. |
947 | * The managed resources are moved to the front-end. |
948 | * The front end should take charge in managing them. |
949 | * |
950 | * \param ret_value The return value. |
951 | * \param ret_type_code The return type code. |
952 | */ |
953 | void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { |
954 | // cannot move str; need specially handle. |
955 | ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); |
956 | *ret_value = value_; |
957 | *ret_type_code = type_code_; |
958 | type_code_ = kTVMNullptr; |
959 | } |
960 | /*! |
961 | * \brief Construct a new TVMRetValue by |
962 | * moving from return value stored via C API. |
963 | * \param value the value. |
964 | * \param type_code The type code. |
965 | * \return The created TVMRetValue. |
966 | */ |
967 | static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { |
968 | // Can move POD and everything under the object system. |
969 | ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); |
970 | TVMRetValue ret; |
971 | ret.value_ = value; |
972 | ret.type_code_ = type_code; |
973 | return ret; |
974 | } |
975 | /*! \return The value field, if the data is POD */ |
976 | const TVMValue& value() const { |
977 | ICHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && |
978 | type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) |
979 | << "TVMRetValue.value can only be used for POD data" ; |
980 | return value_; |
981 | } |
982 | // ObjectRef handling |
983 | template <typename TObjectRef, |
984 | typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
985 | inline TVMRetValue& operator=(TObjectRef other); |
986 | template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type> |
987 | inline operator T() const; |
988 | |
989 | private: |
990 | template <typename T> |
991 | void Assign(const T& other) { |
992 | switch (other.type_code()) { |
993 | case kTVMStr: { |
994 | SwitchToClass<std::string>(kTVMStr, other); |
995 | break; |
996 | } |
997 | case kTVMBytes: { |
998 | SwitchToClass<std::string>(kTVMBytes, other); |
999 | break; |
1000 | } |
1001 | case kTVMPackedFuncHandle: { |
1002 | *this = other.operator PackedFunc(); |
1003 | break; |
1004 | } |
1005 | case kTVMModuleHandle: { |
1006 | *this = other.operator Module(); |
1007 | break; |
1008 | } |
1009 | case kTVMNDArrayHandle: { |
1010 | *this = other.operator NDArray(); |
1011 | break; |
1012 | } |
1013 | case kTVMObjectHandle: { |
1014 | // Avoid operator ObjectRef as we already know it is not NDArray/Module |
1015 | SwitchToObject(kTVMObjectHandle, |
1016 | GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle))); |
1017 | break; |
1018 | } |
1019 | case kTVMObjectRValueRefArg: { |
1020 | operator=(other.operator ObjectRef()); |
1021 | break; |
1022 | } |
1023 | default: { |
1024 | SwitchToPOD(other.type_code()); |
1025 | value_ = other.value_; |
1026 | break; |
1027 | } |
1028 | } |
1029 | } |
1030 | // get the internal container. |
1031 | void SwitchToPOD(int type_code) { |
1032 | if (type_code_ != type_code) { |
1033 | this->Clear(); |
1034 | type_code_ = type_code; |
1035 | } |
1036 | } |
1037 | template <typename T> |
1038 | void SwitchToClass(int type_code, T v) { |
1039 | if (type_code_ != type_code) { |
1040 | this->Clear(); |
1041 | type_code_ = type_code; |
1042 | value_.v_handle = new T(v); |
1043 | } else { |
1044 | *static_cast<T*>(value_.v_handle) = v; |
1045 | } |
1046 | } |
1047 | void SwitchToObject(int type_code, ObjectPtr<Object> other) { |
1048 | if (other.data_ != nullptr) { |
1049 | this->Clear(); |
1050 | type_code_ = type_code; |
1051 | // move the handle out |
1052 | value_.v_handle = other.data_; |
1053 | other.data_ = nullptr; |
1054 | } else { |
1055 | SwitchToPOD(kTVMNullptr); |
1056 | value_.v_handle = nullptr; |
1057 | } |
1058 | } |
1059 | void Clear() { |
1060 | if (type_code_ == kTVMNullptr) return; |
1061 | switch (type_code_) { |
1062 | case kTVMStr: |
1063 | case kTVMBytes: |
1064 | delete ptr<std::string>(); |
1065 | break; |
1066 | case kTVMPackedFuncHandle: |
1067 | static_cast<Object*>(value_.v_handle)->DecRef(); |
1068 | break; |
1069 | case kTVMNDArrayHandle: { |
1070 | NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle)); |
1071 | break; |
1072 | } |
1073 | case kTVMModuleHandle: { |
1074 | static_cast<Object*>(value_.v_handle)->DecRef(); |
1075 | break; |
1076 | } |
1077 | case kTVMObjectHandle: { |
1078 | static_cast<Object*>(value_.v_handle)->DecRef(); |
1079 | break; |
1080 | } |
1081 | } |
1082 | type_code_ = kTVMNullptr; |
1083 | } |
1084 | }; |
1085 | |
1086 | /*! |
1087 | * \brief Type trait to specify special value conversion rules from |
1088 | * TVMArgValue and TVMRetValue. |
1089 | * |
1090 | * The trait can be specialized to add type specific conversion logic |
1091 | * from the TVMArgvalue and TVMRetValue. |
1092 | * |
1093 | * \tparam TObjectRef the specific ObjectRefType. |
1094 | */ |
1095 | template <typename TObjectRef> |
1096 | struct PackedFuncValueConverter { |
1097 | /*! |
1098 | * \brief Convert a TObjectRef from an argument value. |
1099 | * \param val The argument value. |
1100 | * \return the converted result. |
1101 | */ |
1102 | static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); } |
1103 | /*! |
1104 | * \brief Convert a TObjectRef from a return value. |
1105 | * \param val The argument value. |
1106 | * \return the converted result. |
1107 | */ |
1108 | static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); } |
1109 | }; |
1110 | |
1111 | /*! |
1112 | * \brief Export a function with the PackedFunc signature |
1113 | * as a PackedFunc that can be loaded by LibraryModule. |
1114 | * |
1115 | * \param ExportName The symbol name to be exported. |
1116 | * \param Function The function with PackedFunc signature. |
1117 | * \sa PackedFunc |
1118 | * |
1119 | * \code |
1120 | * |
1121 | * void AddOne_(TVMArgs args, TVMRetValue* rv) { |
1122 | * int value = args[0]; |
1123 | * *rv = value + 1; |
1124 | * } |
1125 | * // Expose the function as "AddOne" |
1126 | * TVM_DLL_EXPORT_PACKED_FUNC(AddOne, AddOne_); |
1127 | * |
1128 | * \endcode |
1129 | */ |
1130 | #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ |
1131 | extern "C" { \ |
1132 | TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
1133 | int* out_type_code, void* resource_handle); \ |
1134 | int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
1135 | int* out_type_code, void* resource_handle) { \ |
1136 | try { \ |
1137 | ::tvm::runtime::TVMRetValue rv; \ |
1138 | Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ |
1139 | rv.MoveToCHost(out_value, out_type_code); \ |
1140 | return 0; \ |
1141 | } catch (const ::std::exception& _except_) { \ |
1142 | TVMAPISetLastError(_except_.what()); \ |
1143 | return -1; \ |
1144 | } \ |
1145 | } \ |
1146 | } |
1147 | |
1148 | /*! |
1149 | * \brief Export typed function as a PackedFunc |
1150 | * that can be loaded by LibraryModule. |
1151 | * |
1152 | * \param ExportName The symbol name to be exported. |
1153 | * \param Function The typed function. |
1154 | * \note ExportName and Function must be different, |
1155 | * see code examples below. |
1156 | * |
1157 | * \sa TypedPackedFunc |
1158 | * |
1159 | * \code |
1160 | * |
1161 | * int AddOne_(int x) { |
1162 | * return x + 1; |
1163 | * } |
1164 | * |
1165 | * // Expose the function as "AddOne" |
1166 | * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); |
1167 | * |
1168 | * // Expose the function as "SubOne" |
1169 | * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { |
1170 | * return x - 1; |
1171 | * }); |
1172 | * |
1173 | * // The following code will cause compilation error. |
1174 | * // Because the same Function and ExportName |
1175 | * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); |
1176 | * |
1177 | * // The following code is OK, assuming the macro |
1178 | * // is in a different namespace from xyz |
1179 | * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); |
1180 | * |
1181 | * \endcode |
1182 | */ |
1183 | #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ |
1184 | extern "C" { \ |
1185 | TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
1186 | int* out_type_code, void* resource_handle) { \ |
1187 | try { \ |
1188 | auto f = Function; \ |
1189 | using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \ |
1190 | ::tvm::runtime::TVMRetValue rv; \ |
1191 | ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \ |
1192 | f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ |
1193 | rv.MoveToCHost(out_value, out_type_code); \ |
1194 | return 0; \ |
1195 | } catch (const ::std::exception& _except_) { \ |
1196 | TVMAPISetLastError(_except_.what()); \ |
1197 | return -1; \ |
1198 | } \ |
1199 | } \ |
1200 | } |
1201 | |
1202 | inline TVMArgValue TVMArgs::operator[](int i) const { |
1203 | ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" |
1204 | << " but request arg[" << i << "]." ; |
1205 | return TVMArgValue(values[i], type_codes[i]); |
1206 | } |
1207 | |
1208 | inline int TVMArgs::size() const { return num_args; } |
1209 | |
1210 | template <class TPackedFuncSubObj> |
1211 | void PackedFuncObj::Extractor<TPackedFuncSubObj>::(const PackedFuncObj* obj, TVMArgs args, |
1212 | TVMRetValue* rv) { |
1213 | (static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv); |
1214 | } |
1215 | |
1216 | TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { |
1217 | (*f_call_packed_)(this, args, rv); |
1218 | } |
1219 | |
1220 | TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { |
1221 | (static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv); |
1222 | } |
1223 | |
1224 | // internal namespace |
1225 | inline const char* ArgTypeCode2Str(int type_code) { |
1226 | switch (type_code) { |
1227 | case kDLInt: |
1228 | return "int" ; |
1229 | case kDLUInt: |
1230 | return "uint" ; |
1231 | case kDLFloat: |
1232 | return "float" ; |
1233 | case kTVMStr: |
1234 | return "str" ; |
1235 | case kTVMBytes: |
1236 | return "bytes" ; |
1237 | case kTVMOpaqueHandle: |
1238 | return "handle" ; |
1239 | case kTVMNullptr: |
1240 | return "NULL" ; |
1241 | case kTVMDLTensorHandle: |
1242 | return "ArrayHandle" ; |
1243 | case kTVMDataType: |
1244 | return "DLDataType" ; |
1245 | case kDLDevice: |
1246 | return "DLDevice" ; |
1247 | case kTVMPackedFuncHandle: |
1248 | return "FunctionHandle" ; |
1249 | case kTVMModuleHandle: |
1250 | return "ModuleHandle" ; |
1251 | case kTVMNDArrayHandle: |
1252 | return "NDArrayContainer" ; |
1253 | case kTVMObjectHandle: |
1254 | return "Object" ; |
1255 | case kTVMObjectRValueRefArg: |
1256 | return "ObjectRValueRefArg" ; |
1257 | default: |
1258 | LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code); |
1259 | } |
1260 | } |
1261 | |
1262 | namespace detail { |
1263 | |
1264 | template <bool stop, std::size_t I, typename F> |
1265 | struct for_each_dispatcher { |
1266 | template <typename T, typename... Args> |
1267 | static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) |
1268 | f(I, std::forward<T>(value)); |
1269 | for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...); |
1270 | } |
1271 | }; |
1272 | |
1273 | template <std::size_t I, typename F> |
1274 | struct for_each_dispatcher<true, I, F> { |
1275 | static void run(const F& f) {} // NOLINT(*) |
1276 | }; |
1277 | |
1278 | template <typename F, typename... Args> |
1279 | inline void for_each(const F& f, Args&&... args) { // NOLINT(*) |
1280 | for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...); |
1281 | } |
1282 | |
1283 | namespace parameter_pack { |
1284 | |
1285 | template <typename... EnumArgs> |
1286 | struct EnumeratedParamPack { |
1287 | struct Invoke { |
1288 | template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams> |
1289 | static void F(ExtraParams&&... ) { |
1290 | using TExpander = int[]; |
1291 | (void)TExpander{ |
1292 | 0, |
1293 | (Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)..., |
1294 | }; |
1295 | } |
1296 | }; |
1297 | }; |
1298 | |
1299 | template <typename... Args> |
1300 | struct EnumerateImpl { |
1301 | private: |
1302 | template <size_t _i, typename _T> |
1303 | struct Item { |
1304 | static const constexpr size_t i = _i; |
1305 | using T = _T; |
1306 | }; |
1307 | |
1308 | template <typename...> |
1309 | struct Zipper; |
1310 | |
1311 | template <std::size_t... id> |
1312 | struct Zipper<std::integer_sequence<std::size_t, id...>> { |
1313 | using T = EnumeratedParamPack<Item<id, Args>...>; |
1314 | }; |
1315 | |
1316 | public: |
1317 | using T = typename Zipper<std::index_sequence_for<Args...>>::T; |
1318 | }; |
1319 | |
1320 | template <typename... Args> |
1321 | using Enumerate = typename EnumerateImpl<Args...>::T; |
1322 | |
1323 | template <typename... Args> |
1324 | struct ParamPack { |
1325 | template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams> |
1326 | static void InvokeWithoutArg(ExtraParams&&... ) { |
1327 | Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>( |
1328 | std::forward<ExtraParams>(extra_params)...); |
1329 | } |
1330 | }; |
1331 | |
1332 | } // namespace parameter_pack |
1333 | |
1334 | /*! |
1335 | * \brief Template class to get function signature of a function or functor. |
1336 | * \tparam T The function/functor type. |
1337 | */ |
1338 | template <typename T> |
1339 | struct func_signature_helper { |
1340 | using FType = void; |
1341 | }; |
1342 | |
1343 | template <typename T, typename R, typename... Args> |
1344 | struct func_signature_helper<R (T::*)(Args...)> { |
1345 | using FType = R(Args...); |
1346 | using ParamType = parameter_pack::ParamPack<Args...>; |
1347 | using RetType = R; |
1348 | static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference" ); |
1349 | }; |
1350 | |
1351 | template <typename T, typename R, typename... Args> |
1352 | struct func_signature_helper<R (T::*)(Args...) const> { |
1353 | using FType = R(Args...); |
1354 | using ParamType = parameter_pack::ParamPack<Args...>; |
1355 | using RetType = R; |
1356 | static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference" ); |
1357 | }; |
1358 | |
1359 | /*! |
1360 | * \brief Template class to get function signature of a function or functor. |
1361 | * \tparam T The function/functor type. |
1362 | */ |
1363 | template <typename T> |
1364 | struct function_signature { |
1365 | using FType = typename func_signature_helper<decltype(&T::operator())>::FType; |
1366 | using ParamType = typename func_signature_helper<decltype(&T::operator())>::ParamType; |
1367 | using RetType = typename func_signature_helper<decltype(&T::operator())>::RetType; |
1368 | }; |
1369 | |
1370 | // handle case of function. |
1371 | template <typename R, typename... Args> |
1372 | struct function_signature<R(Args...)> { |
1373 | using FType = R(Args...); |
1374 | using ParamType = parameter_pack::ParamPack<Args...>; |
1375 | using RetType = R; |
1376 | static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference" ); |
1377 | }; |
1378 | |
1379 | // handle case of function ptr. |
1380 | template <typename R, typename... Args> |
1381 | struct function_signature<R (*)(Args...)> { |
1382 | using FType = R(Args...); |
1383 | using ParamType = detail::parameter_pack::ParamPack<Args...>; |
1384 | using RetType = R; |
1385 | static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference" ); |
1386 | }; |
1387 | |
1388 | template <typename TSignature> |
1389 | struct SignaturePrinter; |
1390 | |
1391 | namespace type2str { |
1392 | |
1393 | template <typename T> |
1394 | struct TypeSimplifier; |
1395 | |
1396 | template <typename T> |
1397 | struct Type2Str { |
1398 | template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>> |
1399 | static std::string v() { |
1400 | return T::ContainerType::_type_key; |
1401 | } |
1402 | }; |
1403 | template <> |
1404 | struct Type2Str<int> { |
1405 | static std::string v() { return "int" ; } |
1406 | }; |
1407 | template <> |
1408 | struct Type2Str<double> { |
1409 | static std::string v() { return "double" ; } |
1410 | }; |
1411 | template <> |
1412 | struct Type2Str<int64_t> { |
1413 | static std::string v() { return "int64_t" ; } |
1414 | }; |
1415 | template <> |
1416 | struct Type2Str<uint64_t> { |
1417 | static std::string v() { return "uint64_t" ; } |
1418 | }; |
1419 | template <> |
1420 | struct Type2Str<bool> { |
1421 | static std::string v() { return "bool" ; } |
1422 | }; |
1423 | template <> |
1424 | struct Type2Str<void> { |
1425 | static std::string v() { return "void" ; } |
1426 | }; |
1427 | template <> |
1428 | struct Type2Str<std::basic_string<char>> { |
1429 | static std::string v() { return "basic_string<char>" ; } |
1430 | }; |
1431 | template <typename K, typename V> |
1432 | struct Type2Str<Map<K, V>> { |
1433 | static std::string v() { |
1434 | return "Map<" + TypeSimplifier<K>::v() + ", " + TypeSimplifier<V>::v() + ">" ; |
1435 | } |
1436 | }; |
1437 | template <> |
1438 | struct Type2Str<DLDevice> { |
1439 | static std::string v() { return "DLDevice" ; } |
1440 | }; |
1441 | template <> |
1442 | struct Type2Str<DLTensor> { |
1443 | static std::string v() { return "DLTensor" ; } |
1444 | }; |
1445 | template <> |
1446 | struct Type2Str<DataType> { |
1447 | static std::string v() { return "DataType" ; } |
1448 | }; |
1449 | template <> |
1450 | struct Type2Str<DLDataType> { |
1451 | static std::string v() { return "DLDataType" ; } |
1452 | }; |
1453 | template <> |
1454 | struct Type2Str<TVMRetValue> { |
1455 | static std::string v() { return "TVMRetValue" ; } |
1456 | }; |
1457 | template <> |
1458 | struct Type2Str<TVMArgValue> { |
1459 | static std::string v() { return "TVMArgValue" ; } |
1460 | }; |
1461 | template <typename FType> |
1462 | struct Type2Str<TypedPackedFunc<FType>> { |
1463 | static std::string v() { return SignaturePrinter<function_signature<FType>>::F(); } |
1464 | }; |
1465 | template <typename T> |
1466 | struct Type2Str<Array<T>> { |
1467 | static std::string v() { return "Array<" + TypeSimplifier<T>::v() + ">" ; } |
1468 | }; |
1469 | |
1470 | /*! |
1471 | * \brief Template class to remove const, pointer and reference of original type. |
1472 | * \tparam T The original type. |
1473 | */ |
1474 | template <typename T> |
1475 | struct TypeSimplifier { |
1476 | static std::string v() { |
1477 | using U = typename std::remove_cv< |
1478 | typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type; |
1479 | return (std::is_const<T>::value ? "const " : "" ) + Type2Str<U>::v() + |
1480 | (std::is_pointer<T>::value ? "*" : "" ) + (std::is_reference<T>::value ? "&" : "" ); |
1481 | } |
1482 | }; |
1483 | |
1484 | } // namespace type2str |
1485 | |
1486 | /*! |
1487 | * \brief Template class to generate static function outputting signature of a function or functor. |
1488 | * \tparam TSignature The function/functor signature type generated by `function_signature`. |
1489 | */ |
1490 | template <typename TSignature> |
1491 | struct SignaturePrinter { |
1492 | using ParamType = typename TSignature::ParamType; |
1493 | using RetType = typename TSignature::RetType; |
1494 | |
1495 | template <size_t i, typename TArgument> |
1496 | struct PrintParamType { |
1497 | static void F(std::ostream& os) { |
1498 | os << (i == 0 ? "" : ", " ) << i << ": " << type2str::TypeSimplifier<TArgument>::v(); |
1499 | } |
1500 | }; |
1501 | |
1502 | static std::string F() { |
1503 | std::ostringstream oss; |
1504 | oss << "(" ; |
1505 | ParamType::template InvokeWithoutArg<PrintParamType>(oss); |
1506 | oss << ") -> " << type2str::TypeSimplifier<RetType>::v(); |
1507 | return oss.str(); |
1508 | } |
1509 | }; |
1510 | } // namespace detail |
1511 | |
1512 | /* \brief argument settter to PackedFunc */ |
1513 | class TVMArgsSetter { |
1514 | public: |
1515 | TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} |
1516 | // setters for POD types |
1517 | template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type> |
1518 | TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { |
1519 | values_[i].v_int64 = static_cast<int64_t>(value); |
1520 | type_codes_[i] = kDLInt; |
1521 | } |
1522 | TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { |
1523 | values_[i].v_int64 = static_cast<int64_t>(value); |
1524 | ICHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); |
1525 | type_codes_[i] = kDLInt; |
1526 | } |
1527 | TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { |
1528 | values_[i].v_float64 = value; |
1529 | type_codes_[i] = kDLFloat; |
1530 | } |
1531 | TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { |
1532 | values_[i].v_handle = value; |
1533 | type_codes_[i] = kTVMNullptr; |
1534 | } |
1535 | TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { |
1536 | values_[i] = value.value_; |
1537 | type_codes_[i] = value.type_code_; |
1538 | } |
1539 | TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { |
1540 | values_[i].v_handle = value; |
1541 | type_codes_[i] = kTVMOpaqueHandle; |
1542 | } |
1543 | TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { |
1544 | values_[i].v_handle = value; |
1545 | type_codes_[i] = kTVMDLTensorHandle; |
1546 | } |
1547 | TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const { |
1548 | values_[i].v_device = value; |
1549 | type_codes_[i] = kDLDevice; |
1550 | } |
1551 | TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { |
1552 | values_[i].v_type = value; |
1553 | type_codes_[i] = kTVMDataType; |
1554 | } |
1555 | TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { |
1556 | operator()(i, dtype.operator DLDataType()); |
1557 | } |
1558 | TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { |
1559 | values_[i].v_str = value; |
1560 | type_codes_[i] = kTVMStr; |
1561 | } |
1562 | // setters for container types |
1563 | TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { |
1564 | values_[i].v_str = value.c_str(); |
1565 | type_codes_[i] = kTVMStr; |
1566 | } |
1567 | TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { |
1568 | values_[i].v_handle = const_cast<TVMByteArray*>(&value); |
1569 | type_codes_[i] = kTVMBytes; |
1570 | } |
1571 | template <typename FType> |
1572 | TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const { |
1573 | operator()(i, value.packed()); |
1574 | } |
1575 | void operator()(size_t i, const TVMRetValue& value) const { |
1576 | if (value.type_code() == kTVMStr) { |
1577 | values_[i].v_str = value.ptr<std::string>()->c_str(); |
1578 | type_codes_[i] = kTVMStr; |
1579 | } else { |
1580 | ICHECK_NE(value.type_code(), kTVMBytes) << "not handled." ; |
1581 | values_[i] = value.value_; |
1582 | type_codes_[i] = value.type_code(); |
1583 | } |
1584 | } |
1585 | // ObjectRef handling |
1586 | template <typename TObjectRef, |
1587 | typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
1588 | TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { |
1589 | this->SetObject(i, value); |
1590 | } |
1591 | |
1592 | template <typename TObjectRef, |
1593 | typename = typename std::enable_if<std::is_base_of< |
1594 | ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type> |
1595 | TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { |
1596 | this->SetObject(i, std::forward<TObjectRef>(value)); |
1597 | } |
1598 | |
1599 | private: |
1600 | template <typename TObjectRef> |
1601 | inline void SetObject(size_t i, TObjectRef&& value) const; |
1602 | /*! \brief The values fields */ |
1603 | TVMValue* values_; |
1604 | /*! \brief The type code fields */ |
1605 | int* type_codes_; |
1606 | }; |
1607 | |
1608 | template <typename... Args> |
1609 | inline TVMRetValue PackedFunc::operator()(Args&&... args) const { |
1610 | const int kNumArgs = sizeof...(Args); |
1611 | const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; |
1612 | TVMValue values[kArraySize]; |
1613 | int type_codes[kArraySize]; |
1614 | detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...); |
1615 | TVMRetValue rv; |
1616 | (static_cast<PackedFuncObj*>(data_.get())) |
1617 | ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv); |
1618 | return rv; |
1619 | } |
1620 | |
1621 | namespace detail { |
1622 | template <typename R, int nleft, int index, typename F> |
1623 | struct unpack_call_dispatcher { |
1624 | template <typename... Args> |
1625 | TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f, |
1626 | const TVMArgs& args_pack, TVMRetValue* rv, |
1627 | Args&&... unpacked_args) { |
1628 | // construct a movable argument value |
1629 | // which allows potential move of argument to the input of F. |
1630 | unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run( |
1631 | optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)..., |
1632 | TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index, |
1633 | optional_name, f_sig)); |
1634 | } |
1635 | }; |
1636 | |
1637 | template <typename R, int index, typename F> |
1638 | struct unpack_call_dispatcher<R, 0, index, F> { |
1639 | template <typename... Args> |
1640 | TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f, |
1641 | const TVMArgs& args_pack, TVMRetValue* rv, |
1642 | Args&&... unpacked_args) { |
1643 | using RetType = decltype(f(std::forward<Args>(unpacked_args)...)); |
1644 | if (std::is_same<RetType, R>::value) { |
1645 | *rv = f(std::forward<Args>(unpacked_args)...); |
1646 | } else { |
1647 | *rv = R(f(std::forward<Args>(unpacked_args)...)); |
1648 | } |
1649 | } |
1650 | }; |
1651 | |
1652 | template <int index, typename F> |
1653 | struct unpack_call_dispatcher<void, 0, index, F> { |
1654 | template <typename... Args> |
1655 | TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f, |
1656 | const TVMArgs& args_pack, TVMRetValue* rv, |
1657 | Args&&... unpacked_args) { |
1658 | f(std::forward<Args>(unpacked_args)...); |
1659 | } |
1660 | }; |
1661 | |
1662 | template <typename R, int nargs, typename F> |
1663 | TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f, |
1664 | const TVMArgs& args, TVMRetValue* rv) { |
1665 | FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F; |
1666 | CHECK_EQ(nargs, args.size()) << "Function " |
1667 | << (optional_name == nullptr ? "<anonymous>" : *optional_name) |
1668 | << (f_sig == nullptr ? "" : (*f_sig)()) << " expects " << nargs |
1669 | << " arguments but " << args.size() << " were provided" ; |
1670 | unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv); |
1671 | } |
1672 | |
1673 | template <typename FType> |
1674 | struct unpack_call_by_signature {}; |
1675 | |
1676 | template <typename R, typename... Args> |
1677 | struct unpack_call_by_signature<R(Args...)> { |
1678 | template <typename F> |
1679 | TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { |
1680 | unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv); |
1681 | } |
1682 | }; |
1683 | |
1684 | template <typename R, typename... Args> |
1685 | TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) { |
1686 | return R(pf(std::forward<Args>(args)...)); |
1687 | } |
1688 | |
1689 | template <typename R> |
1690 | struct typed_packed_call_dispatcher { |
1691 | template <typename... Args> |
1692 | TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) { |
1693 | return pf(std::forward<Args>(args)...); |
1694 | } |
1695 | }; |
1696 | |
1697 | template <> |
1698 | struct typed_packed_call_dispatcher<void> { |
1699 | template <typename... Args> |
1700 | TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) { |
1701 | pf(std::forward<Args>(args)...); |
1702 | } |
1703 | }; |
1704 | } // namespace detail |
1705 | |
1706 | template <typename R, typename... Args> |
1707 | TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {} |
1708 | |
1709 | template <typename R, typename... Args> |
1710 | TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value) |
1711 | : packed_(value.operator PackedFunc()) {} |
1712 | |
1713 | template <typename R, typename... Args> |
1714 | TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value) |
1715 | : packed_(value.operator PackedFunc()) {} |
1716 | |
1717 | template <typename R, typename... Args> |
1718 | TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValueWithContext_&& value) |
1719 | : packed_(value.operator PackedFunc()) {} |
1720 | |
1721 | template <typename R, typename... Args> |
1722 | template <typename FType> |
1723 | inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) { |
1724 | FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F; |
1725 | packed_ = PackedFunc([flambda, name, f_sig](const TVMArgs& args, TVMRetValue* rv) { |
1726 | if (args.size() != sizeof...(Args)) { |
1727 | LOG(FATAL) << "Function " << name << (f_sig == nullptr ? "" : (*f_sig)()) << " expects " |
1728 | << sizeof...(Args) << " arguments, but " << args.size() << " were provided." ; |
1729 | } |
1730 | detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv); |
1731 | }); |
1732 | } |
1733 | |
1734 | template <typename R, typename... Args> |
1735 | template <typename FType> |
1736 | inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { |
1737 | FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F; |
1738 | packed_ = PackedFunc([flambda, f_sig](const TVMArgs& args, TVMRetValue* rv) { |
1739 | if (args.size() != sizeof...(Args)) { |
1740 | LOG(FATAL) << "Function <anonymous> " << (*f_sig)() << " expects " << sizeof...(Args) |
1741 | << " arguments, but " << args.size() << " were provided." ; |
1742 | } |
1743 | detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv); |
1744 | }); |
1745 | } |
1746 | |
1747 | template <typename R, typename... Args> |
1748 | TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { |
1749 | return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...); |
1750 | } |
1751 | |
1752 | // ObjectRef related conversion handling |
1753 | // Object can have three possible type codes: |
1754 | // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle |
1755 | // |
1756 | // We use type traits to eliminate un-necessary checks. |
1757 | template <typename T> |
1758 | inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { |
1759 | using ContainerType = typename std::remove_reference<T>::type::ContainerType; |
1760 | if (value.defined()) { |
1761 | Object* ptr = value.data_.data_; |
1762 | if (std::is_base_of<NDArray::ContainerType, ContainerType>::value || |
1763 | (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
1764 | ptr->IsInstance<NDArray::ContainerType>())) { |
1765 | values_[i].v_handle = NDArray::FFIGetHandle(value); |
1766 | type_codes_[i] = kTVMNDArrayHandle; |
1767 | } else if (std::is_base_of<Module::ContainerType, ContainerType>::value || |
1768 | (std::is_base_of<ContainerType, Module::ContainerType>::value && |
1769 | ptr->IsInstance<Module::ContainerType>())) { |
1770 | values_[i].v_handle = ptr; |
1771 | type_codes_[i] = kTVMModuleHandle; |
1772 | } else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value || |
1773 | (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value && |
1774 | ptr->IsInstance<PackedFunc::ContainerType>())) { |
1775 | values_[i].v_handle = ptr; |
1776 | type_codes_[i] = kTVMPackedFuncHandle; |
1777 | } else if (std::is_rvalue_reference<decltype(value)>::value) { |
1778 | values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); |
1779 | type_codes_[i] = kTVMObjectRValueRefArg; |
1780 | } else { |
1781 | values_[i].v_handle = value.data_.data_; |
1782 | type_codes_[i] = kTVMObjectHandle; |
1783 | } |
1784 | } else { |
1785 | type_codes_[i] = kTVMNullptr; |
1786 | values_[i].v_handle = nullptr; |
1787 | } |
1788 | } |
1789 | |
1790 | template <typename TObjectRef, typename> |
1791 | inline bool TVMPODValue_::IsObjectRef() const { |
1792 | using ContainerType = typename TObjectRef::ContainerType; |
1793 | // NOTE: the following code can be optimized by constant folding. |
1794 | if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) { |
1795 | return type_code_ == kTVMNDArrayHandle && |
1796 | TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle)) |
1797 | ->IsInstance<ContainerType>(); |
1798 | } |
1799 | if (std::is_base_of<Module::ContainerType, ContainerType>::value) { |
1800 | return type_code_ == kTVMModuleHandle && |
1801 | static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>(); |
1802 | } |
1803 | if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) { |
1804 | return type_code_ == kTVMPackedFuncHandle && |
1805 | static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>(); |
1806 | } |
1807 | // NOTE: we don't pass NDArray and runtime::Module as RValue ref. |
1808 | if (type_code_ == kTVMObjectRValueRefArg) { |
1809 | return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle)); |
1810 | } |
1811 | return (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
1812 | type_code_ == kTVMNDArrayHandle) || |
1813 | (std::is_base_of<ContainerType, Module::ContainerType>::value && |
1814 | type_code_ == kTVMModuleHandle) || |
1815 | (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value && |
1816 | type_code_ == kTVMPackedFuncHandle) || |
1817 | (type_code_ == kTVMObjectHandle && |
1818 | ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle))); |
1819 | } |
1820 | |
1821 | template <typename TObjectRef> |
1822 | inline TObjectRef TVMPODValue_::AsObjectRef() const { |
1823 | static_assert(std::is_base_of<ObjectRef, TObjectRef>::value, |
1824 | "Conversion only works for ObjectRef" ); |
1825 | using ContainerType = typename TObjectRef::ContainerType; |
1826 | |
1827 | if (type_code_ == kTVMNullptr) { |
1828 | CHECK(TObjectRef::_type_is_nullable) |
1829 | << "Expect a not null value of " << ContainerType::_type_key; |
1830 | return TObjectRef(ObjectPtr<Object>(nullptr)); |
1831 | } |
1832 | // NOTE: the following code can be optimized by constant folding. |
1833 | if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) { |
1834 | // Casting to a sub-class of NDArray |
1835 | TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); |
1836 | ObjectPtr<Object> data = |
1837 | NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)); |
1838 | CHECK(data->IsInstance<ContainerType>()) |
1839 | << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); |
1840 | return TObjectRef(data); |
1841 | } |
1842 | if (std::is_base_of<Module::ContainerType, ContainerType>::value) { |
1843 | // Casting to a sub-class of Module |
1844 | TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); |
1845 | ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)); |
1846 | CHECK(data->IsInstance<ContainerType>()) |
1847 | << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); |
1848 | return TObjectRef(data); |
1849 | } |
1850 | if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) { |
1851 | // Casting to a sub-class of PackedFunc |
1852 | TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); |
1853 | ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)); |
1854 | CHECK(data->IsInstance<ContainerType>()) |
1855 | << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); |
1856 | return TObjectRef(data); |
1857 | } |
1858 | if (type_code_ == kTVMObjectHandle) { |
1859 | // normal object type check. |
1860 | Object* ptr = static_cast<Object*>(value_.v_handle); |
1861 | Optional<String> checked_type = ObjectTypeChecker<TObjectRef>::CheckAndGetMismatch(ptr); |
1862 | ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName() |
1863 | << ", but got " << checked_type.value(); |
1864 | return TObjectRef(GetObjectPtr<Object>(ptr)); |
1865 | } else if (type_code_ == kTVMObjectRValueRefArg) { |
1866 | Object* ptr = *static_cast<Object**>(value_.v_handle); |
1867 | Optional<String> checked_type = ObjectTypeChecker<TObjectRef>::CheckAndGetMismatch(ptr); |
1868 | ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName() |
1869 | << ", but got " << checked_type.value(); |
1870 | return TObjectRef(GetObjectPtr<Object>(ptr)); |
1871 | } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
1872 | type_code_ == kTVMNDArrayHandle) { |
1873 | // Casting to a base class that NDArray can sub-class |
1874 | ObjectPtr<Object> data = |
1875 | NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)); |
1876 | return TObjectRef(data); |
1877 | } else if (std::is_base_of<ContainerType, Module::ContainerType>::value && |
1878 | type_code_ == kTVMModuleHandle) { |
1879 | // Casting to a base class that Module can sub-class |
1880 | return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
1881 | } else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value && |
1882 | type_code_ == kTVMPackedFuncHandle) { |
1883 | // Casting to a base class that PackedFunc can sub-class |
1884 | return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
1885 | } else { |
1886 | TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); |
1887 | return TObjectRef(ObjectPtr<Object>(nullptr)); |
1888 | } |
1889 | } |
1890 | |
1891 | template <typename TObjectRef, typename> |
1892 | inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { |
1893 | using ContainerType = typename TObjectRef::ContainerType; |
1894 | const Object* ptr = other.get(); |
1895 | if (ptr != nullptr) { |
1896 | if (std::is_base_of<NDArray::ContainerType, ContainerType>::value || |
1897 | (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
1898 | ptr->IsInstance<NDArray::ContainerType>())) { |
1899 | return operator=(NDArray(std::move(other.data_))); |
1900 | } |
1901 | if (std::is_base_of<Module::ContainerType, ContainerType>::value || |
1902 | (std::is_base_of<ContainerType, Module::ContainerType>::value && |
1903 | ptr->IsInstance<Module::ContainerType>())) { |
1904 | return operator=(Module(std::move(other.data_))); |
1905 | } |
1906 | if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value || |
1907 | (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value && |
1908 | ptr->IsInstance<PackedFunc::ContainerType>())) { |
1909 | return operator=(PackedFunc(std::move(other.data_))); |
1910 | } |
1911 | SwitchToObject(kTVMObjectHandle, std::move(other.data_)); |
1912 | } else { |
1913 | SwitchToPOD(kTVMNullptr); |
1914 | value_.v_handle = nullptr; |
1915 | } |
1916 | return *this; |
1917 | } |
1918 | |
1919 | template <typename T, typename> |
1920 | inline TVMArgValue::operator T() const { |
1921 | return PackedFuncValueConverter<T>::From(*this); |
1922 | } |
1923 | |
1924 | template <typename T, typename> |
1925 | inline TVMMovableArgValue_::operator T() const { |
1926 | if (type_code_ == kTVMObjectRValueRefArg) { |
1927 | auto** ref = static_cast<Object**>(value_.v_handle); |
1928 | if (ObjectTypeChecker<T>::Check(*ref)) { |
1929 | return T(ObjectPtr<Object>::MoveFromRValueRefArg(ref)); |
1930 | } |
1931 | } |
1932 | // fallback |
1933 | return PackedFuncValueConverter<T>::From(AsArgValue()); |
1934 | } |
1935 | |
1936 | template <typename T, typename> |
1937 | inline TVMRetValue::operator T() const { |
1938 | return PackedFuncValueConverter<T>::From(*this); |
1939 | } |
1940 | |
1941 | inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { |
1942 | return (*this)->GetFunction(name, query_imports); |
1943 | } |
1944 | |
1945 | // specializations of PackedFuncValueConverter |
1946 | template <> |
1947 | struct PackedFuncValueConverter<::tvm::runtime::String> { |
1948 | static String From(const TVMArgValue& val) { |
1949 | if (val.IsObjectRef<tvm::runtime::String>()) { |
1950 | return val.AsObjectRef<tvm::runtime::String>(); |
1951 | } else { |
1952 | return tvm::runtime::String(val.operator std::string()); |
1953 | } |
1954 | } |
1955 | |
1956 | static String From(const TVMRetValue& val) { |
1957 | if (val.IsObjectRef<tvm::runtime::String>()) { |
1958 | return val.AsObjectRef<tvm::runtime::String>(); |
1959 | } else { |
1960 | return tvm::runtime::String(val.operator std::string()); |
1961 | } |
1962 | } |
1963 | }; |
1964 | |
1965 | template <typename T> |
1966 | struct PackedFuncValueConverter<Optional<T>> { |
1967 | static Optional<T> From(const TVMArgValue& val) { |
1968 | if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr); |
1969 | return PackedFuncValueConverter<T>::From(val); |
1970 | } |
1971 | static Optional<T> From(const TVMRetValue& val) { |
1972 | if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr); |
1973 | return PackedFuncValueConverter<T>::From(val); |
1974 | } |
1975 | }; |
1976 | |
1977 | inline bool String::CanConvertFrom(const TVMArgValue& val) { |
1978 | return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>(); |
1979 | } |
1980 | |
1981 | inline TVMArgValue::operator DLDataType() const { |
1982 | if (String::CanConvertFrom(*this)) { |
1983 | return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string()); |
1984 | } |
1985 | // None type |
1986 | if (type_code_ == kTVMNullptr) { |
1987 | DLDataType t; |
1988 | t.code = kTVMOpaqueHandle; |
1989 | t.bits = 0; |
1990 | t.lanes = 0; |
1991 | return t; |
1992 | } |
1993 | TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); |
1994 | return value_.v_type; |
1995 | } |
1996 | |
1997 | inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } |
1998 | |
1999 | } // namespace runtime |
2000 | } // namespace tvm |
2001 | #endif // TVM_RUNTIME_PACKED_FUNC_H_ |
2002 | |