1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/packed_func.h
22 * \brief Type-erased function used across TVM API.
23 */
24#ifndef TVM_RUNTIME_PACKED_FUNC_H_
25#define TVM_RUNTIME_PACKED_FUNC_H_
26
27#include <tvm/runtime/c_runtime_api.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/container/map.h>
30#include <tvm/runtime/data_type.h>
31#include <tvm/runtime/logging.h>
32#include <tvm/runtime/module.h>
33#include <tvm/runtime/ndarray.h>
34#include <tvm/runtime/object.h>
35
36#include <functional>
37#include <limits>
38#include <memory>
39#include <string>
40#include <tuple>
41#include <type_traits>
42#include <utility>
43#include <vector>
44
45// Whether use TVM runtime in header only mode.
46#ifndef TVM_RUNTIME_HEADER_ONLY
47#define TVM_RUNTIME_HEADER_ONLY 0
48#endif
49
50namespace tvm {
51namespace runtime {
52
53// forward declarations
54class TVMArgs;
55class TVMArgValue;
56class TVMMovableArgValueWithContext_;
57class TVMRetValue;
58class TVMArgsSetter;
59template <typename FType>
60class TypedPackedFunc;
61template <typename TSignature>
62struct SignaturePrinter;
63
64/*!
65 * \brief Object container class that backs PackedFunc.
66 * \note Do not use this function directly, use PackedFunc.
67 */
68class PackedFuncObj : public Object {
69 public:
70 /*!
71 * \brief Call the function in packed format.
72 * \param args The arguments
73 * \param rv The return value.
74 */
75 TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
76
77 static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc;
78 static constexpr const char* _type_key = "runtime.PackedFunc";
79 TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object);
80
81 protected:
82 /*!
83 * \brief Internal struct for extracting the callable method from callable type.
84 */
85 template <class TPackedFuncSubObj>
86 struct Extractor {
87 /*!
88 * \brief Extracting the callable method from callable type.
89 * \param obj The base packed function object class.
90 * \param args The arguments
91 * \param rv The return value.
92 */
93 static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv);
94 };
95
96 /*! \brief The internal callable function type. */
97 using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*);
98
99 /*!
100 * \brief Constructing a packed function object from a function pointer.
101 * \param f_call_pack The function pointer used to call the packed function.
102 */
103 explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {}
104
105 /*! \brief Delete the default constructor explicitly. */
106 PackedFuncObj() = delete;
107
108 /*! \brief Internal callable function pointer used to call the packed function. */
109 FCallPacked* f_call_packed_;
110};
111
112/*! \brief Derived object class for constructing PackedFuncObj. */
113template <class TCallable>
114class PackedFuncSubObj : public PackedFuncObj {
115 using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
116
117 public:
118 /*! \brief The type of derived object class */
119 using TSelf = PackedFuncSubObj<TCallable>;
120 /*!
121 * \brief Derived object class for constructing PackedFuncObj.
122 * \param callable The type-erased callable object.
123 */
124 explicit PackedFuncSubObj(TCallable callable)
125 : PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {}
126 /*! \brief Type-erased filed for storing callable object*/
127 mutable TStorage callable_;
128};
129
130/*!
131 * \brief Packed function is a type-erased function.
132 * The arguments are passed by packed format.
133 *
134 * This is an useful unified interface to call generated functions,
135 * It is the unified function function type of TVM.
136 * It corresponds to TVMFunctionHandle in C runtime API.
137 */
138class PackedFunc : public ObjectRef {
139 public:
140 /*! \brief Constructor from null */
141 PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*)
142 /*!
143 * \brief Constructing a packed function from a callable type
144 * whose signature is consistent with `PackedFunc`
145 * \param data the internal container of packed function.
146 */
147 template <typename TCallable,
148 typename = std::enable_if_t<
149 std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
150 !std::is_base_of<TCallable, PackedFunc>::value>>
151 explicit PackedFunc(TCallable data) {
152 using ObjType = PackedFuncSubObj<TCallable>;
153 data_ = make_object<ObjType>(std::forward<TCallable>(data));
154 }
155 /*!
156 * \brief Call packed function by directly passing in unpacked format.
157 * \param args Arguments to be passed.
158 * \tparam Args arguments to be passed.
159 *
160 * \code
161 * // Example code on how to call packed function
162 * void CallPacked(PackedFunc f) {
163 * // call like normal functions by pass in arguments
164 * // return value is automatically converted back
165 * int rvalue = f(1, 2.0);
166 * }
167 * \endcode
168 */
169 template <typename... Args>
170 inline TVMRetValue operator()(Args&&... args) const;
171 /*!
172 * \brief Call the function in packed format.
173 * \param args The arguments
174 * \param rv The return value.
175 */
176 TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
177 /*! \return Whether the packed function is nullptr */
178 bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
179 /*! \return Whether the packed function is not nullptr */
180 bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
181
182 TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj);
183};
184
185/*! \brief Using static function to output TypedPackedFunc signature */
186using FSig = std::string();
187
188/*!
189 * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
190 */
191template <typename FType>
192class TypedPackedFunc;
193
194/*!
195 * \anchor TypedPackedFuncAnchor
196 * \brief A PackedFunc wrapper to provide typed function signature.
197 * It is backed by a PackedFunc internally.
198 *
199 * TypedPackedFunc enables compile time type checking.
200 * TypedPackedFunc works with the runtime system:
201 * - It can be passed as an argument of PackedFunc.
202 * - It can be assigned to TVMRetValue.
203 * - It can be directly converted to a type-erased PackedFunc.
204 *
205 * Developers should prefer TypedPackedFunc over PackedFunc in C++ code
206 * as it enables compile time checking.
207 * We can construct a TypedPackedFunc from a lambda function
208 * with the same signature.
209 *
210 * \code
211 * // user defined lambda function.
212 * auto addone = [](int x)->int {
213 * return x + 1;
214 * };
215 * // We can directly convert
216 * // lambda function to TypedPackedFunc
217 * TypedPackedFunc<int(int)> ftyped(addone);
218 * // invoke the function.
219 * int y = ftyped(1);
220 * // Can be directly converted to PackedFunc
221 * PackedFunc packed = ftype;
222 * \endcode
223 * \tparam R The return value of the function.
224 * \tparam Args The argument signature of the function.
225 */
226template <typename R, typename... Args>
227class TypedPackedFunc<R(Args...)> {
228 public:
229 /*! \brief short hand for this function type */
230 using TSelf = TypedPackedFunc<R(Args...)>;
231 /*! \brief default constructor */
232 TypedPackedFunc() {}
233 /*! \brief constructor from null */
234 TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
235 /*!
236 * \brief construct by wrap a PackedFunc
237 *
238 * Example usage:
239 * \code
240 * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) {
241 * int x = args[0];
242 * *rv = x + 1;
243 * });
244 * // construct from packed function
245 * TypedPackedFunc<int(int)> ftyped(packed);
246 * // call the typed version.
247 * ICHECK_EQ(ftyped(1), 2);
248 * \endcode
249 *
250 * \param packed The packed function
251 */
252 inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
253 /*!
254 * \brief constructor from TVMRetValue
255 * \param value The TVMRetValue
256 */
257 inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
258 /*!
259 * \brief constructor from TVMArgValue
260 * \param value The TVMArgValue
261 */
262 inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
263 /*!
264 * \brief constructor from TVMMovableArgValue_
265 * \param value The TVMMovableArgValue_
266 */
267 inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*)
268 /*!
269 * \brief construct from a lambda function with the same signature.
270 *
271 * Example usage:
272 * \code
273 * auto typed_lambda = [](int x)->int { return x + 1; }
274 * // construct from packed function
275 * TypedPackedFunc<int(int)> ftyped(typed_lambda, "add_one");
276 * // call the typed version.
277 * ICHECK_EQ(ftyped(1), 2);
278 * \endcode
279 *
280 * \param typed_lambda typed lambda function.
281 * \param name the name of the lambda function.
282 * \tparam FLambda the type of the lambda function.
283 */
284 template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
285 FLambda, std::function<R(Args...)>>::value>::type>
286 TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*)
287 this->AssignTypedLambda(typed_lambda, name);
288 }
289 /*!
290 * \brief construct from a lambda function with the same signature.
291 *
292 * This version does not take a name. It is highly recommend you use the
293 * version that takes a name for the lambda.
294 *
295 * Example usage:
296 * \code
297 * auto typed_lambda = [](int x)->int { return x + 1; }
298 * // construct from packed function
299 * TypedPackedFunc<int(int)> ftyped(typed_lambda);
300 * // call the typed version.
301 * ICHECK_EQ(ftyped(1), 2);
302 * \endcode
303 *
304 * \param typed_lambda typed lambda function.
305 * \tparam FLambda the type of the lambda function.
306 */
307 template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
308 FLambda, std::function<R(Args...)>>::value>::type>
309 TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
310 this->AssignTypedLambda(typed_lambda);
311 }
312 /*!
313 * \brief copy assignment operator from typed lambda
314 *
315 * Example usage:
316 * \code
317 * // construct from packed function
318 * TypedPackedFunc<int(int)> ftyped;
319 * ftyped = [](int x) { return x + 1; }
320 * // call the typed version.
321 * ICHECK_EQ(ftyped(1), 2);
322 * \endcode
323 *
324 * \param typed_lambda typed lambda function.
325 * \tparam FLambda the type of the lambda function.
326 * \returns reference to self.
327 */
328 template <typename FLambda, typename = typename std::enable_if<
329 std::is_convertible<FLambda,
330 std::function<R(Args...)>>::value>::type>
331 TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
332 this->AssignTypedLambda(typed_lambda);
333 return *this;
334 }
335 /*!
336 * \brief copy assignment operator from PackedFunc.
337 * \param packed The packed function.
338 * \returns reference to self.
339 */
340 TSelf& operator=(PackedFunc packed) {
341 packed_ = packed;
342 return *this;
343 }
344 /*!
345 * \brief Invoke the operator.
346 * \param args The arguments
347 * \returns The return value.
348 */
349 TVM_ALWAYS_INLINE R operator()(Args... args) const;
350 /*!
351 * \brief convert to PackedFunc
352 * \return the internal PackedFunc
353 */
354 operator PackedFunc() const { return packed(); }
355 /*!
356 * \return reference the internal PackedFunc
357 */
358 const PackedFunc& packed() const { return packed_; }
359 /*! \return Whether the packed function is nullptr */
360 bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
361 /*! \return Whether the packed function is not nullptr */
362 bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
363
364 private:
365 friend class TVMRetValue;
366 /*! \brief The internal packed function */
367 PackedFunc packed_;
368 /*!
369 * \brief Assign the packed field using a typed lambda function.
370 *
371 * \param flambda The lambda function.
372 * \param name The name associated with this lambda.
373 * \tparam FLambda The lambda function type.
374 * \note We capture the lambda when possible for maximum efficiency.
375 */
376 template <typename FLambda>
377 inline void AssignTypedLambda(FLambda flambda, std::string name);
378 /*!
379 * \brief Assign the packed field using a typed lambda function. This variant is for functions
380 * without names.
381 *
382 * \param flambda The lambda function.
383 * \tparam FLambda The lambda function type.
384 * \note We capture the lambda when possible for maximum efficiency.
385 */
386 template <typename FLambda>
387 inline void AssignTypedLambda(FLambda flambda);
388};
389
390/*! \brief Arguments into TVM functions. */
391class TVMArgs {
392 public:
393 const TVMValue* values;
394 const int* type_codes;
395 int num_args;
396 /*!
397 * \brief constructor
398 * \param values The argument values
399 * \param type_codes The argument type codes
400 * \param num_args number of arguments.
401 */
402 TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
403 : values(values), type_codes(type_codes), num_args(num_args) {}
404 /*! \return size of the arguments */
405 inline int size() const;
406 /*!
407 * \brief Get i-th argument
408 * \param i the index.
409 * \return the ith argument.
410 */
411 inline TVMArgValue operator[](int i) const;
412};
413
414/*!
415 * \brief Convert argument type code to string.
416 * \param type_code The input type code.
417 * \return The corresponding string repr.
418 */
419inline const char* ArgTypeCode2Str(int type_code);
420
421// macro to check type code.
422#define TVM_CHECK_TYPE_CODE(CODE, T) \
423 ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
424
425/*!
426 * \brief Type traits for runtime type check during FFI conversion.
427 * \tparam T the type to be checked.
428 */
429template <typename T>
430struct ObjectTypeChecker {
431 /*!
432 * \brief Check if an object matches the template type and return the
433 * mismatched type if it exists.
434 * \param ptr The object to check the type of.
435 * \return An Optional containing the actual type of the pointer if it does not match the
436 * template type. If the Optional does not contain a value, then the types match.
437 */
438 static Optional<String> CheckAndGetMismatch(const Object* ptr) {
439 using ContainerType = typename T::ContainerType;
440 if (ptr == nullptr) {
441 if (T::_type_is_nullable) {
442 return NullOpt;
443 } else {
444 return String("nullptr");
445 }
446 }
447 if (ptr->IsInstance<ContainerType>()) {
448 return NullOpt;
449 } else {
450 return String(ptr->GetTypeKey());
451 }
452 }
453 /*!
454 * \brief Check if an object matches the template type.
455 * \param ptr The object to check the type of.
456 * \return Whether or not the template type matches the objects type.
457 */
458 static bool Check(const Object* ptr) {
459 using ContainerType = typename T::ContainerType;
460 if (ptr == nullptr) return T::_type_is_nullable;
461 return ptr->IsInstance<ContainerType>();
462 }
463 static std::string TypeName() {
464 using ContainerType = typename T::ContainerType;
465 return ContainerType::_type_key;
466 }
467};
468
469// Additional overloads for PackedFunc checking.
470template <typename T>
471struct ObjectTypeChecker<Array<T>> {
472 static Optional<String> CheckAndGetMismatch(const Object* ptr) {
473 if (ptr == nullptr) {
474 return NullOpt;
475 }
476 if (!ptr->IsInstance<ArrayNode>()) {
477 return String(ptr->GetTypeKey());
478 }
479 const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
480 for (size_t i = 0; i < n->size(); i++) {
481 const ObjectRef& p = (*n)[i];
482 Optional<String> check_subtype = ObjectTypeChecker<T>::CheckAndGetMismatch(p.get());
483 if (check_subtype.defined()) {
484 return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]");
485 }
486 }
487 return NullOpt;
488 }
489 static bool Check(const Object* ptr) {
490 if (ptr == nullptr) return true;
491 if (!ptr->IsInstance<ArrayNode>()) return false;
492 const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
493 for (const ObjectRef& p : *n) {
494 if (!ObjectTypeChecker<T>::Check(p.get())) {
495 return false;
496 }
497 }
498 return true;
499 }
500 static std::string TypeName() { return "Array[" + ObjectTypeChecker<T>::TypeName() + "]"; }
501};
502template <typename K, typename V>
503struct ObjectTypeChecker<Map<K, V>> {
504 static Optional<String> CheckAndGetMismatch(const Object* ptr) {
505 if (ptr == nullptr) return NullOpt;
506 if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey());
507 const MapNode* n = static_cast<const MapNode*>(ptr);
508 for (const auto& kv : *n) {
509 Optional<String> key_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
510 Optional<String> value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
511 if (key_type.defined() || value_type.defined()) {
512 std::string key_name =
513 key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName();
514 std::string value_name = value_type.defined() ? std::string(value_type.value())
515 : ObjectTypeChecker<V>::TypeName();
516 return String("Map[" + key_name + ", " + value_name + "]");
517 }
518 }
519 return NullOpt;
520 }
521 static bool Check(const Object* ptr) {
522 if (ptr == nullptr) return true;
523 if (!ptr->IsInstance<MapNode>()) return false;
524 const MapNode* n = static_cast<const MapNode*>(ptr);
525 for (const auto& kv : *n) {
526 if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
527 if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
528 }
529 return true;
530 }
531 static std::string TypeName() {
532 return "Map[" + ObjectTypeChecker<K>::TypeName() + ", " + ObjectTypeChecker<V>::TypeName() +
533 ']';
534 }
535};
536
537/*!
538 * \brief Internal base class to
539 * handle conversion to POD values.
540 */
541class TVMPODValue_ {
542 public:
543 operator double() const {
544 // Allow automatic conversion from int to float
545 // This avoids errors when user pass in int from
546 // the frontend while the API expects a float.
547 if (type_code_ == kDLInt) {
548 return static_cast<double>(value_.v_int64);
549 }
550 TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
551 return value_.v_float64;
552 }
553 operator int64_t() const {
554 TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
555 return value_.v_int64;
556 }
557 operator uint64_t() const {
558 TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
559 return value_.v_int64;
560 }
561 operator int() const {
562 TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
563 ICHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
564 ICHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
565 return static_cast<int>(value_.v_int64);
566 }
567 operator bool() const {
568 TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
569 return value_.v_int64 != 0;
570 }
571 operator void*() const {
572 if (type_code_ == kTVMNullptr) return nullptr;
573 if (type_code_ == kTVMDLTensorHandle) return value_.v_handle;
574 TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle);
575 return value_.v_handle;
576 }
577 operator DLTensor*() const {
578 if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) {
579 return static_cast<DLTensor*>(value_.v_handle);
580 } else {
581 if (type_code_ == kTVMNullptr) return nullptr;
582 LOG(FATAL) << "Expected "
583 << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_);
584 return nullptr;
585 }
586 }
587 operator NDArray() const {
588 if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr));
589 TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
590 return NDArray(NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)));
591 }
592 operator Module() const {
593 if (type_code_ == kTVMNullptr) {
594 return Module(ObjectPtr<Object>(nullptr));
595 }
596 TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
597 return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
598 }
599 operator PackedFunc() const {
600 if (type_code_ == kTVMNullptr) {
601 return PackedFunc(ObjectPtr<Object>(nullptr));
602 }
603 TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
604 return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
605 }
606 operator Device() const {
607 TVM_CHECK_TYPE_CODE(type_code_, kDLDevice);
608 return value_.v_device;
609 }
610 int type_code() const { return type_code_; }
611 /*!
612 * \brief return handle as specific pointer type.
613 * \tparam T the data type.
614 * \return The pointer type.
615 */
616 template <typename T>
617 T* ptr() const {
618 return static_cast<T*>(value_.v_handle);
619 }
620 // ObjectRef handling
621 template <typename TObjectRef,
622 typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
623 inline bool IsObjectRef() const;
624 template <typename TObjectRef>
625 inline TObjectRef AsObjectRef() const;
626
627 protected:
628 friend class TVMArgsSetter;
629 friend class TVMRetValue;
630 friend class TVMMovableArgValue_;
631 TVMPODValue_() : type_code_(kTVMNullptr) {}
632 TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {}
633
634 /*! \brief The value */
635 TVMValue value_;
636 /*! \brief the type code */
637 int type_code_;
638};
639
640/*!
641 * \brief A single argument value to PackedFunc.
642 * Containing both type_code and TVMValue
643 *
644 * Provides utilities to do type cast into other types.
645 */
646class TVMArgValue : public TVMPODValue_ {
647 public:
648 /*! \brief default constructor */
649 TVMArgValue() {}
650 /*!
651 * \brief constructor
652 * \param value of the function
653 * \param type_code The type code.
654 */
655 TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
656 // reuse converter from parent
657 using TVMPODValue_::operator double;
658 using TVMPODValue_::operator int64_t;
659 using TVMPODValue_::operator uint64_t;
660 using TVMPODValue_::operator int;
661 using TVMPODValue_::operator bool;
662 using TVMPODValue_::operator void*;
663 using TVMPODValue_::operator DLTensor*;
664 using TVMPODValue_::operator NDArray;
665 using TVMPODValue_::operator Device;
666 using TVMPODValue_::operator Module;
667 using TVMPODValue_::operator PackedFunc;
668 using TVMPODValue_::AsObjectRef;
669 using TVMPODValue_::IsObjectRef;
670
671 // conversion operator.
672 operator std::string() const {
673 if (type_code_ == kTVMDataType) {
674 return DLDataType2String(operator DLDataType());
675 } else if (type_code_ == kTVMBytes) {
676 TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
677 return std::string(arr->data, arr->size);
678 } else if (type_code_ == kTVMStr) {
679 return std::string(value_.v_str);
680 } else {
681 ICHECK(IsObjectRef<tvm::runtime::String>())
682 << "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
683 << " to a string.";
684 return AsObjectRef<tvm::runtime::String>().operator std::string();
685 }
686 }
687 template <typename FType>
688 operator TypedPackedFunc<FType>() const {
689 return TypedPackedFunc<FType>(operator PackedFunc());
690 }
691 const TVMValue& value() const { return value_; }
692
693 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
694 inline operator T() const;
695 inline operator DLDataType() const;
696 inline operator DataType() const;
697};
698
699/*!
700 * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
701 *
702 * We can only construct a movable argument once from a single argument position.
703 * If the argument is passed as RValue reference, the result will be moved.
704 * We should only construct a MovableArg from an argument once,
705 * as the result will can moved.
706 *
707 * \note For internal development purpose only.
708 */
709class TVMMovableArgValue_ : public TVMPODValue_ {
710 public:
711 TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
712 // reuse converter from parent
713 using TVMPODValue_::operator double;
714 using TVMPODValue_::operator int64_t;
715 using TVMPODValue_::operator uint64_t;
716 using TVMPODValue_::operator int;
717 using TVMPODValue_::operator bool;
718 using TVMPODValue_::operator void*;
719 using TVMPODValue_::operator DLTensor*;
720 using TVMPODValue_::operator NDArray;
721 using TVMPODValue_::operator Device;
722 using TVMPODValue_::operator Module;
723 using TVMPODValue_::operator PackedFunc;
724 // reuse conversion rule from ArgValue.
725 operator std::string() const { return AsArgValue().operator std::string(); }
726 template <typename FType>
727 operator TypedPackedFunc<FType>() const {
728 return TypedPackedFunc<FType>(operator PackedFunc());
729 }
730 operator DLDataType() const { return AsArgValue().operator DLDataType(); }
731 operator DataType() const { return AsArgValue().operator DataType(); }
732 operator TVMArgValue() const { return AsArgValue(); }
733 /*!
734 * \brief Helper converter function.
735 * Try to move out an argument if possible,
736 * fall back to normal argument conversion rule otherwise.
737 */
738 template <typename T,
739 typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
740 inline operator T() const;
741
742 private:
743 /*! \return The arg value repr of the value. */
744 TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); }
745};
746
747/*!
748 * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with
749 * additional context information (function name and argument index) for better error reporting.
750 *
751 * \sa MovableArgValue_
752 * \note For internal development purpose only.
753 */
754class TVMMovableArgValueWithContext_ {
755 public:
756 /*!
757 * \brief move constructor from another return value.
758 * \param value The other return value.
759 * \param type_code The code associated with the type of the value.
760 * \param arg_index In a function call, this argument is at index arg_index (0-indexed).
761 * \param optional_name Name of the function being called. Can be nullptr if the function is not.
762 * \param f_sig Pointer to static function outputting signature of the function being called.
763 * named.
764 */
765 TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index,
766 const std::string* optional_name, FSig* f_sig)
767 : value_(value, type_code),
768 arg_index_(arg_index),
769 optional_name_(optional_name),
770 f_sig_(f_sig) {}
771
772 template <typename T>
773 operator T() const {
774 try {
775 return value_; // implicit conversion happens here
776 } catch (dmlc::Error& e) {
777 LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_)
778 << (f_sig_ == nullptr ? "" : (*f_sig_)()) << ": error while converting argument "
779 << arg_index_ << ": " << e.what();
780 throw; // never reached, LOG(FATAL) throws, but this silences a warning.
781 }
782 }
783
784 private:
785 TVMMovableArgValue_ value_;
786 int arg_index_;
787 const std::string* optional_name_;
788 FSig* f_sig_;
789};
790
791/*!
792 * \brief Return Value container,
793 * Unlike TVMArgValue, which only holds reference and do not delete
794 * the underlying container during destruction.
795 *
796 * TVMRetValue holds value and will manage the underlying containers
797 * when it stores a complicated data type.
798 */
799class TVMRetValue : public TVMPODValue_ {
800 public:
801 /*! \brief default constructor */
802 TVMRetValue() {}
803 /*!
804 * \brief move constructor from another return value.
805 * \param other The other return value.
806 */
807 TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) {
808 other.value_.v_handle = nullptr;
809 other.type_code_ = kTVMNullptr;
810 }
811 /*! \brief destructor */
812 ~TVMRetValue() { this->Clear(); }
813 // reuse converter from parent
814 using TVMPODValue_::operator double;
815 using TVMPODValue_::operator int64_t;
816 using TVMPODValue_::operator uint64_t;
817 using TVMPODValue_::operator int;
818 using TVMPODValue_::operator bool;
819 using TVMPODValue_::operator void*;
820 using TVMPODValue_::operator DLTensor*;
821 using TVMPODValue_::operator Device;
822 using TVMPODValue_::operator NDArray;
823 using TVMPODValue_::operator Module;
824 using TVMPODValue_::operator PackedFunc;
825 using TVMPODValue_::AsObjectRef;
826 using TVMPODValue_::IsObjectRef;
827
828 TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); }
829 // conversion operators
830 operator std::string() const {
831 if (type_code_ == kTVMDataType) {
832 return DLDataType2String(operator DLDataType());
833 } else if (type_code_ == kTVMBytes) {
834 return *ptr<std::string>();
835 }
836 TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
837 return *ptr<std::string>();
838 }
839 operator DLDataType() const {
840 if (type_code_ == kTVMStr) {
841 return String2DLDataType(operator std::string());
842 }
843 TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
844 return value_.v_type;
845 }
846 operator DataType() const { return DataType(operator DLDataType()); }
847 template <typename FType>
848 operator TypedPackedFunc<FType>() const {
849 return TypedPackedFunc<FType>(operator PackedFunc());
850 }
851 // Assign operators
852 TVMRetValue& operator=(TVMRetValue&& other) {
853 this->Clear();
854 value_ = other.value_;
855 type_code_ = other.type_code_;
856 other.type_code_ = kTVMNullptr;
857 return *this;
858 }
859 TVMRetValue& operator=(double value) {
860 this->SwitchToPOD(kDLFloat);
861 value_.v_float64 = value;
862 return *this;
863 }
864 TVMRetValue& operator=(std::nullptr_t value) {
865 this->SwitchToPOD(kTVMNullptr);
866 value_.v_handle = value;
867 return *this;
868 }
869 TVMRetValue& operator=(void* value) {
870 this->SwitchToPOD(kTVMOpaqueHandle);
871 value_.v_handle = value;
872 return *this;
873 }
874 TVMRetValue& operator=(int64_t value) {
875 this->SwitchToPOD(kDLInt);
876 value_.v_int64 = value;
877 return *this;
878 }
879 TVMRetValue& operator=(int value) {
880 this->SwitchToPOD(kDLInt);
881 value_.v_int64 = value;
882 return *this;
883 }
884 TVMRetValue& operator=(DLDevice value) {
885 this->SwitchToPOD(kDLDevice);
886 value_.v_device = value;
887 return *this;
888 }
889 TVMRetValue& operator=(DLDataType t) {
890 this->SwitchToPOD(kTVMDataType);
891 value_.v_type = t;
892 return *this;
893 }
894 TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
895 TVMRetValue& operator=(bool value) {
896 this->SwitchToPOD(kDLInt);
897 value_.v_int64 = value;
898 return *this;
899 }
900 TVMRetValue& operator=(std::string value) {
901 this->SwitchToClass(kTVMStr, value);
902 return *this;
903 }
904 TVMRetValue& operator=(TVMByteArray value) {
905 this->SwitchToClass(kTVMBytes, std::string(value.data, value.size));
906 return *this;
907 }
908 TVMRetValue& operator=(NDArray other) {
909 if (other.data_ != nullptr) {
910 this->Clear();
911 type_code_ = kTVMNDArrayHandle;
912 value_.v_handle = NDArray::FFIGetHandle(other);
913 ObjectRef::FFIClearAfterMove(&other);
914 } else {
915 SwitchToPOD(kTVMNullptr);
916 value_.v_handle = nullptr;
917 }
918 return *this;
919 }
920 TVMRetValue& operator=(Module m) {
921 SwitchToObject(kTVMModuleHandle, std::move(m.data_));
922 return *this;
923 }
924 TVMRetValue& operator=(PackedFunc f) {
925 this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_));
926 return *this;
927 }
928 template <typename FType>
929 TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
930 return operator=(f.packed());
931 }
932 TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
933 this->Assign(other);
934 return *this;
935 }
936 TVMRetValue& operator=(const TVMArgValue& other) {
937 this->Assign(other);
938 return *this;
939 }
940 TVMRetValue& operator=(TVMMovableArgValue_&& other) {
941 this->Assign(other);
942 return *this;
943 }
944 /*!
945 * \brief Move the value back to front-end via C API.
946 * This marks the current container as null.
947 * The managed resources are moved to the front-end.
948 * The front end should take charge in managing them.
949 *
950 * \param ret_value The return value.
951 * \param ret_type_code The return type code.
952 */
953 void MoveToCHost(TVMValue* ret_value, int* ret_type_code) {
954 // cannot move str; need specially handle.
955 ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes);
956 *ret_value = value_;
957 *ret_type_code = type_code_;
958 type_code_ = kTVMNullptr;
959 }
960 /*!
961 * \brief Construct a new TVMRetValue by
962 * moving from return value stored via C API.
963 * \param value the value.
964 * \param type_code The type code.
965 * \return The created TVMRetValue.
966 */
967 static TVMRetValue MoveFromCHost(TVMValue value, int type_code) {
968 // Can move POD and everything under the object system.
969 ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle);
970 TVMRetValue ret;
971 ret.value_ = value;
972 ret.type_code_ = type_code;
973 return ret;
974 }
975 /*! \return The value field, if the data is POD */
976 const TVMValue& value() const {
977 ICHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle &&
978 type_code_ != kTVMModuleHandle && type_code_ != kTVMStr)
979 << "TVMRetValue.value can only be used for POD data";
980 return value_;
981 }
982 // ObjectRef handling
983 template <typename TObjectRef,
984 typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
985 inline TVMRetValue& operator=(TObjectRef other);
986 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
987 inline operator T() const;
988
989 private:
990 template <typename T>
991 void Assign(const T& other) {
992 switch (other.type_code()) {
993 case kTVMStr: {
994 SwitchToClass<std::string>(kTVMStr, other);
995 break;
996 }
997 case kTVMBytes: {
998 SwitchToClass<std::string>(kTVMBytes, other);
999 break;
1000 }
1001 case kTVMPackedFuncHandle: {
1002 *this = other.operator PackedFunc();
1003 break;
1004 }
1005 case kTVMModuleHandle: {
1006 *this = other.operator Module();
1007 break;
1008 }
1009 case kTVMNDArrayHandle: {
1010 *this = other.operator NDArray();
1011 break;
1012 }
1013 case kTVMObjectHandle: {
1014 // Avoid operator ObjectRef as we already know it is not NDArray/Module
1015 SwitchToObject(kTVMObjectHandle,
1016 GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
1017 break;
1018 }
1019 case kTVMObjectRValueRefArg: {
1020 operator=(other.operator ObjectRef());
1021 break;
1022 }
1023 default: {
1024 SwitchToPOD(other.type_code());
1025 value_ = other.value_;
1026 break;
1027 }
1028 }
1029 }
1030 // get the internal container.
1031 void SwitchToPOD(int type_code) {
1032 if (type_code_ != type_code) {
1033 this->Clear();
1034 type_code_ = type_code;
1035 }
1036 }
1037 template <typename T>
1038 void SwitchToClass(int type_code, T v) {
1039 if (type_code_ != type_code) {
1040 this->Clear();
1041 type_code_ = type_code;
1042 value_.v_handle = new T(v);
1043 } else {
1044 *static_cast<T*>(value_.v_handle) = v;
1045 }
1046 }
1047 void SwitchToObject(int type_code, ObjectPtr<Object> other) {
1048 if (other.data_ != nullptr) {
1049 this->Clear();
1050 type_code_ = type_code;
1051 // move the handle out
1052 value_.v_handle = other.data_;
1053 other.data_ = nullptr;
1054 } else {
1055 SwitchToPOD(kTVMNullptr);
1056 value_.v_handle = nullptr;
1057 }
1058 }
1059 void Clear() {
1060 if (type_code_ == kTVMNullptr) return;
1061 switch (type_code_) {
1062 case kTVMStr:
1063 case kTVMBytes:
1064 delete ptr<std::string>();
1065 break;
1066 case kTVMPackedFuncHandle:
1067 static_cast<Object*>(value_.v_handle)->DecRef();
1068 break;
1069 case kTVMNDArrayHandle: {
1070 NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
1071 break;
1072 }
1073 case kTVMModuleHandle: {
1074 static_cast<Object*>(value_.v_handle)->DecRef();
1075 break;
1076 }
1077 case kTVMObjectHandle: {
1078 static_cast<Object*>(value_.v_handle)->DecRef();
1079 break;
1080 }
1081 }
1082 type_code_ = kTVMNullptr;
1083 }
1084};
1085
1086/*!
1087 * \brief Type trait to specify special value conversion rules from
1088 * TVMArgValue and TVMRetValue.
1089 *
1090 * The trait can be specialized to add type specific conversion logic
1091 * from the TVMArgvalue and TVMRetValue.
1092 *
1093 * \tparam TObjectRef the specific ObjectRefType.
1094 */
1095template <typename TObjectRef>
1096struct PackedFuncValueConverter {
1097 /*!
1098 * \brief Convert a TObjectRef from an argument value.
1099 * \param val The argument value.
1100 * \return the converted result.
1101 */
1102 static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
1103 /*!
1104 * \brief Convert a TObjectRef from a return value.
1105 * \param val The argument value.
1106 * \return the converted result.
1107 */
1108 static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
1109};
1110
1111/*!
1112 * \brief Export a function with the PackedFunc signature
1113 * as a PackedFunc that can be loaded by LibraryModule.
1114 *
1115 * \param ExportName The symbol name to be exported.
1116 * \param Function The function with PackedFunc signature.
1117 * \sa PackedFunc
1118 *
1119 * \code
1120 *
1121 * void AddOne_(TVMArgs args, TVMRetValue* rv) {
1122 * int value = args[0];
1123 * *rv = value + 1;
1124 * }
1125 * // Expose the function as "AddOne"
1126 * TVM_DLL_EXPORT_PACKED_FUNC(AddOne, AddOne_);
1127 *
1128 * \endcode
1129 */
1130#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
1131 extern "C" { \
1132 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1133 int* out_type_code, void* resource_handle); \
1134 int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1135 int* out_type_code, void* resource_handle) { \
1136 try { \
1137 ::tvm::runtime::TVMRetValue rv; \
1138 Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1139 rv.MoveToCHost(out_value, out_type_code); \
1140 return 0; \
1141 } catch (const ::std::exception& _except_) { \
1142 TVMAPISetLastError(_except_.what()); \
1143 return -1; \
1144 } \
1145 } \
1146 }
1147
1148/*!
1149 * \brief Export typed function as a PackedFunc
1150 * that can be loaded by LibraryModule.
1151 *
1152 * \param ExportName The symbol name to be exported.
1153 * \param Function The typed function.
1154 * \note ExportName and Function must be different,
1155 * see code examples below.
1156 *
1157 * \sa TypedPackedFunc
1158 *
1159 * \code
1160 *
1161 * int AddOne_(int x) {
1162 * return x + 1;
1163 * }
1164 *
1165 * // Expose the function as "AddOne"
1166 * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_);
1167 *
1168 * // Expose the function as "SubOne"
1169 * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) {
1170 * return x - 1;
1171 * });
1172 *
1173 * // The following code will cause compilation error.
1174 * // Because the same Function and ExportName
1175 * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_);
1176 *
1177 * // The following code is OK, assuming the macro
1178 * // is in a different namespace from xyz
1179 * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_);
1180 *
1181 * \endcode
1182 */
1183#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
1184 extern "C" { \
1185 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1186 int* out_type_code, void* resource_handle) { \
1187 try { \
1188 auto f = Function; \
1189 using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1190 ::tvm::runtime::TVMRetValue rv; \
1191 ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
1192 f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1193 rv.MoveToCHost(out_value, out_type_code); \
1194 return 0; \
1195 } catch (const ::std::exception& _except_) { \
1196 TVMAPISetLastError(_except_.what()); \
1197 return -1; \
1198 } \
1199 } \
1200 }
1201
1202inline TVMArgValue TVMArgs::operator[](int i) const {
1203 ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
1204 << " but request arg[" << i << "].";
1205 return TVMArgValue(values[i], type_codes[i]);
1206}
1207
1208inline int TVMArgs::size() const { return num_args; }
1209
1210template <class TPackedFuncSubObj>
1211void PackedFuncObj::Extractor<TPackedFuncSubObj>::Call(const PackedFuncObj* obj, TVMArgs args,
1212 TVMRetValue* rv) {
1213 (static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv);
1214}
1215
1216TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1217 (*f_call_packed_)(this, args, rv);
1218}
1219
1220TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1221 (static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv);
1222}
1223
1224// internal namespace
1225inline const char* ArgTypeCode2Str(int type_code) {
1226 switch (type_code) {
1227 case kDLInt:
1228 return "int";
1229 case kDLUInt:
1230 return "uint";
1231 case kDLFloat:
1232 return "float";
1233 case kTVMStr:
1234 return "str";
1235 case kTVMBytes:
1236 return "bytes";
1237 case kTVMOpaqueHandle:
1238 return "handle";
1239 case kTVMNullptr:
1240 return "NULL";
1241 case kTVMDLTensorHandle:
1242 return "ArrayHandle";
1243 case kTVMDataType:
1244 return "DLDataType";
1245 case kDLDevice:
1246 return "DLDevice";
1247 case kTVMPackedFuncHandle:
1248 return "FunctionHandle";
1249 case kTVMModuleHandle:
1250 return "ModuleHandle";
1251 case kTVMNDArrayHandle:
1252 return "NDArrayContainer";
1253 case kTVMObjectHandle:
1254 return "Object";
1255 case kTVMObjectRValueRefArg:
1256 return "ObjectRValueRefArg";
1257 default:
1258 LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
1259 }
1260}
1261
1262namespace detail {
1263
1264template <bool stop, std::size_t I, typename F>
1265struct for_each_dispatcher {
1266 template <typename T, typename... Args>
1267 static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
1268 f(I, std::forward<T>(value));
1269 for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1270 }
1271};
1272
1273template <std::size_t I, typename F>
1274struct for_each_dispatcher<true, I, F> {
1275 static void run(const F& f) {} // NOLINT(*)
1276};
1277
1278template <typename F, typename... Args>
1279inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
1280 for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1281}
1282
1283namespace parameter_pack {
1284
1285template <typename... EnumArgs>
1286struct EnumeratedParamPack {
1287 struct Invoke {
1288 template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
1289 static void F(ExtraParams&&... extra_params) {
1290 using TExpander = int[];
1291 (void)TExpander{
1292 0,
1293 (Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
1294 };
1295 }
1296 };
1297};
1298
1299template <typename... Args>
1300struct EnumerateImpl {
1301 private:
1302 template <size_t _i, typename _T>
1303 struct Item {
1304 static const constexpr size_t i = _i;
1305 using T = _T;
1306 };
1307
1308 template <typename...>
1309 struct Zipper;
1310
1311 template <std::size_t... id>
1312 struct Zipper<std::integer_sequence<std::size_t, id...>> {
1313 using T = EnumeratedParamPack<Item<id, Args>...>;
1314 };
1315
1316 public:
1317 using T = typename Zipper<std::index_sequence_for<Args...>>::T;
1318};
1319
1320template <typename... Args>
1321using Enumerate = typename EnumerateImpl<Args...>::T;
1322
1323template <typename... Args>
1324struct ParamPack {
1325 template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
1326 static void InvokeWithoutArg(ExtraParams&&... extra_params) {
1327 Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
1328 std::forward<ExtraParams>(extra_params)...);
1329 }
1330};
1331
1332} // namespace parameter_pack
1333
1334/*!
1335 * \brief Template class to get function signature of a function or functor.
1336 * \tparam T The function/functor type.
1337 */
1338template <typename T>
1339struct func_signature_helper {
1340 using FType = void;
1341};
1342
1343template <typename T, typename R, typename... Args>
1344struct func_signature_helper<R (T::*)(Args...)> {
1345 using FType = R(Args...);
1346 using ParamType = parameter_pack::ParamPack<Args...>;
1347 using RetType = R;
1348 static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1349};
1350
1351template <typename T, typename R, typename... Args>
1352struct func_signature_helper<R (T::*)(Args...) const> {
1353 using FType = R(Args...);
1354 using ParamType = parameter_pack::ParamPack<Args...>;
1355 using RetType = R;
1356 static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1357};
1358
1359/*!
1360 * \brief Template class to get function signature of a function or functor.
1361 * \tparam T The function/functor type.
1362 */
1363template <typename T>
1364struct function_signature {
1365 using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
1366 using ParamType = typename func_signature_helper<decltype(&T::operator())>::ParamType;
1367 using RetType = typename func_signature_helper<decltype(&T::operator())>::RetType;
1368};
1369
1370// handle case of function.
1371template <typename R, typename... Args>
1372struct function_signature<R(Args...)> {
1373 using FType = R(Args...);
1374 using ParamType = parameter_pack::ParamPack<Args...>;
1375 using RetType = R;
1376 static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1377};
1378
1379// handle case of function ptr.
1380template <typename R, typename... Args>
1381struct function_signature<R (*)(Args...)> {
1382 using FType = R(Args...);
1383 using ParamType = detail::parameter_pack::ParamPack<Args...>;
1384 using RetType = R;
1385 static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1386};
1387
1388template <typename TSignature>
1389struct SignaturePrinter;
1390
1391namespace type2str {
1392
1393template <typename T>
1394struct TypeSimplifier;
1395
1396template <typename T>
1397struct Type2Str {
1398 template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1399 static std::string v() {
1400 return T::ContainerType::_type_key;
1401 }
1402};
1403template <>
1404struct Type2Str<int> {
1405 static std::string v() { return "int"; }
1406};
1407template <>
1408struct Type2Str<double> {
1409 static std::string v() { return "double"; }
1410};
1411template <>
1412struct Type2Str<int64_t> {
1413 static std::string v() { return "int64_t"; }
1414};
1415template <>
1416struct Type2Str<uint64_t> {
1417 static std::string v() { return "uint64_t"; }
1418};
1419template <>
1420struct Type2Str<bool> {
1421 static std::string v() { return "bool"; }
1422};
1423template <>
1424struct Type2Str<void> {
1425 static std::string v() { return "void"; }
1426};
1427template <>
1428struct Type2Str<std::basic_string<char>> {
1429 static std::string v() { return "basic_string<char>"; }
1430};
1431template <typename K, typename V>
1432struct Type2Str<Map<K, V>> {
1433 static std::string v() {
1434 return "Map<" + TypeSimplifier<K>::v() + ", " + TypeSimplifier<V>::v() + ">";
1435 }
1436};
1437template <>
1438struct Type2Str<DLDevice> {
1439 static std::string v() { return "DLDevice"; }
1440};
1441template <>
1442struct Type2Str<DLTensor> {
1443 static std::string v() { return "DLTensor"; }
1444};
1445template <>
1446struct Type2Str<DataType> {
1447 static std::string v() { return "DataType"; }
1448};
1449template <>
1450struct Type2Str<DLDataType> {
1451 static std::string v() { return "DLDataType"; }
1452};
1453template <>
1454struct Type2Str<TVMRetValue> {
1455 static std::string v() { return "TVMRetValue"; }
1456};
1457template <>
1458struct Type2Str<TVMArgValue> {
1459 static std::string v() { return "TVMArgValue"; }
1460};
1461template <typename FType>
1462struct Type2Str<TypedPackedFunc<FType>> {
1463 static std::string v() { return SignaturePrinter<function_signature<FType>>::F(); }
1464};
1465template <typename T>
1466struct Type2Str<Array<T>> {
1467 static std::string v() { return "Array<" + TypeSimplifier<T>::v() + ">"; }
1468};
1469
1470/*!
1471 * \brief Template class to remove const, pointer and reference of original type.
1472 * \tparam T The original type.
1473 */
1474template <typename T>
1475struct TypeSimplifier {
1476 static std::string v() {
1477 using U = typename std::remove_cv<
1478 typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1479 return (std::is_const<T>::value ? "const " : "") + Type2Str<U>::v() +
1480 (std::is_pointer<T>::value ? "*" : "") + (std::is_reference<T>::value ? "&" : "");
1481 }
1482};
1483
1484} // namespace type2str
1485
1486/*!
1487 * \brief Template class to generate static function outputting signature of a function or functor.
1488 * \tparam TSignature The function/functor signature type generated by `function_signature`.
1489 */
1490template <typename TSignature>
1491struct SignaturePrinter {
1492 using ParamType = typename TSignature::ParamType;
1493 using RetType = typename TSignature::RetType;
1494
1495 template <size_t i, typename TArgument>
1496 struct PrintParamType {
1497 static void F(std::ostream& os) {
1498 os << (i == 0 ? "" : ", ") << i << ": " << type2str::TypeSimplifier<TArgument>::v();
1499 }
1500 };
1501
1502 static std::string F() {
1503 std::ostringstream oss;
1504 oss << "(";
1505 ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1506 oss << ") -> " << type2str::TypeSimplifier<RetType>::v();
1507 return oss.str();
1508 }
1509};
1510} // namespace detail
1511
1512/* \brief argument settter to PackedFunc */
1513class TVMArgsSetter {
1514 public:
1515 TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
1516 // setters for POD types
1517 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1518 TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
1519 values_[i].v_int64 = static_cast<int64_t>(value);
1520 type_codes_[i] = kDLInt;
1521 }
1522 TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
1523 values_[i].v_int64 = static_cast<int64_t>(value);
1524 ICHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1525 type_codes_[i] = kDLInt;
1526 }
1527 TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
1528 values_[i].v_float64 = value;
1529 type_codes_[i] = kDLFloat;
1530 }
1531 TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
1532 values_[i].v_handle = value;
1533 type_codes_[i] = kTVMNullptr;
1534 }
1535 TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
1536 values_[i] = value.value_;
1537 type_codes_[i] = value.type_code_;
1538 }
1539 TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
1540 values_[i].v_handle = value;
1541 type_codes_[i] = kTVMOpaqueHandle;
1542 }
1543 TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
1544 values_[i].v_handle = value;
1545 type_codes_[i] = kTVMDLTensorHandle;
1546 }
1547 TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const {
1548 values_[i].v_device = value;
1549 type_codes_[i] = kDLDevice;
1550 }
1551 TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
1552 values_[i].v_type = value;
1553 type_codes_[i] = kTVMDataType;
1554 }
1555 TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
1556 operator()(i, dtype.operator DLDataType());
1557 }
1558 TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
1559 values_[i].v_str = value;
1560 type_codes_[i] = kTVMStr;
1561 }
1562 // setters for container types
1563 TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
1564 values_[i].v_str = value.c_str();
1565 type_codes_[i] = kTVMStr;
1566 }
1567 TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
1568 values_[i].v_handle = const_cast<TVMByteArray*>(&value);
1569 type_codes_[i] = kTVMBytes;
1570 }
1571 template <typename FType>
1572 TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
1573 operator()(i, value.packed());
1574 }
1575 void operator()(size_t i, const TVMRetValue& value) const {
1576 if (value.type_code() == kTVMStr) {
1577 values_[i].v_str = value.ptr<std::string>()->c_str();
1578 type_codes_[i] = kTVMStr;
1579 } else {
1580 ICHECK_NE(value.type_code(), kTVMBytes) << "not handled.";
1581 values_[i] = value.value_;
1582 type_codes_[i] = value.type_code();
1583 }
1584 }
1585 // ObjectRef handling
1586 template <typename TObjectRef,
1587 typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1588 TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
1589 this->SetObject(i, value);
1590 }
1591
1592 template <typename TObjectRef,
1593 typename = typename std::enable_if<std::is_base_of<
1594 ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type>
1595 TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
1596 this->SetObject(i, std::forward<TObjectRef>(value));
1597 }
1598
1599 private:
1600 template <typename TObjectRef>
1601 inline void SetObject(size_t i, TObjectRef&& value) const;
1602 /*! \brief The values fields */
1603 TVMValue* values_;
1604 /*! \brief The type code fields */
1605 int* type_codes_;
1606};
1607
1608template <typename... Args>
1609inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
1610 const int kNumArgs = sizeof...(Args);
1611 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1612 TVMValue values[kArraySize];
1613 int type_codes[kArraySize];
1614 detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1615 TVMRetValue rv;
1616 (static_cast<PackedFuncObj*>(data_.get()))
1617 ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
1618 return rv;
1619}
1620
1621namespace detail {
1622template <typename R, int nleft, int index, typename F>
1623struct unpack_call_dispatcher {
1624 template <typename... Args>
1625 TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1626 const TVMArgs& args_pack, TVMRetValue* rv,
1627 Args&&... unpacked_args) {
1628 // construct a movable argument value
1629 // which allows potential move of argument to the input of F.
1630 unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1631 optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1632 TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index,
1633 optional_name, f_sig));
1634 }
1635};
1636
1637template <typename R, int index, typename F>
1638struct unpack_call_dispatcher<R, 0, index, F> {
1639 template <typename... Args>
1640 TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1641 const TVMArgs& args_pack, TVMRetValue* rv,
1642 Args&&... unpacked_args) {
1643 using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1644 if (std::is_same<RetType, R>::value) {
1645 *rv = f(std::forward<Args>(unpacked_args)...);
1646 } else {
1647 *rv = R(f(std::forward<Args>(unpacked_args)...));
1648 }
1649 }
1650};
1651
1652template <int index, typename F>
1653struct unpack_call_dispatcher<void, 0, index, F> {
1654 template <typename... Args>
1655 TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1656 const TVMArgs& args_pack, TVMRetValue* rv,
1657 Args&&... unpacked_args) {
1658 f(std::forward<Args>(unpacked_args)...);
1659 }
1660};
1661
1662template <typename R, int nargs, typename F>
1663TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f,
1664 const TVMArgs& args, TVMRetValue* rv) {
1665 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1666 CHECK_EQ(nargs, args.size()) << "Function "
1667 << (optional_name == nullptr ? "<anonymous>" : *optional_name)
1668 << (f_sig == nullptr ? "" : (*f_sig)()) << " expects " << nargs
1669 << " arguments but " << args.size() << " were provided";
1670 unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
1671}
1672
1673template <typename FType>
1674struct unpack_call_by_signature {};
1675
1676template <typename R, typename... Args>
1677struct unpack_call_by_signature<R(Args...)> {
1678 template <typename F>
1679 TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
1680 unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv);
1681 }
1682};
1683
1684template <typename R, typename... Args>
1685TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) {
1686 return R(pf(std::forward<Args>(args)...));
1687}
1688
1689template <typename R>
1690struct typed_packed_call_dispatcher {
1691 template <typename... Args>
1692 TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) {
1693 return pf(std::forward<Args>(args)...);
1694 }
1695};
1696
1697template <>
1698struct typed_packed_call_dispatcher<void> {
1699 template <typename... Args>
1700 TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) {
1701 pf(std::forward<Args>(args)...);
1702 }
1703};
1704} // namespace detail
1705
1706template <typename R, typename... Args>
1707TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
1708
1709template <typename R, typename... Args>
1710TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value)
1711 : packed_(value.operator PackedFunc()) {}
1712
1713template <typename R, typename... Args>
1714TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
1715 : packed_(value.operator PackedFunc()) {}
1716
1717template <typename R, typename... Args>
1718TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValueWithContext_&& value)
1719 : packed_(value.operator PackedFunc()) {}
1720
1721template <typename R, typename... Args>
1722template <typename FType>
1723inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1724 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1725 packed_ = PackedFunc([flambda, name, f_sig](const TVMArgs& args, TVMRetValue* rv) {
1726 if (args.size() != sizeof...(Args)) {
1727 LOG(FATAL) << "Function " << name << (f_sig == nullptr ? "" : (*f_sig)()) << " expects "
1728 << sizeof...(Args) << " arguments, but " << args.size() << " were provided.";
1729 }
1730 detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv);
1731 });
1732}
1733
1734template <typename R, typename... Args>
1735template <typename FType>
1736inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
1737 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1738 packed_ = PackedFunc([flambda, f_sig](const TVMArgs& args, TVMRetValue* rv) {
1739 if (args.size() != sizeof...(Args)) {
1740 LOG(FATAL) << "Function <anonymous> " << (*f_sig)() << " expects " << sizeof...(Args)
1741 << " arguments, but " << args.size() << " were provided.";
1742 }
1743 detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv);
1744 });
1745}
1746
1747template <typename R, typename... Args>
1748TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
1749 return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1750}
1751
1752// ObjectRef related conversion handling
1753// Object can have three possible type codes:
1754// kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
1755//
1756// We use type traits to eliminate un-necessary checks.
1757template <typename T>
1758inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
1759 using ContainerType = typename std::remove_reference<T>::type::ContainerType;
1760 if (value.defined()) {
1761 Object* ptr = value.data_.data_;
1762 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1763 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1764 ptr->IsInstance<NDArray::ContainerType>())) {
1765 values_[i].v_handle = NDArray::FFIGetHandle(value);
1766 type_codes_[i] = kTVMNDArrayHandle;
1767 } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1768 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1769 ptr->IsInstance<Module::ContainerType>())) {
1770 values_[i].v_handle = ptr;
1771 type_codes_[i] = kTVMModuleHandle;
1772 } else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1773 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1774 ptr->IsInstance<PackedFunc::ContainerType>())) {
1775 values_[i].v_handle = ptr;
1776 type_codes_[i] = kTVMPackedFuncHandle;
1777 } else if (std::is_rvalue_reference<decltype(value)>::value) {
1778 values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
1779 type_codes_[i] = kTVMObjectRValueRefArg;
1780 } else {
1781 values_[i].v_handle = value.data_.data_;
1782 type_codes_[i] = kTVMObjectHandle;
1783 }
1784 } else {
1785 type_codes_[i] = kTVMNullptr;
1786 values_[i].v_handle = nullptr;
1787 }
1788}
1789
1790template <typename TObjectRef, typename>
1791inline bool TVMPODValue_::IsObjectRef() const {
1792 using ContainerType = typename TObjectRef::ContainerType;
1793 // NOTE: the following code can be optimized by constant folding.
1794 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1795 return type_code_ == kTVMNDArrayHandle &&
1796 TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle))
1797 ->IsInstance<ContainerType>();
1798 }
1799 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1800 return type_code_ == kTVMModuleHandle &&
1801 static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
1802 }
1803 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1804 return type_code_ == kTVMPackedFuncHandle &&
1805 static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
1806 }
1807 // NOTE: we don't pass NDArray and runtime::Module as RValue ref.
1808 if (type_code_ == kTVMObjectRValueRefArg) {
1809 return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
1810 }
1811 return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1812 type_code_ == kTVMNDArrayHandle) ||
1813 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1814 type_code_ == kTVMModuleHandle) ||
1815 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1816 type_code_ == kTVMPackedFuncHandle) ||
1817 (type_code_ == kTVMObjectHandle &&
1818 ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
1819}
1820
1821template <typename TObjectRef>
1822inline TObjectRef TVMPODValue_::AsObjectRef() const {
1823 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1824 "Conversion only works for ObjectRef");
1825 using ContainerType = typename TObjectRef::ContainerType;
1826
1827 if (type_code_ == kTVMNullptr) {
1828 CHECK(TObjectRef::_type_is_nullable)
1829 << "Expect a not null value of " << ContainerType::_type_key;
1830 return TObjectRef(ObjectPtr<Object>(nullptr));
1831 }
1832 // NOTE: the following code can be optimized by constant folding.
1833 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1834 // Casting to a sub-class of NDArray
1835 TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
1836 ObjectPtr<Object> data =
1837 NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
1838 CHECK(data->IsInstance<ContainerType>())
1839 << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
1840 return TObjectRef(data);
1841 }
1842 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1843 // Casting to a sub-class of Module
1844 TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
1845 ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
1846 CHECK(data->IsInstance<ContainerType>())
1847 << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
1848 return TObjectRef(data);
1849 }
1850 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1851 // Casting to a sub-class of PackedFunc
1852 TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
1853 ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
1854 CHECK(data->IsInstance<ContainerType>())
1855 << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
1856 return TObjectRef(data);
1857 }
1858 if (type_code_ == kTVMObjectHandle) {
1859 // normal object type check.
1860 Object* ptr = static_cast<Object*>(value_.v_handle);
1861 Optional<String> checked_type = ObjectTypeChecker<TObjectRef>::CheckAndGetMismatch(ptr);
1862 ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
1863 << ", but got " << checked_type.value();
1864 return TObjectRef(GetObjectPtr<Object>(ptr));
1865 } else if (type_code_ == kTVMObjectRValueRefArg) {
1866 Object* ptr = *static_cast<Object**>(value_.v_handle);
1867 Optional<String> checked_type = ObjectTypeChecker<TObjectRef>::CheckAndGetMismatch(ptr);
1868 ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
1869 << ", but got " << checked_type.value();
1870 return TObjectRef(GetObjectPtr<Object>(ptr));
1871 } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1872 type_code_ == kTVMNDArrayHandle) {
1873 // Casting to a base class that NDArray can sub-class
1874 ObjectPtr<Object> data =
1875 NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
1876 return TObjectRef(data);
1877 } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1878 type_code_ == kTVMModuleHandle) {
1879 // Casting to a base class that Module can sub-class
1880 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1881 } else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1882 type_code_ == kTVMPackedFuncHandle) {
1883 // Casting to a base class that PackedFunc can sub-class
1884 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1885 } else {
1886 TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
1887 return TObjectRef(ObjectPtr<Object>(nullptr));
1888 }
1889}
1890
1891template <typename TObjectRef, typename>
1892inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
1893 using ContainerType = typename TObjectRef::ContainerType;
1894 const Object* ptr = other.get();
1895 if (ptr != nullptr) {
1896 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1897 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1898 ptr->IsInstance<NDArray::ContainerType>())) {
1899 return operator=(NDArray(std::move(other.data_)));
1900 }
1901 if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1902 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1903 ptr->IsInstance<Module::ContainerType>())) {
1904 return operator=(Module(std::move(other.data_)));
1905 }
1906 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1907 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1908 ptr->IsInstance<PackedFunc::ContainerType>())) {
1909 return operator=(PackedFunc(std::move(other.data_)));
1910 }
1911 SwitchToObject(kTVMObjectHandle, std::move(other.data_));
1912 } else {
1913 SwitchToPOD(kTVMNullptr);
1914 value_.v_handle = nullptr;
1915 }
1916 return *this;
1917}
1918
1919template <typename T, typename>
1920inline TVMArgValue::operator T() const {
1921 return PackedFuncValueConverter<T>::From(*this);
1922}
1923
1924template <typename T, typename>
1925inline TVMMovableArgValue_::operator T() const {
1926 if (type_code_ == kTVMObjectRValueRefArg) {
1927 auto** ref = static_cast<Object**>(value_.v_handle);
1928 if (ObjectTypeChecker<T>::Check(*ref)) {
1929 return T(ObjectPtr<Object>::MoveFromRValueRefArg(ref));
1930 }
1931 }
1932 // fallback
1933 return PackedFuncValueConverter<T>::From(AsArgValue());
1934}
1935
1936template <typename T, typename>
1937inline TVMRetValue::operator T() const {
1938 return PackedFuncValueConverter<T>::From(*this);
1939}
1940
1941inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
1942 return (*this)->GetFunction(name, query_imports);
1943}
1944
1945// specializations of PackedFuncValueConverter
1946template <>
1947struct PackedFuncValueConverter<::tvm::runtime::String> {
1948 static String From(const TVMArgValue& val) {
1949 if (val.IsObjectRef<tvm::runtime::String>()) {
1950 return val.AsObjectRef<tvm::runtime::String>();
1951 } else {
1952 return tvm::runtime::String(val.operator std::string());
1953 }
1954 }
1955
1956 static String From(const TVMRetValue& val) {
1957 if (val.IsObjectRef<tvm::runtime::String>()) {
1958 return val.AsObjectRef<tvm::runtime::String>();
1959 } else {
1960 return tvm::runtime::String(val.operator std::string());
1961 }
1962 }
1963};
1964
1965template <typename T>
1966struct PackedFuncValueConverter<Optional<T>> {
1967 static Optional<T> From(const TVMArgValue& val) {
1968 if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
1969 return PackedFuncValueConverter<T>::From(val);
1970 }
1971 static Optional<T> From(const TVMRetValue& val) {
1972 if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
1973 return PackedFuncValueConverter<T>::From(val);
1974 }
1975};
1976
1977inline bool String::CanConvertFrom(const TVMArgValue& val) {
1978 return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
1979}
1980
1981inline TVMArgValue::operator DLDataType() const {
1982 if (String::CanConvertFrom(*this)) {
1983 return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
1984 }
1985 // None type
1986 if (type_code_ == kTVMNullptr) {
1987 DLDataType t;
1988 t.code = kTVMOpaqueHandle;
1989 t.bits = 0;
1990 t.lanes = 0;
1991 return t;
1992 }
1993 TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
1994 return value_.v_type;
1995}
1996
1997inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
1998
1999} // namespace runtime
2000} // namespace tvm
2001#endif // TVM_RUNTIME_PACKED_FUNC_H_
2002