1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/expr.h
22 * \brief Base expr nodes in TVM.
23 */
24#ifndef TVM_IR_EXPR_H_
25#define TVM_IR_EXPR_H_
26
27#include <tvm/ir/source_map.h>
28#include <tvm/ir/type.h>
29#include <tvm/node/node.h>
30#include <tvm/runtime/container/string.h>
31#include <tvm/runtime/object.h>
32
33#include <algorithm>
34#include <limits>
35#include <string>
36#include <type_traits>
37
38namespace tvm {
39
40using tvm::runtime::String;
41
42// Forward-declare VirtualDevice to avoid circular imports.
43class VirtualDevice;
44
45/*!
46 * \brief Base type of all the expressions.
47 * \sa Expr
48 */
49class BaseExprNode : public Object {
50 public:
51 /*!
52 * \brief Span that points to the original source code.
53 * Reserved debug information.
54 */
55 mutable Span span;
56
57 static constexpr const char* _type_key = "BaseExpr";
58 static constexpr const bool _type_has_method_sequal_reduce = true;
59 static constexpr const bool _type_has_method_shash_reduce = true;
60 static constexpr const uint32_t _type_child_slots = 62;
61 TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
62};
63
64/*!
65 * \brief Managed reference to BaseExprNode.
66 * \sa BaseExprNode
67 */
68class BaseExpr : public ObjectRef {
69 public:
70 TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
71};
72
73/*!
74 * \brief Base node of all primitive expressions.
75 *
76 * A primitive expression deals with low-level
77 * POD data types and handles without
78 * doing life-cycle management for objects.
79 *
80 * PrimExpr is used in the low-level code
81 * optimizations and integer analysis.
82 *
83 * \sa PrimExpr
84 */
85class PrimExprNode : public BaseExprNode {
86 public:
87 /*!
88 * \brief The runtime data type of the primitive expression.
89 *
90 * runtime::DataType(dtype) provides coarse grained type information
91 * during compile time and runtime. It is eagerly built in
92 * PrimExpr expression construction and can be used for
93 * quick type checking.
94 *
95 * dtype is sufficient to decide the Type of the PrimExpr
96 * when it corresponds to POD value types such as i32.
97 *
98 * When dtype is DataType::Handle(), the expression could corresponds to
99 * a more fine-grained Type, and we can get the type by running lazy type inference.
100 */
101 DataType dtype;
102
103 TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
104
105 static constexpr const char* _type_key = "PrimExpr";
106 static constexpr const uint32_t _type_child_slots = 38;
107 TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
108};
109
110/*!
111 * \brief Reference to PrimExprNode.
112 * \sa PrimExprNode
113 */
114class PrimExpr : public BaseExpr {
115 public:
116 /*!
117 * \brief construct from integer.
118 * \param value The value to be constructed.
119 */
120 TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
121 /*!
122 * \brief construct from float.
123 * \param value The value to be constructed.
124 */
125 TVM_DLL PrimExpr(float value); // NOLINT(*)
126
127 /*! \return the data type of this expression. */
128 DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; }
129
130 TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
131
132 private:
133 // Internal function for conversion.
134 friend struct runtime::PackedFuncValueConverter<PrimExpr>;
135 TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
136};
137
138/*!
139 * \brief add operator
140 *
141 * \param a left operand
142 * \param b right operand
143 * \return The result expression.
144 * \note this function does eager constant folding for
145 * index types(int32, int64) when possible.
146 */
147TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
148
149/*!
150 * \brief subtraction operator
151 *
152 * \param a left operand
153 * \param b right operand
154 * \return The result expression.
155 * \note this function does eager constant folding for
156 * index types(int32, int64) when possible.
157 */
158TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
159
160/*!
161 * \brief negation.
162 *
163 * \param a input.
164 * \return The result expression.
165 * \note this function does eager constant folding for
166 * index types(int32, int64) when possible.
167 */
168TVM_DLL PrimExpr operator-(PrimExpr a);
169
170/*!
171 * \brief multiplication operator
172 *
173 * \param a left operand
174 * \param b right operand
175 * \return The result expression.
176 * \note this function does eager constant folding for
177 * index types(int32, int64) when possible.
178 */
179TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
180
181/*!
182 * \brief division operator
183 *
184 * \param a left operand
185 * \param b right operand
186 * \return The result expression.
187 * \note this function does eager constant folding for
188 * index types(int32, int64) when possible.
189 */
190TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
191
192/*!
193 * \brief left shift operator
194 *
195 * \param a left operand
196 * \param b right operand
197 * \return The result expression.
198 * \note this function does eager constant folding for
199 * index types(int32, int64) when possible.
200 */
201TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
202
203/*!
204 * \brief right shift operator
205 *
206 * \param a left operand
207 * \param b right operand
208 * \return The result expression.
209 * \note this function does eager constant folding for
210 * index types(int32, int64) when possible.
211 */
212TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
213
214/*!
215 * \brief greater
216 *
217 * \param a left operand
218 * \param b right operand
219 * \return The result expression.
220 * \note this function does eager constant folding for
221 * index types(int32, int64) when possible.
222 */
223TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
224
225/*!
226 * \brief greater_equal
227 *
228 * \param a left operand
229 * \param b right operand
230 * \return The result expression.
231 * \note this function does eager constant folding for
232 * index types(int32, int64) when possible.
233 */
234TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
235
236/*!
237 * \brief less
238 *
239 * \param a left operand
240 * \param b right operand
241 * \return The result expression.
242 * \note this function does eager constant folding for
243 * index types(int32, int64) when possible.
244 */
245TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
246
247/*!
248 * \brief less_equal
249 *
250 * \param a left operand
251 * \param b right operand
252 * \return The result expression.
253 * \note this function does eager constant folding for
254 * index types(int32, int64) when possible.
255 */
256TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
257
258/*!
259 * \brief equal
260 *
261 * \param a left operand
262 * \param b right operand
263 * \return The result expression.
264 * \note this function does eager constant folding for
265 * index types(int32, int64) when possible.
266 */
267TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
268
269/*!
270 * \brief not_equal
271 *
272 * \param a left operand
273 * \param b right operand
274 * \return The result expression.
275 * \note this function does eager constant folding for
276 * index types(int32, int64) when possible.
277 */
278TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
279
280/*!
281 * \brief and
282 *
283 * \param a left operand
284 * \param b right operand
285 * \return The result expression.
286 * \note This operator does eager constant folding.
287 */
288TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
289
290/*!
291 * \brief or
292 *
293 * \param a left operand
294 * \param b right operand
295 * \return The result expression.
296 * \note This operator does eager constant folding.
297 */
298TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
299
300/*!
301 * \brief not
302 *
303 * \param a left operand
304 * \return The result expression.
305 * \note This operator does eager constant folding.
306 */
307TVM_DLL PrimExpr operator!(PrimExpr a);
308
309/*!
310 * \brief take bitwise and of two values
311 *
312 * \param a left operand
313 * \param b right operand
314 * \return The result expression.
315 * \note this function does eager constant folding for
316 * index types(int32, int64) when possible.
317 */
318TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
319
320/*!
321 * \brief take bitwise or of two values
322 *
323 * \param a left operand
324 * \param b right operand
325 * \return The result expression.
326 * \note this function does eager constant folding for
327 * index types(int32, int64) when possible.
328 */
329TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
330
331/*!
332 * \brief take bitwise xor of two values
333 *
334 * \param a left operand
335 * \param b right operand
336 * \return The result expression.
337 * \note this function does eager constant folding for
338 * index types(int32, int64) when possible.
339 */
340TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
341
342/*!
343 * \brief take bitwise negation of two values
344 *
345 * \param a the input expression.
346 * \return The result expression.
347 * \note this function does eager constant folding for
348 * index types(int32, int64) when possible.
349 */
350TVM_DLL PrimExpr operator~(PrimExpr a);
351
352/*!
353 * \brief Base node of all non-primitive expressions.
354 *
355 * RelayExpr supports tensor types, functions and ADT as
356 * first class citizens. The life-cycle of the corresponding
357 * objects are implicitly managed by the language.
358 *
359 * \sa RelayExpr
360 */
361class RelayExprNode : public BaseExprNode {
362 public:
363 /*!
364 * \brief Stores the result of type inference(type checking).
365 *
366 * \note This can be undefined before type inference.
367 * This value is discarded during serialization.
368 */
369 mutable Type checked_type_ = Type(nullptr);
370 /*!
371 * \return The checked_type
372 */
373 inline const Type& checked_type() const;
374 /*!
375 * \brief Check if the inferred(checked) type of the Expr
376 * is backed by a TTypeNode and return it.
377 *
378 * \note This function will thrown an error if the node type
379 * of this Expr is not TTypeNode.
380 *
381 * \return The corresponding TTypeNode pointer.
382 * \tparam The specific TypeNode we look for.
383 */
384 template <typename TTypeNode>
385 inline const TTypeNode* type_as() const;
386
387 /*!
388 * \brief The virtual device (VirtualDevice) for this node (the result of device planning).
389 * For first-order expressions (non functions), this describes where the result of evaluating the
390 * expression should be stored. Note that currently, all composite first-order values (tuples,
391 * references, ADTs) must be stored on the same virtual device. This means that it is not possible
392 * to store two tuple fields on different devices, so we only need one virtual device for these
393 * types.
394 *
395 * For expressions that have the function type, the virtual device describes where the result of
396 * the call to the function or closure is stored (instead of where the function itself is stored).
397 * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
398 * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
399 * device of body. For more details, see the documentation in
400 * src/relay/transforms/device_planner.cc.
401 *
402 * The VirtualDevice's Target field describes how the body of the function should be compiled.
403 *
404 * Set to VirtualDevice::FullyUnconstrained by default.
405 *
406 * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
407 * import.
408 */
409 mutable ObjectRef virtual_device_;
410
411 /*!
412 * \return The virtual device (VirtualDevice).
413 * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained().
414 * Note that for function types, the virtual device is the device where the result of a
415 * call to the function is stored, not where the function itself lives.
416 * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
417 * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
418 * device of body.
419 *
420 * See the documentation of the virtual_device_ field (above) for more details.
421 */
422 VirtualDevice virtual_device() const;
423
424 static constexpr const char* _type_key = "RelayExpr";
425 static constexpr const uint32_t _type_child_slots = 22;
426 TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
427};
428
429/*!
430 * \brief Managed reference to RelayExprNode.
431 * \sa RelayExprNode
432 */
433class RelayExpr : public BaseExpr {
434 public:
435 TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode);
436};
437
438class GlobalVar;
439/*!
440 * \brief Global variable that lives in the top-level module.
441 *
442 * A GlobalVar only refers to function definitions.
443 * This is used to enable recursive calls between function.
444 *
445 * \sa GlobalVarNode
446 */
447class GlobalVarNode : public RelayExprNode {
448 public:
449 /*! \brief The name of the variable, this only acts as a hint. */
450 String name_hint;
451
452 void VisitAttrs(AttrVisitor* v) {
453 v->Visit("name_hint", &name_hint);
454 v->Visit("virtual_device_", &virtual_device_);
455 v->Visit("span", &span);
456 v->Visit("_checked_type_", &checked_type_);
457 }
458
459 bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
460 // name matters for global var.
461 return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
462 }
463
464 void SHashReduce(SHashReducer hash_reduce) const {
465 hash_reduce(name_hint);
466 hash_reduce.FreeVarHashImpl(this);
467 }
468
469 static constexpr const char* _type_key = "GlobalVar";
470 TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
471};
472
473/*!
474 * \brief Managed reference to GlobalVarNode.
475 * \sa GlobalVarNode
476 */
477class GlobalVar : public RelayExpr {
478 public:
479 TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
480
481 TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
482 TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
483};
484
485// PrimExprs that are useful as runtime containers.
486//
487/*!
488 * \brief Constant integer literals in the program.
489 * \sa IntImm
490 */
491class IntImmNode : public PrimExprNode {
492 public:
493 /*! \brief the Internal value. */
494 int64_t value;
495
496 void VisitAttrs(AttrVisitor* v) {
497 v->Visit("dtype", &dtype);
498 v->Visit("value", &value);
499 v->Visit("span", &span);
500 }
501
502 bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
503 return equal(dtype, other->dtype) && equal(value, other->value);
504 }
505
506 void SHashReduce(SHashReducer hash_reduce) const {
507 hash_reduce(dtype);
508 hash_reduce(value);
509 }
510
511 static constexpr const char* _type_key = "IntImm";
512 TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
513};
514
515/*!
516 * \brief Managed reference class to IntImmNode.
517 *
518 * \sa IntImmNode
519 */
520class IntImm : public PrimExpr {
521 public:
522 /*!
523 * \brief Constructor.
524 * \param dtype The data type of the value.
525 * \param value The internal value.
526 * \param span The location of this object in the source code.
527 */
528 TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
529
530 TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
531 TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
532};
533
534/*!
535 * \brief Constant floating point literals in the program.
536 * \sa FloatImm
537 */
538class FloatImmNode : public PrimExprNode {
539 public:
540 /*! \brief The constant value content. */
541 double value;
542
543 void VisitAttrs(AttrVisitor* v) {
544 v->Visit("dtype", &dtype);
545 v->Visit("value", &value);
546 v->Visit("span", &span);
547 }
548
549 bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
550 return equal(dtype, other->dtype) && equal(value, other->value);
551 }
552
553 void SHashReduce(SHashReducer hash_reduce) const {
554 hash_reduce(dtype);
555 hash_reduce(value);
556 }
557
558 static constexpr const char* _type_key = "FloatImm";
559 TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
560};
561
562/*!
563 * \brief Managed reference class to FloatImmNode.
564 *
565 * \sa FloatImmNode
566 */
567class FloatImm : public PrimExpr {
568 public:
569 /*!
570 * \brief Constructor.
571 * \param dtype The data type of the value.
572 * \param value The internal value.
573 * \param span The location in the source code.
574 */
575 TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());
576
577 TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
578 TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
579};
580
581/*!
582 * \brief Boolean constant.
583 *
584 * This reference type is useful to add additional compile-time
585 * type checks and helper functions for Integer equal comparisons.
586 */
587class Bool : public IntImm {
588 public:
589 explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {}
590 Bool operator!() const { return Bool((*this)->value == 0); }
591 operator bool() const { return (*this)->value != 0; }
592
593 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
594};
595
596// Overload operators to make sure we have the most fine grained types.
597inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
598inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
599inline Bool operator||(const Bool& a, const Bool& b) {
600 return Bool(a.operator bool() || b.operator bool());
601}
602inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
603inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
604inline Bool operator&&(const Bool& a, const Bool& b) {
605 return Bool(a.operator bool() && b.operator bool());
606}
607
608inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; }
609inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); }
610inline bool operator==(const Bool& a, const Bool& b) {
611 return a.operator bool() == b.operator bool();
612}
613
614/*!
615 * \brief Container of constant int that adds more constructors.
616 *
617 * This is used to store and automate type check
618 * attributes that must be constant integer.
619 *
620 * \sa IntImm
621 */
622class Integer : public IntImm {
623 public:
624 Integer() {}
625 /*!
626 * \brief constructor from node.
627 */
628 explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
629 /*!
630 * \brief Construct integer from int value.
631 */
632 Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*)
633 /*!
634 * \brief Construct integer from int imm.
635 * \param other The other value.
636 */
637 Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
638 /*!
639 * \brief Constructor from enum
640 * \tparam Enum The enum type.
641 * \param value The enum value.
642 */
643 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
644 explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
645 static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
646 "declare enum to be enum int to use visitor");
647 }
648 /*!
649 * \brief Assign an expression to integer.
650 * \param other another expression.
651 */
652 Integer& operator=(const IntImm& other) {
653 data_ = ObjectRef::GetDataPtr<Object>(other);
654 return *this;
655 }
656 /*!
657 * \brief convert to int64_t
658 */
659 int64_t IntValue() const {
660 ICHECK(data_ != nullptr) << " Trying to reference a null Integer";
661 return (*this)->value;
662 }
663 // comparators
664 Bool operator==(int other) const {
665 if (data_ == nullptr) return Bool(false);
666 return Bool((*this)->value == other);
667 }
668 Bool operator!=(int other) const { return !(*this == other); }
669 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
670 Bool operator==(Enum other) const {
671 return *this == static_cast<int>(other);
672 }
673 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
674 Bool operator!=(Enum other) const {
675 return *this != static_cast<int>(other);
676 }
677};
678
679/*! \brief range over one dimension */
680class RangeNode : public Object {
681 public:
682 /*! \brief beginning of the node */
683 PrimExpr min;
684 /*! \brief the extend of range */
685 PrimExpr extent;
686 /*! \brief the location of this range in the source */
687 mutable Span span;
688 /*! \brief constructor */
689 RangeNode() {}
690 RangeNode(PrimExpr min, PrimExpr extent, Span span = Span())
691 : min(min), extent(extent), span(span) {}
692
693 void VisitAttrs(AttrVisitor* v) {
694 v->Visit("min", &min);
695 v->Visit("extent", &extent);
696 v->Visit("span", &span);
697 }
698
699 bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
700 return equal(min, other->min) && equal(extent, other->extent);
701 }
702
703 void SHashReduce(SHashReducer hash_reduce) const {
704 hash_reduce(min);
705 hash_reduce(extent);
706 }
707
708 static constexpr const char* _type_key = "Range";
709 static constexpr const bool _type_has_method_sequal_reduce = true;
710 static constexpr const bool _type_has_method_shash_reduce = true;
711 TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
712};
713
714/*! \brief Range constainer */
715class Range : public ObjectRef {
716 public:
717 /*!
718 * \brief constructor by begin and end
719 * \param begin The begin of the range.
720 * \param end The end of the range.
721 * \param span The location of the Range in the source.
722 */
723 TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
724 /*!
725 * \brief construct a new range with min and extent
726 * The corresponding constructor is removed,
727 * because that is counter convention of tradition meaning
728 * of range(begin, end)
729 *
730 * \param min The minimum range.
731 * \param extent The extent of the range.
732 * \param span The location of the Range in the source.
733 */
734 static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span());
735 // declare range.
736 TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
737};
738
739// implementataions
740inline const Type& RelayExprNode::checked_type() const {
741 ICHECK(checked_type_.defined()) << "internal error: the type checker has "
742 << "not populated the checked_type "
743 << "field for " << GetRef<RelayExpr>(this);
744 return this->checked_type_;
745}
746
747template <typename TTypeNode>
748inline const TTypeNode* RelayExprNode::type_as() const {
749 static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
750 "TType must be a special case of type");
751 ICHECK(checked_type_.defined())
752 << "Type inference for this Expr has not completed. Try to call infer_type pass.";
753 const TTypeNode* node = checked_type_.as<TTypeNode>();
754 ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
755 << checked_type_->GetTypeKey();
756 return node;
757}
758
759} // namespace tvm
760
761namespace tvm {
762namespace runtime {
763// common rule for RetValue and ArgValue
764template <>
765struct PackedFuncValueConverter<PrimExpr> {
766 static PrimExpr From(const TVMPODValue_& val) {
767 if (val.type_code() == kTVMNullptr) {
768 return PrimExpr(ObjectPtr<Object>(nullptr));
769 }
770 if (val.type_code() == kDLInt) {
771 int64_t value = val.operator int64_t();
772 if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
773 return IntImm(runtime::DataType::Int(64), value);
774 }
775 return IntImm(runtime::DataType::Int(32), val.operator int());
776 }
777 if (val.type_code() == kDLFloat) {
778 return FloatImm(runtime::DataType::Float(32), val.operator double());
779 }
780
781 return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
782 }
783};
784
785template <>
786struct PackedFuncValueConverter<tvm::Integer> {
787 static tvm::Integer From(const TVMPODValue_& val) {
788 if (val.type_code() == kTVMNullptr) {
789 return Integer(ObjectPtr<Object>(nullptr));
790 }
791 if (val.type_code() == kTVMArgInt) {
792 return Integer(val.operator int());
793 }
794 return val.AsObjectRef<tvm::Integer>();
795 }
796};
797
798template <>
799struct PackedFuncValueConverter<tvm::Bool> {
800 static tvm::Bool From(const TVMPODValue_& val) {
801 if (val.type_code() == kTVMNullptr) {
802 return Bool(ObjectPtr<Object>(nullptr));
803 }
804 if (val.type_code() == kTVMArgInt) {
805 int v = val.operator int();
806 ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
807 return Bool(static_cast<bool>(v));
808 }
809 return val.AsObjectRef<tvm::Bool>();
810 }
811};
812
813} // namespace runtime
814} // namespace tvm
815#endif // TVM_IR_EXPR_H_
816