1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/expr.h |
22 | * \brief TIR expressions. |
23 | */ |
24 | // Acknowledgement: Many low-level IR nodes originate from Halide. |
25 | #ifndef TVM_TIR_EXPR_H_ |
26 | #define TVM_TIR_EXPR_H_ |
27 | |
28 | #include <tvm/ir/expr.h> |
29 | #include <tvm/node/functor.h> |
30 | #include <tvm/node/node.h> |
31 | #include <tvm/runtime/c_runtime_api.h> |
32 | #include <tvm/runtime/container/array.h> |
33 | #include <tvm/runtime/container/map.h> |
34 | #include <tvm/runtime/container/string.h> |
35 | #include <tvm/runtime/data_type.h> |
36 | #include <tvm/tir/buffer.h> |
37 | #include <tvm/tir/var.h> |
38 | |
39 | #include <algorithm> |
40 | #include <iostream> |
41 | #include <limits> |
42 | #include <string> |
43 | #include <unordered_map> |
44 | #include <utility> |
45 | |
46 | namespace tvm { |
47 | namespace tir { |
48 | |
49 | using IntImmNode = tvm::IntImmNode; |
50 | using FloatImmNode = tvm::FloatImmNode; |
51 | |
52 | /*! \brief String constants, only used in asserts. */ |
53 | class StringImmNode : public PrimExprNode { |
54 | public: |
55 | /*! \brief The constant value content. */ |
56 | String value; |
57 | |
58 | void VisitAttrs(AttrVisitor* v) { |
59 | v->Visit("dtype" , &dtype); |
60 | v->Visit("value" , &value); |
61 | v->Visit("span" , &span); |
62 | } |
63 | |
64 | bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { |
65 | return equal(value, other->value); |
66 | } |
67 | |
68 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } |
69 | |
70 | static constexpr const char* _type_key = "tir.StringImm" ; |
71 | TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); |
72 | }; |
73 | |
74 | /*! |
75 | * \brief Managed reference to StringImmNode. |
76 | * \sa StringImmNode |
77 | */ |
78 | class StringImm : public PrimExpr { |
79 | public: |
80 | TVM_DLL StringImm(String value, Span span = Span()); |
81 | TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); |
82 | TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); |
83 | }; |
84 | |
85 | /*! |
86 | * \brief Cast value from one data type to another. |
87 | * \note The lanes of value should keep fixed. |
88 | */ |
89 | class CastNode : public PrimExprNode { |
90 | public: |
91 | /*! \brief Original data type. */ |
92 | PrimExpr value; |
93 | |
94 | void VisitAttrs(AttrVisitor* v) { |
95 | v->Visit("dtype" , &dtype); |
96 | v->Visit("value" , &value); |
97 | v->Visit("span" , &span); |
98 | } |
99 | |
100 | bool SEqualReduce(const CastNode* other, SEqualReducer equal) const { |
101 | return equal(dtype, other->dtype) && equal(value, other->value); |
102 | } |
103 | |
104 | void SHashReduce(SHashReducer hash_reduce) const { |
105 | hash_reduce(dtype); |
106 | hash_reduce(value); |
107 | } |
108 | |
109 | static constexpr const char* _type_key = "tir.Cast" ; |
110 | TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); |
111 | }; |
112 | |
113 | /*! |
114 | * \brief Managed reference to CastNode |
115 | * \sa CastNode |
116 | */ |
117 | class Cast : public PrimExpr { |
118 | public: |
119 | TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); |
120 | TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); |
121 | TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); |
122 | }; |
123 | |
124 | /*! |
125 | * \brief Base template to implement binary ops. |
126 | * \tparam T The type of the child class. |
127 | */ |
128 | template <typename T> |
129 | class BinaryOpNode : public PrimExprNode { |
130 | public: |
131 | /*! \brief The left operand. */ |
132 | PrimExpr a; |
133 | /*! \brief The right operand. */ |
134 | PrimExpr b; |
135 | |
136 | void VisitAttrs(AttrVisitor* v) { |
137 | v->Visit("dtype" , &(this->dtype)); |
138 | v->Visit("a" , &a); |
139 | v->Visit("b" , &b); |
140 | v->Visit("span" , &span); |
141 | } |
142 | |
143 | bool SEqualReduce(const T* other, SEqualReducer equal) const { |
144 | return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
145 | } |
146 | |
147 | void SHashReduce(SHashReducer hash_reduce) const { |
148 | hash_reduce(dtype); |
149 | hash_reduce(a); |
150 | hash_reduce(b); |
151 | } |
152 | |
153 | TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); |
154 | }; |
155 | |
156 | /*! \brief a + b */ |
157 | class AddNode : public BinaryOpNode<AddNode> { |
158 | public: |
159 | static constexpr const char* _type_key = "tir.Add" ; |
160 | }; |
161 | |
162 | /*! |
163 | * \brief Managed reference to AddNode |
164 | * \sa AddNode |
165 | */ |
166 | class Add : public PrimExpr { |
167 | public: |
168 | TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span()); |
169 | TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); |
170 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode); |
171 | }; |
172 | |
173 | /*! \brief a - b */ |
174 | class SubNode : public BinaryOpNode<SubNode> { |
175 | public: |
176 | static constexpr const char* _type_key = "tir.Sub" ; |
177 | }; |
178 | |
179 | /*! |
180 | * \brief Managed reference to SubNode |
181 | * \sa SubNode |
182 | */ |
183 | class Sub : public PrimExpr { |
184 | public: |
185 | TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span()); |
186 | TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); |
187 | TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode); |
188 | }; |
189 | |
190 | /*! \brief a * b */ |
191 | class MulNode : public BinaryOpNode<MulNode> { |
192 | public: |
193 | static constexpr const char* _type_key = "tir.Mul" ; |
194 | }; |
195 | |
196 | /*! |
197 | * \brief Managed reference to MulNode |
198 | * \sa MulNode |
199 | */ |
200 | class Mul : public PrimExpr { |
201 | public: |
202 | TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span()); |
203 | TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); |
204 | TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode); |
205 | }; |
206 | |
207 | /*! |
208 | * \brief a / b in the C semnatics. |
209 | * \note For integer division, C standard uses trunc div. |
210 | */ |
211 | class DivNode : public BinaryOpNode<DivNode> { |
212 | public: |
213 | static constexpr const char* _type_key = "tir.Div" ; |
214 | }; |
215 | |
216 | /*! |
217 | * \brief Managed reference to DivNode |
218 | * \sa DivNode |
219 | */ |
220 | class Div : public PrimExpr { |
221 | public: |
222 | TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span()); |
223 | TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); |
224 | TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode); |
225 | }; |
226 | |
227 | /*! |
228 | * \brief a % b in the C semnatics. |
229 | * \note For integer division, C standard uses trunc div. |
230 | */ |
231 | class ModNode : public BinaryOpNode<ModNode> { |
232 | public: |
233 | static constexpr const char* _type_key = "tir.Mod" ; |
234 | }; |
235 | |
236 | /*! |
237 | * \brief Managed reference to ModNode |
238 | * \sa ModNode |
239 | */ |
240 | class Mod : public PrimExpr { |
241 | public: |
242 | TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span()); |
243 | TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); |
244 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode); |
245 | }; |
246 | |
247 | /*! \brief Floor division, floor(a/b) */ |
248 | class FloorDivNode : public BinaryOpNode<FloorDivNode> { |
249 | public: |
250 | static constexpr const char* _type_key = "tir.FloorDiv" ; |
251 | }; |
252 | |
253 | /*! |
254 | * \brief Managed reference to FloorDivNode |
255 | * \sa FloorDivNode |
256 | */ |
257 | class FloorDiv : public PrimExpr { |
258 | public: |
259 | TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span()); |
260 | TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); |
261 | TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); |
262 | }; |
263 | |
264 | /*! \brief The remainder of the floordiv */ |
265 | class FloorModNode : public BinaryOpNode<FloorModNode> { |
266 | public: |
267 | static constexpr const char* _type_key = "tir.FloorMod" ; |
268 | }; |
269 | |
270 | /*! |
271 | * \brief Managed reference to FloorModNode |
272 | * \sa FloorModNode |
273 | */ |
274 | class FloorMod : public PrimExpr { |
275 | public: |
276 | TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span()); |
277 | TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); |
278 | TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode); |
279 | }; |
280 | |
281 | /*! \brief min(a, b) */ |
282 | class MinNode : public BinaryOpNode<MinNode> { |
283 | public: |
284 | static constexpr const char* _type_key = "tir.Min" ; |
285 | }; |
286 | |
287 | /*! |
288 | * \brief Managed reference to MinNode |
289 | * \sa MinNode |
290 | */ |
291 | class Min : public PrimExpr { |
292 | public: |
293 | TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span()); |
294 | TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); |
295 | TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode); |
296 | }; |
297 | |
298 | /*! \brief max(a, b) */ |
299 | class MaxNode : public BinaryOpNode<MaxNode> { |
300 | public: |
301 | static constexpr const char* _type_key = "tir.Max" ; |
302 | }; |
303 | |
304 | /*! |
305 | * \brief Managed reference to MaxNode |
306 | * \sa MaxNode |
307 | */ |
308 | class Max : public PrimExpr { |
309 | public: |
310 | TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span()); |
311 | TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); |
312 | TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode); |
313 | }; |
314 | |
315 | /*! |
316 | * \brief Base template to implement comparison ops. |
317 | * \tparam T The type of the child class. |
318 | */ |
319 | template <typename T> |
320 | class CmpOpNode : public PrimExprNode { |
321 | public: |
322 | /*! \brief The left operand. */ |
323 | PrimExpr a; |
324 | /*! \brief The right operand. */ |
325 | PrimExpr b; |
326 | |
327 | void VisitAttrs(AttrVisitor* v) { |
328 | v->Visit("dtype" , &(this->dtype)); |
329 | v->Visit("a" , &a); |
330 | v->Visit("b" , &b); |
331 | v->Visit("span" , &span); |
332 | } |
333 | |
334 | bool SEqualReduce(const T* other, SEqualReducer equal) const { |
335 | return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
336 | } |
337 | |
338 | void SHashReduce(SHashReducer hash_reduce) const { |
339 | hash_reduce(dtype); |
340 | hash_reduce(a); |
341 | hash_reduce(b); |
342 | } |
343 | |
344 | TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); |
345 | }; |
346 | |
347 | /*! \brief a == b */ |
348 | class EQNode : public CmpOpNode<EQNode> { |
349 | public: |
350 | static constexpr const char* _type_key = "tir.EQ" ; |
351 | }; |
352 | |
353 | /*! |
354 | * \brief Managed reference to EQNode |
355 | * \sa EQNode |
356 | */ |
357 | class EQ : public PrimExpr { |
358 | public: |
359 | TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span()); |
360 | TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); |
361 | TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode); |
362 | }; |
363 | |
364 | /*! \brief a != b */ |
365 | class NENode : public CmpOpNode<NENode> { |
366 | public: |
367 | static constexpr const char* _type_key = "tir.NE" ; |
368 | }; |
369 | |
370 | /*! |
371 | * \brief Managed reference to NENode |
372 | * \sa NENode |
373 | */ |
374 | class NE : public PrimExpr { |
375 | public: |
376 | TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span()); |
377 | TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); |
378 | TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode); |
379 | }; |
380 | |
381 | /*! \brief a < b */ |
382 | class LTNode : public CmpOpNode<LTNode> { |
383 | public: |
384 | static constexpr const char* _type_key = "tir.LT" ; |
385 | }; |
386 | |
387 | /*! |
388 | * \brief Managed reference to LTNode |
389 | * \sa LTNode |
390 | */ |
391 | class LT : public PrimExpr { |
392 | public: |
393 | TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span()); |
394 | TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); |
395 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode); |
396 | }; |
397 | |
398 | /*! \brief a <= b */ |
399 | struct LENode : public CmpOpNode<LENode> { |
400 | public: |
401 | static constexpr const char* _type_key = "tir.LE" ; |
402 | }; |
403 | |
404 | /*! |
405 | * \brief Managed reference to LENode |
406 | * \sa LENode |
407 | */ |
408 | class LE : public PrimExpr { |
409 | public: |
410 | TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span()); |
411 | TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); |
412 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode); |
413 | }; |
414 | |
415 | /*! \brief a > b */ |
416 | class GTNode : public CmpOpNode<GTNode> { |
417 | public: |
418 | static constexpr const char* _type_key = "tir.GT" ; |
419 | }; |
420 | |
421 | /*! |
422 | * \brief Managed reference to GTNode |
423 | * \sa GTNode |
424 | */ |
425 | class GT : public PrimExpr { |
426 | public: |
427 | TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span()); |
428 | TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); |
429 | TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode); |
430 | }; |
431 | |
432 | /*! \brief a >= b */ |
433 | class GENode : public CmpOpNode<GENode> { |
434 | public: |
435 | static constexpr const char* _type_key = "tir.GE" ; |
436 | }; |
437 | |
438 | /*! |
439 | * \brief Managed reference to GENode |
440 | * \sa GENode |
441 | */ |
442 | class GE : public PrimExpr { |
443 | public: |
444 | TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span()); |
445 | TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); |
446 | TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode); |
447 | }; |
448 | |
449 | /*! \brief a && b */ |
450 | class AndNode : public PrimExprNode { |
451 | public: |
452 | /*! \brief The left operand. */ |
453 | PrimExpr a; |
454 | /*! \brief The right operand. */ |
455 | PrimExpr b; |
456 | |
457 | void VisitAttrs(AttrVisitor* v) { |
458 | v->Visit("dtype" , &(this->dtype)); |
459 | v->Visit("a" , &a); |
460 | v->Visit("b" , &b); |
461 | v->Visit("span" , &span); |
462 | } |
463 | |
464 | bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { |
465 | return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
466 | } |
467 | |
468 | void SHashReduce(SHashReducer hash_reduce) const { |
469 | hash_reduce(dtype); |
470 | hash_reduce(a); |
471 | hash_reduce(b); |
472 | } |
473 | |
474 | static constexpr const char* _type_key = "tir.And" ; |
475 | TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); |
476 | }; |
477 | |
478 | /*! |
479 | * \brief Managed reference to AndNode |
480 | * \sa AndNode |
481 | */ |
482 | class And : public PrimExpr { |
483 | public: |
484 | TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span()); |
485 | TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); |
486 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode); |
487 | }; |
488 | |
489 | /*! \brief a || b */ |
490 | class OrNode : public PrimExprNode { |
491 | public: |
492 | /*! \brief The left operand. */ |
493 | PrimExpr a; |
494 | /*! \brief The right operand. */ |
495 | PrimExpr b; |
496 | |
497 | void VisitAttrs(AttrVisitor* v) { |
498 | v->Visit("dtype" , &dtype); |
499 | v->Visit("a" , &a); |
500 | v->Visit("b" , &b); |
501 | v->Visit("span" , &span); |
502 | } |
503 | |
504 | bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { |
505 | return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
506 | } |
507 | |
508 | void SHashReduce(SHashReducer hash_reduce) const { |
509 | hash_reduce(dtype); |
510 | hash_reduce(a); |
511 | hash_reduce(b); |
512 | } |
513 | |
514 | static constexpr const char* _type_key = "tir.Or" ; |
515 | TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); |
516 | }; |
517 | |
518 | /*! |
519 | * \brief Managed reference to OrNode |
520 | * \sa OrNode |
521 | */ |
522 | class Or : public PrimExpr { |
523 | public: |
524 | TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span()); |
525 | TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); |
526 | TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode); |
527 | }; |
528 | |
529 | /*! \brief !a */ |
530 | class NotNode : public PrimExprNode { |
531 | public: |
532 | /*! \brief The input operand. */ |
533 | PrimExpr a; |
534 | |
535 | void VisitAttrs(AttrVisitor* v) { |
536 | v->Visit("dtype" , &dtype); |
537 | v->Visit("a" , &a); |
538 | v->Visit("span" , &span); |
539 | } |
540 | |
541 | bool SEqualReduce(const NotNode* other, SEqualReducer equal) const { |
542 | return equal(dtype, other->dtype) && equal(a, other->a); |
543 | } |
544 | |
545 | void SHashReduce(SHashReducer hash_reduce) const { |
546 | hash_reduce(dtype); |
547 | hash_reduce(a); |
548 | } |
549 | |
550 | static constexpr const char* _type_key = "tir.Not" ; |
551 | TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); |
552 | }; |
553 | |
554 | /*! |
555 | * \brief Managed reference to NotNode |
556 | * \sa NotNode |
557 | */ |
558 | class Not : public PrimExpr { |
559 | public: |
560 | TVM_DLL Not(PrimExpr a, Span span = Span()); |
561 | TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); |
562 | TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode); |
563 | }; |
564 | |
565 | /*! |
566 | * \brief return true_value if condition is true, otherwise return false_value. |
567 | * \note Both true_value and false_value could be evaluated |
568 | * regardless of the condition value. |
569 | * Do not use it to guard against out of bound access, |
570 | * please use if_then_else instead. |
571 | */ |
572 | class SelectNode : public PrimExprNode { |
573 | public: |
574 | /*! \brief The condition */ |
575 | PrimExpr condition; |
576 | /*! \brief value to be returned when condition is true. */ |
577 | PrimExpr true_value; |
578 | /*! \brief value to be returned when condition is false. */ |
579 | PrimExpr false_value; |
580 | |
581 | void VisitAttrs(AttrVisitor* v) { |
582 | v->Visit("dtype" , &dtype); |
583 | v->Visit("condition" , &condition); |
584 | v->Visit("true_value" , &true_value); |
585 | v->Visit("false_value" , &false_value); |
586 | v->Visit("span" , &span); |
587 | } |
588 | |
589 | bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { |
590 | return equal(dtype, other->dtype) && equal(condition, other->condition) && |
591 | equal(true_value, other->true_value) && equal(false_value, other->false_value); |
592 | } |
593 | |
594 | void SHashReduce(SHashReducer hash_reduce) const { |
595 | hash_reduce(dtype); |
596 | hash_reduce(condition); |
597 | hash_reduce(true_value); |
598 | hash_reduce(false_value); |
599 | } |
600 | |
601 | static constexpr const char* _type_key = "tir.Select" ; |
602 | TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); |
603 | }; |
604 | |
605 | /*! |
606 | * \brief Managed reference to SelectNode |
607 | * \sa SelectNode |
608 | */ |
609 | class Select : public PrimExpr { |
610 | public: |
611 | TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span()); |
612 | |
613 | TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); |
614 | TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode); |
615 | }; |
616 | |
617 | /*! |
618 | * \brief Load value from the high dimension buffer. |
619 | * |
620 | * \code |
621 | * |
622 | * value = buffer[i, j]; |
623 | * |
624 | * \endcode |
625 | * \sa BufferStore |
626 | */ |
627 | class BufferLoadNode : public PrimExprNode { |
628 | public: |
629 | /*! \brief The buffer variable. */ |
630 | Buffer buffer; |
631 | /*! \brief The indices location to be loaded. */ |
632 | Array<PrimExpr> indices; |
633 | |
634 | void VisitAttrs(AttrVisitor* v) { |
635 | v->Visit("dtype" , &(this->dtype)); |
636 | v->Visit("buffer" , &buffer); |
637 | v->Visit("indices" , &indices); |
638 | v->Visit("span" , &span); |
639 | } |
640 | |
641 | bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { |
642 | return equal(dtype, other->dtype) && equal(buffer, other->buffer) && |
643 | equal(indices, other->indices); |
644 | } |
645 | |
646 | void SHashReduce(SHashReducer hash_reduce) const { |
647 | hash_reduce(dtype); |
648 | hash_reduce(buffer); |
649 | hash_reduce(indices); |
650 | } |
651 | |
652 | static constexpr const char* _type_key = "tir.BufferLoad" ; |
653 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); |
654 | |
655 | private: |
656 | /*! \brief Set the dtype based on the buffer/indices |
657 | * |
658 | * Usually, the BufferLoad's dtype will be the same dtype as the |
659 | * buffer. This may have a different number of lanes than the |
660 | * buffer's dtype if index values have more than 1 lane. |
661 | * |
662 | * This function should only be called during construction and after |
663 | * CopyOnWrite. Friend class used here to restrict usage. |
664 | */ |
665 | void LegalizeDType(); |
666 | friend class BufferLoad; |
667 | friend class CustomDatatypesLowerer; |
668 | friend class VectorTypeRewriter; |
669 | friend class Vectorizer; |
670 | }; |
671 | |
672 | /*! |
673 | * \brief Managed reference to BufferLoadNode. |
674 | * \sa BufferLoadNode |
675 | */ |
676 | class BufferLoad : public PrimExpr { |
677 | public: |
678 | TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span()); |
679 | TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); |
680 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); |
681 | }; |
682 | |
683 | /*! |
684 | * \brief Load value from the result produced by the producer. |
685 | * |
686 | * \note This node only appears in high-level DSLs that are built on top of the TIR. |
687 | * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower |
688 | * this node before TIR transformations. |
689 | * |
690 | * \sa ProducerLoad, DataProducerNode |
691 | */ |
692 | class ProducerLoadNode : public PrimExprNode { |
693 | public: |
694 | /*! \brief The buffer producer. */ |
695 | DataProducer producer; |
696 | /*! \brief The location arguments. */ |
697 | Array<PrimExpr> indices; |
698 | |
699 | void VisitAttrs(AttrVisitor* v) { |
700 | v->Visit("dtype" , &(this->dtype)); |
701 | v->Visit("producer" , &producer); |
702 | v->Visit("indices" , &indices); |
703 | v->Visit("span" , &span); |
704 | } |
705 | |
706 | bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { |
707 | return equal(dtype, other->dtype) && equal(producer, other->producer) && |
708 | equal(indices, other->indices); |
709 | } |
710 | |
711 | void SHashReduce(SHashReducer hash_reduce) const { |
712 | hash_reduce(dtype); |
713 | hash_reduce(producer); |
714 | hash_reduce(indices); |
715 | } |
716 | |
717 | static constexpr const char* _type_key = "tir.ProducerLoad" ; |
718 | TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); |
719 | }; |
720 | |
721 | /*! |
722 | * \brief Managed reference to ProducerLoadNode. |
723 | * \sa ProducerLoadNode |
724 | */ |
725 | class ProducerLoad : public PrimExpr { |
726 | public: |
727 | TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span()); |
728 | |
729 | TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); |
730 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode); |
731 | }; |
732 | |
733 | /*! |
734 | * \brief Load the value from buffer_var. |
735 | * |
736 | * Equivalent to ((DType*)buffer_var)[index] |
737 | * where DType is the type specified by type().element_of(). |
738 | * |
739 | * For example, if type = float32x3, then the load will corresponds to |
740 | * |
741 | * \code |
742 | * |
743 | * auto buffer = static_cast<float*>(buffer_var); |
744 | * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]); |
745 | * |
746 | * \endcode |
747 | */ |
748 | class LoadNode : public PrimExprNode { |
749 | public: |
750 | /*! \brief The buffer variable. */ |
751 | Var buffer_var; |
752 | /*! \brief The index locations to be loaded. */ |
753 | PrimExpr index; |
754 | /*! \brief The predicate to mask which lanes would be loaded. */ |
755 | PrimExpr predicate; |
756 | |
757 | void VisitAttrs(AttrVisitor* v) { |
758 | v->Visit("dtype" , &dtype); |
759 | v->Visit("buffer_var" , &buffer_var); |
760 | v->Visit("index" , &index); |
761 | v->Visit("predicate" , &predicate); |
762 | v->Visit("span" , &span); |
763 | } |
764 | |
765 | bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { |
766 | return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) && |
767 | equal(index, other->index) && equal(predicate, other->predicate); |
768 | } |
769 | |
770 | void SHashReduce(SHashReducer hash_reduce) const { |
771 | hash_reduce(dtype); |
772 | hash_reduce(buffer_var); |
773 | hash_reduce(index); |
774 | hash_reduce(predicate); |
775 | } |
776 | |
777 | static constexpr const char* _type_key = "tir.Load" ; |
778 | TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); |
779 | }; |
780 | |
781 | /*! |
782 | * \brief Managed reference to LoadNode |
783 | * \sa LoadNode |
784 | */ |
785 | class Load : public PrimExpr { |
786 | public: |
787 | TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, |
788 | Span span = Span()); |
789 | TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); |
790 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode); |
791 | }; |
792 | |
793 | /*! |
794 | * \brief Construct a vector with lanes elements |
795 | * where its i-th element equals base + i * stride. |
796 | * This is useful to construct a index for a continuous vector load. |
797 | * |
798 | * Examples: |
799 | * - ramp(0, 1, 3) = [0, 1, 2] |
800 | * - ramp(1, 2, 4) = [1, 3, 5, 7] |
801 | */ |
802 | class RampNode : public PrimExprNode { |
803 | public: |
804 | /*! \brief The base value. */ |
805 | PrimExpr base; |
806 | /*! \brief The stride of each step. */ |
807 | PrimExpr stride; |
808 | /*! \brief Total number of lanes. */ |
809 | int lanes; |
810 | |
811 | void VisitAttrs(AttrVisitor* v) { |
812 | v->Visit("dtype" , &dtype); |
813 | v->Visit("base" , &base); |
814 | v->Visit("stride" , &stride); |
815 | v->Visit("lanes" , &lanes); |
816 | v->Visit("span" , &span); |
817 | } |
818 | |
819 | bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { |
820 | return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && |
821 | equal(lanes, other->lanes); |
822 | } |
823 | |
824 | void SHashReduce(SHashReducer hash_reduce) const { |
825 | hash_reduce(dtype); |
826 | hash_reduce(base); |
827 | hash_reduce(stride); |
828 | hash_reduce(lanes); |
829 | } |
830 | |
831 | static constexpr const char* _type_key = "tir.Ramp" ; |
832 | TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); |
833 | }; |
834 | |
835 | /*! |
836 | * \brief Managed reference to RampNode |
837 | * \sa RampNode |
838 | */ |
839 | class Ramp : public PrimExpr { |
840 | public: |
841 | TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); |
842 | TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); |
843 | TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); |
844 | }; |
845 | |
846 | /*! \brief Create a vector where all the elements are value. */ |
847 | class BroadcastNode : public PrimExprNode { |
848 | public: |
849 | /*! \brief The base value. */ |
850 | PrimExpr value; |
851 | /*! \brief The number of lanes. */ |
852 | int lanes; |
853 | |
854 | void VisitAttrs(AttrVisitor* v) { |
855 | v->Visit("dtype" , &dtype); |
856 | v->Visit("value" , &value); |
857 | v->Visit("lanes" , &lanes); |
858 | v->Visit("span" , &span); |
859 | } |
860 | |
861 | bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { |
862 | return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); |
863 | } |
864 | |
865 | void SHashReduce(SHashReducer hash_reduce) const { |
866 | hash_reduce(dtype); |
867 | hash_reduce(value); |
868 | hash_reduce(lanes); |
869 | } |
870 | |
871 | static constexpr const char* _type_key = "tir.Broadcast" ; |
872 | TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); |
873 | }; |
874 | |
875 | /*! |
876 | * \brief Managed reference to BroadcastNode |
877 | * \sa BroadcastNode |
878 | */ |
879 | class Broadcast : public PrimExpr { |
880 | public: |
881 | TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); |
882 | TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); |
883 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); |
884 | }; |
885 | |
886 | /*! |
887 | * \brief Let binding. Bind var to value then evaluate body. |
888 | */ |
889 | class LetNode : public PrimExprNode { |
890 | public: |
891 | /*! \brief The variable. */ |
892 | Var var; |
893 | /*! \brief The value to be binded. */ |
894 | PrimExpr value; |
895 | /*! \brief The result expression. */ |
896 | PrimExpr body; |
897 | |
898 | void VisitAttrs(AttrVisitor* v) { |
899 | v->Visit("dtype" , &dtype); |
900 | v->Visit("var" , &var); |
901 | v->Visit("value" , &value); |
902 | v->Visit("body" , &body); |
903 | v->Visit("span" , &span); |
904 | } |
905 | |
906 | bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { |
907 | return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && |
908 | equal(value, other->value) && equal(body, other->body); |
909 | } |
910 | |
911 | void SHashReduce(SHashReducer hash_reduce) const { |
912 | hash_reduce(dtype); |
913 | hash_reduce.DefHash(var); |
914 | hash_reduce(value); |
915 | hash_reduce(body); |
916 | } |
917 | |
918 | static constexpr const char* _type_key = "tir.Let" ; |
919 | TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); |
920 | }; |
921 | |
922 | /*! |
923 | * \brief Managed reference to LetNode |
924 | * \sa LetNode |
925 | */ |
926 | class Let : public PrimExpr { |
927 | public: |
928 | TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span()); |
929 | TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); |
930 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); |
931 | }; |
932 | |
933 | /*! |
934 | * \brief Call node. |
935 | */ |
936 | class CallNode : public PrimExprNode { |
937 | public: |
938 | /*! |
939 | * \brief The operator(function) being invoked |
940 | * |
941 | * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). |
942 | * - It can also be another function in the IRModule (GlobalVar). |
943 | */ |
944 | RelayExpr op; |
945 | |
946 | /*! \brief The arguments. */ |
947 | Array<PrimExpr> args; |
948 | void VisitAttrs(AttrVisitor* v) { |
949 | v->Visit("dtype" , &dtype); |
950 | v->Visit("op" , &op); |
951 | v->Visit("args" , &args); |
952 | v->Visit("span" , &span); |
953 | } |
954 | |
955 | bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { |
956 | return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args); |
957 | } |
958 | |
959 | void SHashReduce(SHashReducer hash_reduce) const { |
960 | hash_reduce(dtype); |
961 | hash_reduce(op); |
962 | hash_reduce(args); |
963 | } |
964 | |
965 | static constexpr const char* _type_key = "tir.Call" ; |
966 | TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); |
967 | }; |
968 | |
969 | /*! |
970 | * \brief Managed reference to CallNode |
971 | * \sa CallNode |
972 | */ |
973 | class Call : public PrimExpr { |
974 | public: |
975 | TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span()); |
976 | TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); |
977 | TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); |
978 | }; |
979 | |
980 | /*! |
981 | * \brief Shuffle instruction. |
982 | * vec = concat(vectors) |
983 | * result = (vec[indices[0]], vec[indices[1]] ...) |
984 | */ |
985 | class ShuffleNode : public PrimExprNode { |
986 | public: |
987 | /*! \brief the input vectors. */ |
988 | Array<PrimExpr> vectors; |
989 | /*! \brief The indices of each element. */ |
990 | Array<PrimExpr> indices; |
991 | |
992 | void VisitAttrs(AttrVisitor* v) { |
993 | v->Visit("dtype" , &dtype); |
994 | v->Visit("vectors" , &vectors); |
995 | v->Visit("indices" , &indices); |
996 | v->Visit("span" , &span); |
997 | } |
998 | |
999 | bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { |
1000 | return equal(dtype, other->dtype) && equal(vectors, other->vectors) && |
1001 | equal(indices, other->indices); |
1002 | } |
1003 | |
1004 | void SHashReduce(SHashReducer hash_reduce) const { |
1005 | hash_reduce(dtype); |
1006 | hash_reduce(vectors); |
1007 | hash_reduce(indices); |
1008 | } |
1009 | |
1010 | static constexpr const char* _type_key = "tir.Shuffle" ; |
1011 | TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); |
1012 | }; |
1013 | |
1014 | /*! |
1015 | * \brief Managed reference to ShuffleNode |
1016 | * \sa ShuffleNode |
1017 | */ |
1018 | class Shuffle : public PrimExpr { |
1019 | public: |
1020 | TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span = Span()); |
1021 | TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors, Span span = Span()); |
1022 | TVM_DLL static PrimExpr (PrimExpr vector, int index, Span span = Span()); |
1023 | |
1024 | TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); |
1025 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode); |
1026 | }; |
1027 | |
1028 | // Reduce operator |
1029 | /*! |
1030 | * \brief A commutative reducer node to represent a commutative |
1031 | * binary operator with identity element |
1032 | */ |
1033 | class CommReducerNode : public Object { |
1034 | public: |
1035 | /*! \brief The left argument of reducer */ |
1036 | Array<Var> lhs; |
1037 | /*! \brief The right argument of reducer */ |
1038 | Array<Var> rhs; |
1039 | /*! \brief The result of reducer */ |
1040 | Array<PrimExpr> result; |
1041 | /*! |
1042 | * \brief The identity element of reducer, which leaves other |
1043 | * elements unchanged when combined with it, with respect to |
1044 | * the binary operation of this reducer uses. |
1045 | */ |
1046 | Array<PrimExpr> identity_element; |
1047 | /*! \brief Function call operator to combine a and b */ |
1048 | Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const; |
1049 | /*! |
1050 | * \brief Span that points to the original source code. |
1051 | * Reserved debug information. |
1052 | */ |
1053 | mutable Span span; |
1054 | |
1055 | void VisitAttrs(AttrVisitor* v) { |
1056 | v->Visit("lhs" , &lhs); |
1057 | v->Visit("rhs" , &rhs); |
1058 | v->Visit("result" , &result); |
1059 | v->Visit("identity_element" , &identity_element); |
1060 | v->Visit("span" , &span); |
1061 | } |
1062 | |
1063 | bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { |
1064 | return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && |
1065 | equal(result, other->result) && equal(identity_element, other->identity_element); |
1066 | } |
1067 | |
1068 | void SHashReduce(SHashReducer hash_reduce) const { |
1069 | hash_reduce.DefHash(lhs); |
1070 | hash_reduce.DefHash(rhs); |
1071 | hash_reduce(result); |
1072 | hash_reduce(identity_element); |
1073 | } |
1074 | |
1075 | static constexpr const char* _type_key = "tir.CommReducer" ; |
1076 | static constexpr const bool _type_has_method_sequal_reduce = true; |
1077 | static constexpr const bool _type_has_method_shash_reduce = true; |
1078 | TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); |
1079 | }; |
1080 | |
1081 | /*! |
1082 | * \brief Managed reference to CommReducerNode |
1083 | * \sa CommReducerNode |
1084 | */ |
1085 | class CommReducer : public ObjectRef { |
1086 | public: |
1087 | TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, |
1088 | Array<PrimExpr> identity_element, Span span = Span()); |
1089 | |
1090 | TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); |
1091 | }; |
1092 | |
1093 | /*! \brief Reduction operator operator */ |
1094 | class ReduceNode : public PrimExprNode { |
1095 | public: |
1096 | /*! \brief The commutative combiner */ |
1097 | CommReducer combiner; |
1098 | /*! \brief The source operand */ |
1099 | Array<PrimExpr> source; |
1100 | /*! \brief The init operand */ |
1101 | Array<PrimExpr> init; |
1102 | /*! \brief The reduction axis */ |
1103 | Array<IterVar> axis; |
1104 | /*! |
1105 | * \brief Predicate on the reduction |
1106 | * Only add the body to reduction if condition is true. |
1107 | */ |
1108 | PrimExpr condition; |
1109 | /*! \brief the index of this reduce node */ |
1110 | int value_index; |
1111 | |
1112 | void VisitAttrs(AttrVisitor* v) { |
1113 | v->Visit("dtype" , &dtype); |
1114 | v->Visit("combiner" , &combiner); |
1115 | v->Visit("source" , &source); |
1116 | v->Visit("init" , &init); |
1117 | v->Visit("axis" , &axis); |
1118 | v->Visit("condition" , &condition); |
1119 | v->Visit("value_index" , &value_index); |
1120 | v->Visit("span" , &span); |
1121 | } |
1122 | |
1123 | bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { |
1124 | // check axis first so IterVars can define the necessary variables. |
1125 | return equal(dtype, other->dtype) && equal(axis, other->axis) && |
1126 | equal(combiner, other->combiner) && equal(source, other->source) && |
1127 | equal(init, other->init) && equal(condition, other->condition) && |
1128 | equal(value_index, other->value_index); |
1129 | } |
1130 | |
1131 | void SHashReduce(SHashReducer hash_reduce) const { |
1132 | hash_reduce(dtype); |
1133 | hash_reduce(axis); |
1134 | hash_reduce(combiner); |
1135 | hash_reduce(source); |
1136 | hash_reduce(init); |
1137 | hash_reduce(condition); |
1138 | hash_reduce(value_index); |
1139 | } |
1140 | |
1141 | static constexpr const char* _type_key = "tir.Reduce" ; |
1142 | TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); |
1143 | }; |
1144 | |
1145 | /*! |
1146 | * \brief Managed reference to ReduceNode |
1147 | * \sa ReduceNode |
1148 | */ |
1149 | class Reduce : public PrimExpr { |
1150 | public: |
1151 | TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition, |
1152 | int value_index, Array<PrimExpr> init, Span span = Span()); |
1153 | |
1154 | TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); |
1155 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); |
1156 | }; |
1157 | |
1158 | /*! \brief Any shape. */ |
1159 | class AnyNode : public PrimExprNode { |
1160 | public: |
1161 | void VisitAttrs(AttrVisitor* v) { |
1162 | v->Visit("dtype" , &dtype); |
1163 | v->Visit("span" , &span); |
1164 | } |
1165 | |
1166 | bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { |
1167 | return equal(dtype, other->dtype); |
1168 | } |
1169 | |
1170 | void SHashReduce(SHashReducer hash_reduce) const {} |
1171 | |
1172 | /*! \brief Convert to var. */ |
1173 | Var ToVar() const { return Var("any_dim" , DataType::Int(32)); } |
1174 | |
1175 | /*! \brief Convert to SizeVar. */ |
1176 | SizeVar ToSizeVar() const { return SizeVar("any_dim" , DataType::Int(32)); } |
1177 | |
1178 | static constexpr const char* _type_key = "tir.Any" ; |
1179 | TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); |
1180 | }; |
1181 | |
1182 | /*! |
1183 | * \brief Managed reference to AnyNode |
1184 | * \sa AnyNode |
1185 | */ |
1186 | class Any : public PrimExpr { |
1187 | public: |
1188 | TVM_DLL Any(Span span = Span()); |
1189 | |
1190 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); |
1191 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode); |
1192 | }; |
1193 | |
1194 | /* |
1195 | * \brief Template function to convert Map to unordered_map |
1196 | * Sometimes useful for API gluing when internal uses unordered_map |
1197 | * \param dmap The container map |
1198 | * \return The corresponding unordered_map. |
1199 | * \tparam K the key of the Map. |
1200 | * \tparam V the value of the Map. |
1201 | */ |
1202 | template <typename K, typename V> |
1203 | inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { |
1204 | std::unordered_map<K, V> ret; |
1205 | for (auto kv : dmap) { |
1206 | ret[kv.first] = kv.second; |
1207 | } |
1208 | return ret; |
1209 | } |
1210 | } // namespace tir |
1211 | } // namespace tvm |
1212 | |
1213 | namespace std { |
1214 | template <> |
1215 | struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; |
1216 | } // namespace std |
1217 | #endif // TVM_TIR_EXPR_H_ |
1218 | |