1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/expr.h |
22 | * \brief Base expr nodes in TVM. |
23 | */ |
24 | #ifndef TVM_IR_EXPR_H_ |
25 | #define TVM_IR_EXPR_H_ |
26 | |
27 | #include <tvm/ir/source_map.h> |
28 | #include <tvm/ir/type.h> |
29 | #include <tvm/node/node.h> |
30 | #include <tvm/runtime/container/string.h> |
31 | #include <tvm/runtime/object.h> |
32 | |
33 | #include <algorithm> |
34 | #include <limits> |
35 | #include <string> |
36 | #include <type_traits> |
37 | |
38 | namespace tvm { |
39 | |
40 | using tvm::runtime::String; |
41 | |
42 | // Forward-declare VirtualDevice to avoid circular imports. |
43 | class VirtualDevice; |
44 | |
45 | /*! |
46 | * \brief Base type of all the expressions. |
47 | * \sa Expr |
48 | */ |
49 | class BaseExprNode : public Object { |
50 | public: |
51 | /*! |
52 | * \brief Span that points to the original source code. |
53 | * Reserved debug information. |
54 | */ |
55 | mutable Span span; |
56 | |
57 | static constexpr const char* _type_key = "BaseExpr" ; |
58 | static constexpr const bool _type_has_method_sequal_reduce = true; |
59 | static constexpr const bool _type_has_method_shash_reduce = true; |
60 | static constexpr const uint32_t _type_child_slots = 62; |
61 | TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); |
62 | }; |
63 | |
64 | /*! |
65 | * \brief Managed reference to BaseExprNode. |
66 | * \sa BaseExprNode |
67 | */ |
68 | class BaseExpr : public ObjectRef { |
69 | public: |
70 | TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode); |
71 | }; |
72 | |
73 | /*! |
74 | * \brief Base node of all primitive expressions. |
75 | * |
76 | * A primitive expression deals with low-level |
77 | * POD data types and handles without |
78 | * doing life-cycle management for objects. |
79 | * |
80 | * PrimExpr is used in the low-level code |
81 | * optimizations and integer analysis. |
82 | * |
83 | * \sa PrimExpr |
84 | */ |
85 | class PrimExprNode : public BaseExprNode { |
86 | public: |
87 | /*! |
88 | * \brief The runtime data type of the primitive expression. |
89 | * |
90 | * runtime::DataType(dtype) provides coarse grained type information |
91 | * during compile time and runtime. It is eagerly built in |
92 | * PrimExpr expression construction and can be used for |
93 | * quick type checking. |
94 | * |
95 | * dtype is sufficient to decide the Type of the PrimExpr |
96 | * when it corresponds to POD value types such as i32. |
97 | * |
98 | * When dtype is DataType::Handle(), the expression could corresponds to |
99 | * a more fine-grained Type, and we can get the type by running lazy type inference. |
100 | */ |
101 | DataType dtype; |
102 | |
103 | TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); |
104 | |
105 | static constexpr const char* _type_key = "PrimExpr" ; |
106 | static constexpr const uint32_t _type_child_slots = 38; |
107 | TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); |
108 | }; |
109 | |
110 | /*! |
111 | * \brief Reference to PrimExprNode. |
112 | * \sa PrimExprNode |
113 | */ |
114 | class PrimExpr : public BaseExpr { |
115 | public: |
116 | /*! |
117 | * \brief construct from integer. |
118 | * \param value The value to be constructed. |
119 | */ |
120 | TVM_DLL PrimExpr(int32_t value); // NOLINT(*) |
121 | /*! |
122 | * \brief construct from float. |
123 | * \param value The value to be constructed. |
124 | */ |
125 | TVM_DLL PrimExpr(float value); // NOLINT(*) |
126 | |
127 | /*! \return the data type of this expression. */ |
128 | DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; } |
129 | |
130 | TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); |
131 | |
132 | private: |
133 | // Internal function for conversion. |
134 | friend struct runtime::PackedFuncValueConverter<PrimExpr>; |
135 | TVM_DLL static PrimExpr FromObject_(ObjectRef ref); |
136 | }; |
137 | |
138 | /*! |
139 | * \brief add operator |
140 | * |
141 | * \param a left operand |
142 | * \param b right operand |
143 | * \return The result expression. |
144 | * \note this function does eager constant folding for |
145 | * index types(int32, int64) when possible. |
146 | */ |
147 | TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); |
148 | |
149 | /*! |
150 | * \brief subtraction operator |
151 | * |
152 | * \param a left operand |
153 | * \param b right operand |
154 | * \return The result expression. |
155 | * \note this function does eager constant folding for |
156 | * index types(int32, int64) when possible. |
157 | */ |
158 | TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); |
159 | |
160 | /*! |
161 | * \brief negation. |
162 | * |
163 | * \param a input. |
164 | * \return The result expression. |
165 | * \note this function does eager constant folding for |
166 | * index types(int32, int64) when possible. |
167 | */ |
168 | TVM_DLL PrimExpr operator-(PrimExpr a); |
169 | |
170 | /*! |
171 | * \brief multiplication operator |
172 | * |
173 | * \param a left operand |
174 | * \param b right operand |
175 | * \return The result expression. |
176 | * \note this function does eager constant folding for |
177 | * index types(int32, int64) when possible. |
178 | */ |
179 | TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b); |
180 | |
181 | /*! |
182 | * \brief division operator |
183 | * |
184 | * \param a left operand |
185 | * \param b right operand |
186 | * \return The result expression. |
187 | * \note this function does eager constant folding for |
188 | * index types(int32, int64) when possible. |
189 | */ |
190 | TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b); |
191 | |
192 | /*! |
193 | * \brief left shift operator |
194 | * |
195 | * \param a left operand |
196 | * \param b right operand |
197 | * \return The result expression. |
198 | * \note this function does eager constant folding for |
199 | * index types(int32, int64) when possible. |
200 | */ |
201 | TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b); |
202 | |
203 | /*! |
204 | * \brief right shift operator |
205 | * |
206 | * \param a left operand |
207 | * \param b right operand |
208 | * \return The result expression. |
209 | * \note this function does eager constant folding for |
210 | * index types(int32, int64) when possible. |
211 | */ |
212 | TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b); |
213 | |
214 | /*! |
215 | * \brief greater |
216 | * |
217 | * \param a left operand |
218 | * \param b right operand |
219 | * \return The result expression. |
220 | * \note this function does eager constant folding for |
221 | * index types(int32, int64) when possible. |
222 | */ |
223 | TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b); |
224 | |
225 | /*! |
226 | * \brief greater_equal |
227 | * |
228 | * \param a left operand |
229 | * \param b right operand |
230 | * \return The result expression. |
231 | * \note this function does eager constant folding for |
232 | * index types(int32, int64) when possible. |
233 | */ |
234 | TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b); |
235 | |
236 | /*! |
237 | * \brief less |
238 | * |
239 | * \param a left operand |
240 | * \param b right operand |
241 | * \return The result expression. |
242 | * \note this function does eager constant folding for |
243 | * index types(int32, int64) when possible. |
244 | */ |
245 | TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b); |
246 | |
247 | /*! |
248 | * \brief less_equal |
249 | * |
250 | * \param a left operand |
251 | * \param b right operand |
252 | * \return The result expression. |
253 | * \note this function does eager constant folding for |
254 | * index types(int32, int64) when possible. |
255 | */ |
256 | TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b); |
257 | |
258 | /*! |
259 | * \brief equal |
260 | * |
261 | * \param a left operand |
262 | * \param b right operand |
263 | * \return The result expression. |
264 | * \note this function does eager constant folding for |
265 | * index types(int32, int64) when possible. |
266 | */ |
267 | TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b); |
268 | |
269 | /*! |
270 | * \brief not_equal |
271 | * |
272 | * \param a left operand |
273 | * \param b right operand |
274 | * \return The result expression. |
275 | * \note this function does eager constant folding for |
276 | * index types(int32, int64) when possible. |
277 | */ |
278 | TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b); |
279 | |
280 | /*! |
281 | * \brief and |
282 | * |
283 | * \param a left operand |
284 | * \param b right operand |
285 | * \return The result expression. |
286 | * \note This operator does eager constant folding. |
287 | */ |
288 | TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b); |
289 | |
290 | /*! |
291 | * \brief or |
292 | * |
293 | * \param a left operand |
294 | * \param b right operand |
295 | * \return The result expression. |
296 | * \note This operator does eager constant folding. |
297 | */ |
298 | TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b); |
299 | |
300 | /*! |
301 | * \brief not |
302 | * |
303 | * \param a left operand |
304 | * \return The result expression. |
305 | * \note This operator does eager constant folding. |
306 | */ |
307 | TVM_DLL PrimExpr operator!(PrimExpr a); |
308 | |
309 | /*! |
310 | * \brief take bitwise and of two values |
311 | * |
312 | * \param a left operand |
313 | * \param b right operand |
314 | * \return The result expression. |
315 | * \note this function does eager constant folding for |
316 | * index types(int32, int64) when possible. |
317 | */ |
318 | TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b); |
319 | |
320 | /*! |
321 | * \brief take bitwise or of two values |
322 | * |
323 | * \param a left operand |
324 | * \param b right operand |
325 | * \return The result expression. |
326 | * \note this function does eager constant folding for |
327 | * index types(int32, int64) when possible. |
328 | */ |
329 | TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b); |
330 | |
331 | /*! |
332 | * \brief take bitwise xor of two values |
333 | * |
334 | * \param a left operand |
335 | * \param b right operand |
336 | * \return The result expression. |
337 | * \note this function does eager constant folding for |
338 | * index types(int32, int64) when possible. |
339 | */ |
340 | TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b); |
341 | |
342 | /*! |
343 | * \brief take bitwise negation of two values |
344 | * |
345 | * \param a the input expression. |
346 | * \return The result expression. |
347 | * \note this function does eager constant folding for |
348 | * index types(int32, int64) when possible. |
349 | */ |
350 | TVM_DLL PrimExpr operator~(PrimExpr a); |
351 | |
352 | /*! |
353 | * \brief Base node of all non-primitive expressions. |
354 | * |
355 | * RelayExpr supports tensor types, functions and ADT as |
356 | * first class citizens. The life-cycle of the corresponding |
357 | * objects are implicitly managed by the language. |
358 | * |
359 | * \sa RelayExpr |
360 | */ |
361 | class RelayExprNode : public BaseExprNode { |
362 | public: |
363 | /*! |
364 | * \brief Stores the result of type inference(type checking). |
365 | * |
366 | * \note This can be undefined before type inference. |
367 | * This value is discarded during serialization. |
368 | */ |
369 | mutable Type checked_type_ = Type(nullptr); |
370 | /*! |
371 | * \return The checked_type |
372 | */ |
373 | inline const Type& checked_type() const; |
374 | /*! |
375 | * \brief Check if the inferred(checked) type of the Expr |
376 | * is backed by a TTypeNode and return it. |
377 | * |
378 | * \note This function will thrown an error if the node type |
379 | * of this Expr is not TTypeNode. |
380 | * |
381 | * \return The corresponding TTypeNode pointer. |
382 | * \tparam The specific TypeNode we look for. |
383 | */ |
384 | template <typename TTypeNode> |
385 | inline const TTypeNode* type_as() const; |
386 | |
387 | /*! |
388 | * \brief The virtual device (VirtualDevice) for this node (the result of device planning). |
389 | * For first-order expressions (non functions), this describes where the result of evaluating the |
390 | * expression should be stored. Note that currently, all composite first-order values (tuples, |
391 | * references, ADTs) must be stored on the same virtual device. This means that it is not possible |
392 | * to store two tuple fields on different devices, so we only need one virtual device for these |
393 | * types. |
394 | * |
395 | * For expressions that have the function type, the virtual device describes where the result of |
396 | * the call to the function or closure is stored (instead of where the function itself is stored). |
397 | * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where |
398 | * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual |
399 | * device of body. For more details, see the documentation in |
400 | * src/relay/transforms/device_planner.cc. |
401 | * |
402 | * The VirtualDevice's Target field describes how the body of the function should be compiled. |
403 | * |
404 | * Set to VirtualDevice::FullyUnconstrained by default. |
405 | * |
406 | * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular |
407 | * import. |
408 | */ |
409 | mutable ObjectRef virtual_device_; |
410 | |
411 | /*! |
412 | * \return The virtual device (VirtualDevice). |
413 | * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). |
414 | * Note that for function types, the virtual device is the device where the result of a |
415 | * call to the function is stored, not where the function itself lives. |
416 | * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where |
417 | * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual |
418 | * device of body. |
419 | * |
420 | * See the documentation of the virtual_device_ field (above) for more details. |
421 | */ |
422 | VirtualDevice virtual_device() const; |
423 | |
424 | static constexpr const char* _type_key = "RelayExpr" ; |
425 | static constexpr const uint32_t _type_child_slots = 22; |
426 | TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); |
427 | }; |
428 | |
429 | /*! |
430 | * \brief Managed reference to RelayExprNode. |
431 | * \sa RelayExprNode |
432 | */ |
433 | class RelayExpr : public BaseExpr { |
434 | public: |
435 | TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode); |
436 | }; |
437 | |
438 | class GlobalVar; |
439 | /*! |
440 | * \brief Global variable that lives in the top-level module. |
441 | * |
442 | * A GlobalVar only refers to function definitions. |
443 | * This is used to enable recursive calls between function. |
444 | * |
445 | * \sa GlobalVarNode |
446 | */ |
447 | class GlobalVarNode : public RelayExprNode { |
448 | public: |
449 | /*! \brief The name of the variable, this only acts as a hint. */ |
450 | String name_hint; |
451 | |
452 | void VisitAttrs(AttrVisitor* v) { |
453 | v->Visit("name_hint" , &name_hint); |
454 | v->Visit("virtual_device_" , &virtual_device_); |
455 | v->Visit("span" , &span); |
456 | v->Visit("_checked_type_" , &checked_type_); |
457 | } |
458 | |
459 | bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { |
460 | // name matters for global var. |
461 | return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); |
462 | } |
463 | |
464 | void SHashReduce(SHashReducer hash_reduce) const { |
465 | hash_reduce(name_hint); |
466 | hash_reduce.FreeVarHashImpl(this); |
467 | } |
468 | |
469 | static constexpr const char* _type_key = "GlobalVar" ; |
470 | TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); |
471 | }; |
472 | |
473 | /*! |
474 | * \brief Managed reference to GlobalVarNode. |
475 | * \sa GlobalVarNode |
476 | */ |
477 | class GlobalVar : public RelayExpr { |
478 | public: |
479 | TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); |
480 | |
481 | TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); |
482 | TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); |
483 | }; |
484 | |
485 | // PrimExprs that are useful as runtime containers. |
486 | // |
487 | /*! |
488 | * \brief Constant integer literals in the program. |
489 | * \sa IntImm |
490 | */ |
491 | class IntImmNode : public PrimExprNode { |
492 | public: |
493 | /*! \brief the Internal value. */ |
494 | int64_t value; |
495 | |
496 | void VisitAttrs(AttrVisitor* v) { |
497 | v->Visit("dtype" , &dtype); |
498 | v->Visit("value" , &value); |
499 | v->Visit("span" , &span); |
500 | } |
501 | |
502 | bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { |
503 | return equal(dtype, other->dtype) && equal(value, other->value); |
504 | } |
505 | |
506 | void SHashReduce(SHashReducer hash_reduce) const { |
507 | hash_reduce(dtype); |
508 | hash_reduce(value); |
509 | } |
510 | |
511 | static constexpr const char* _type_key = "IntImm" ; |
512 | TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); |
513 | }; |
514 | |
515 | /*! |
516 | * \brief Managed reference class to IntImmNode. |
517 | * |
518 | * \sa IntImmNode |
519 | */ |
520 | class IntImm : public PrimExpr { |
521 | public: |
522 | /*! |
523 | * \brief Constructor. |
524 | * \param dtype The data type of the value. |
525 | * \param value The internal value. |
526 | * \param span The location of this object in the source code. |
527 | */ |
528 | TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); |
529 | |
530 | TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); |
531 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); |
532 | }; |
533 | |
534 | /*! |
535 | * \brief Constant floating point literals in the program. |
536 | * \sa FloatImm |
537 | */ |
538 | class FloatImmNode : public PrimExprNode { |
539 | public: |
540 | /*! \brief The constant value content. */ |
541 | double value; |
542 | |
543 | void VisitAttrs(AttrVisitor* v) { |
544 | v->Visit("dtype" , &dtype); |
545 | v->Visit("value" , &value); |
546 | v->Visit("span" , &span); |
547 | } |
548 | |
549 | bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const { |
550 | return equal(dtype, other->dtype) && equal(value, other->value); |
551 | } |
552 | |
553 | void SHashReduce(SHashReducer hash_reduce) const { |
554 | hash_reduce(dtype); |
555 | hash_reduce(value); |
556 | } |
557 | |
558 | static constexpr const char* _type_key = "FloatImm" ; |
559 | TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); |
560 | }; |
561 | |
562 | /*! |
563 | * \brief Managed reference class to FloatImmNode. |
564 | * |
565 | * \sa FloatImmNode |
566 | */ |
567 | class FloatImm : public PrimExpr { |
568 | public: |
569 | /*! |
570 | * \brief Constructor. |
571 | * \param dtype The data type of the value. |
572 | * \param value The internal value. |
573 | * \param span The location in the source code. |
574 | */ |
575 | TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); |
576 | |
577 | TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); |
578 | TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); |
579 | }; |
580 | |
581 | /*! |
582 | * \brief Boolean constant. |
583 | * |
584 | * This reference type is useful to add additional compile-time |
585 | * type checks and helper functions for Integer equal comparisons. |
586 | */ |
587 | class Bool : public IntImm { |
588 | public: |
589 | explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {} |
590 | Bool operator!() const { return Bool((*this)->value == 0); } |
591 | operator bool() const { return (*this)->value != 0; } |
592 | |
593 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); |
594 | }; |
595 | |
596 | // Overload operators to make sure we have the most fine grained types. |
597 | inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } |
598 | inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } |
599 | inline Bool operator||(const Bool& a, const Bool& b) { |
600 | return Bool(a.operator bool() || b.operator bool()); |
601 | } |
602 | inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } |
603 | inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } |
604 | inline Bool operator&&(const Bool& a, const Bool& b) { |
605 | return Bool(a.operator bool() && b.operator bool()); |
606 | } |
607 | |
608 | inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; } |
609 | inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); } |
610 | inline bool operator==(const Bool& a, const Bool& b) { |
611 | return a.operator bool() == b.operator bool(); |
612 | } |
613 | |
614 | /*! |
615 | * \brief Container of constant int that adds more constructors. |
616 | * |
617 | * This is used to store and automate type check |
618 | * attributes that must be constant integer. |
619 | * |
620 | * \sa IntImm |
621 | */ |
622 | class Integer : public IntImm { |
623 | public: |
624 | Integer() {} |
625 | /*! |
626 | * \brief constructor from node. |
627 | */ |
628 | explicit Integer(ObjectPtr<Object> node) : IntImm(node) {} |
629 | /*! |
630 | * \brief Construct integer from int value. |
631 | */ |
632 | Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*) |
633 | /*! |
634 | * \brief Construct integer from int imm. |
635 | * \param other The other value. |
636 | */ |
637 | Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) |
638 | /*! |
639 | * \brief Constructor from enum |
640 | * \tparam Enum The enum type. |
641 | * \param value The enum value. |
642 | */ |
643 | template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> |
644 | explicit Integer(Enum value) : Integer(static_cast<int>(value)) { |
645 | static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value, |
646 | "declare enum to be enum int to use visitor" ); |
647 | } |
648 | /*! |
649 | * \brief Assign an expression to integer. |
650 | * \param other another expression. |
651 | */ |
652 | Integer& operator=(const IntImm& other) { |
653 | data_ = ObjectRef::GetDataPtr<Object>(other); |
654 | return *this; |
655 | } |
656 | /*! |
657 | * \brief convert to int64_t |
658 | */ |
659 | int64_t IntValue() const { |
660 | ICHECK(data_ != nullptr) << " Trying to reference a null Integer" ; |
661 | return (*this)->value; |
662 | } |
663 | // comparators |
664 | Bool operator==(int other) const { |
665 | if (data_ == nullptr) return Bool(false); |
666 | return Bool((*this)->value == other); |
667 | } |
668 | Bool operator!=(int other) const { return !(*this == other); } |
669 | template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> |
670 | Bool operator==(Enum other) const { |
671 | return *this == static_cast<int>(other); |
672 | } |
673 | template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> |
674 | Bool operator!=(Enum other) const { |
675 | return *this != static_cast<int>(other); |
676 | } |
677 | }; |
678 | |
679 | /*! \brief range over one dimension */ |
680 | class RangeNode : public Object { |
681 | public: |
682 | /*! \brief beginning of the node */ |
683 | PrimExpr min; |
684 | /*! \brief the extend of range */ |
685 | PrimExpr extent; |
686 | /*! \brief the location of this range in the source */ |
687 | mutable Span span; |
688 | /*! \brief constructor */ |
689 | RangeNode() {} |
690 | RangeNode(PrimExpr min, PrimExpr extent, Span span = Span()) |
691 | : min(min), extent(extent), span(span) {} |
692 | |
693 | void VisitAttrs(AttrVisitor* v) { |
694 | v->Visit("min" , &min); |
695 | v->Visit("extent" , &extent); |
696 | v->Visit("span" , &span); |
697 | } |
698 | |
699 | bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const { |
700 | return equal(min, other->min) && equal(extent, other->extent); |
701 | } |
702 | |
703 | void SHashReduce(SHashReducer hash_reduce) const { |
704 | hash_reduce(min); |
705 | hash_reduce(extent); |
706 | } |
707 | |
708 | static constexpr const char* _type_key = "Range" ; |
709 | static constexpr const bool _type_has_method_sequal_reduce = true; |
710 | static constexpr const bool _type_has_method_shash_reduce = true; |
711 | TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); |
712 | }; |
713 | |
714 | /*! \brief Range constainer */ |
715 | class Range : public ObjectRef { |
716 | public: |
717 | /*! |
718 | * \brief constructor by begin and end |
719 | * \param begin The begin of the range. |
720 | * \param end The end of the range. |
721 | * \param span The location of the Range in the source. |
722 | */ |
723 | TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span()); |
724 | /*! |
725 | * \brief construct a new range with min and extent |
726 | * The corresponding constructor is removed, |
727 | * because that is counter convention of tradition meaning |
728 | * of range(begin, end) |
729 | * |
730 | * \param min The minimum range. |
731 | * \param extent The extent of the range. |
732 | * \param span The location of the Range in the source. |
733 | */ |
734 | static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span()); |
735 | // declare range. |
736 | TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); |
737 | }; |
738 | |
739 | // implementataions |
740 | inline const Type& RelayExprNode::checked_type() const { |
741 | ICHECK(checked_type_.defined()) << "internal error: the type checker has " |
742 | << "not populated the checked_type " |
743 | << "field for " << GetRef<RelayExpr>(this); |
744 | return this->checked_type_; |
745 | } |
746 | |
747 | template <typename TTypeNode> |
748 | inline const TTypeNode* RelayExprNode::type_as() const { |
749 | static_assert(std::is_base_of<TypeNode, TTypeNode>::value, |
750 | "TType must be a special case of type" ); |
751 | ICHECK(checked_type_.defined()) |
752 | << "Type inference for this Expr has not completed. Try to call infer_type pass." ; |
753 | const TTypeNode* node = checked_type_.as<TTypeNode>(); |
754 | ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " |
755 | << checked_type_->GetTypeKey(); |
756 | return node; |
757 | } |
758 | |
759 | } // namespace tvm |
760 | |
761 | namespace tvm { |
762 | namespace runtime { |
763 | // common rule for RetValue and ArgValue |
764 | template <> |
765 | struct PackedFuncValueConverter<PrimExpr> { |
766 | static PrimExpr From(const TVMPODValue_& val) { |
767 | if (val.type_code() == kTVMNullptr) { |
768 | return PrimExpr(ObjectPtr<Object>(nullptr)); |
769 | } |
770 | if (val.type_code() == kDLInt) { |
771 | int64_t value = val.operator int64_t(); |
772 | if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) { |
773 | return IntImm(runtime::DataType::Int(64), value); |
774 | } |
775 | return IntImm(runtime::DataType::Int(32), val.operator int()); |
776 | } |
777 | if (val.type_code() == kDLFloat) { |
778 | return FloatImm(runtime::DataType::Float(32), val.operator double()); |
779 | } |
780 | |
781 | return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>()); |
782 | } |
783 | }; |
784 | |
785 | template <> |
786 | struct PackedFuncValueConverter<tvm::Integer> { |
787 | static tvm::Integer From(const TVMPODValue_& val) { |
788 | if (val.type_code() == kTVMNullptr) { |
789 | return Integer(ObjectPtr<Object>(nullptr)); |
790 | } |
791 | if (val.type_code() == kTVMArgInt) { |
792 | return Integer(val.operator int()); |
793 | } |
794 | return val.AsObjectRef<tvm::Integer>(); |
795 | } |
796 | }; |
797 | |
798 | template <> |
799 | struct PackedFuncValueConverter<tvm::Bool> { |
800 | static tvm::Bool From(const TVMPODValue_& val) { |
801 | if (val.type_code() == kTVMNullptr) { |
802 | return Bool(ObjectPtr<Object>(nullptr)); |
803 | } |
804 | if (val.type_code() == kTVMArgInt) { |
805 | int v = val.operator int(); |
806 | ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; |
807 | return Bool(static_cast<bool>(v)); |
808 | } |
809 | return val.AsObjectRef<tvm::Bool>(); |
810 | } |
811 | }; |
812 | |
813 | } // namespace runtime |
814 | } // namespace tvm |
815 | #endif // TVM_IR_EXPR_H_ |
816 | |