1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/expr.h
22 * \brief TIR expressions.
23 */
24// Acknowledgement: Many low-level IR nodes originate from Halide.
25#ifndef TVM_TIR_EXPR_H_
26#define TVM_TIR_EXPR_H_
27
28#include <tvm/ir/expr.h>
29#include <tvm/node/functor.h>
30#include <tvm/node/node.h>
31#include <tvm/runtime/c_runtime_api.h>
32#include <tvm/runtime/container/array.h>
33#include <tvm/runtime/container/map.h>
34#include <tvm/runtime/container/string.h>
35#include <tvm/runtime/data_type.h>
36#include <tvm/tir/buffer.h>
37#include <tvm/tir/var.h>
38
39#include <algorithm>
40#include <iostream>
41#include <limits>
42#include <string>
43#include <unordered_map>
44#include <utility>
45
46namespace tvm {
47namespace tir {
48
49using IntImmNode = tvm::IntImmNode;
50using FloatImmNode = tvm::FloatImmNode;
51
52/*! \brief String constants, only used in asserts. */
53class StringImmNode : public PrimExprNode {
54 public:
55 /*! \brief The constant value content. */
56 String value;
57
58 void VisitAttrs(AttrVisitor* v) {
59 v->Visit("dtype", &dtype);
60 v->Visit("value", &value);
61 v->Visit("span", &span);
62 }
63
64 bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
65 return equal(value, other->value);
66 }
67
68 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
69
70 static constexpr const char* _type_key = "tir.StringImm";
71 TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
72};
73
74/*!
75 * \brief Managed reference to StringImmNode.
76 * \sa StringImmNode
77 */
78class StringImm : public PrimExpr {
79 public:
80 TVM_DLL StringImm(String value, Span span = Span());
81 TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
82 TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
83};
84
85/*!
86 * \brief Cast value from one data type to another.
87 * \note The lanes of value should keep fixed.
88 */
89class CastNode : public PrimExprNode {
90 public:
91 /*! \brief Original data type. */
92 PrimExpr value;
93
94 void VisitAttrs(AttrVisitor* v) {
95 v->Visit("dtype", &dtype);
96 v->Visit("value", &value);
97 v->Visit("span", &span);
98 }
99
100 bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
101 return equal(dtype, other->dtype) && equal(value, other->value);
102 }
103
104 void SHashReduce(SHashReducer hash_reduce) const {
105 hash_reduce(dtype);
106 hash_reduce(value);
107 }
108
109 static constexpr const char* _type_key = "tir.Cast";
110 TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
111};
112
113/*!
114 * \brief Managed reference to CastNode
115 * \sa CastNode
116 */
117class Cast : public PrimExpr {
118 public:
119 TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
120 TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
121 TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
122};
123
124/*!
125 * \brief Base template to implement binary ops.
126 * \tparam T The type of the child class.
127 */
128template <typename T>
129class BinaryOpNode : public PrimExprNode {
130 public:
131 /*! \brief The left operand. */
132 PrimExpr a;
133 /*! \brief The right operand. */
134 PrimExpr b;
135
136 void VisitAttrs(AttrVisitor* v) {
137 v->Visit("dtype", &(this->dtype));
138 v->Visit("a", &a);
139 v->Visit("b", &b);
140 v->Visit("span", &span);
141 }
142
143 bool SEqualReduce(const T* other, SEqualReducer equal) const {
144 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
145 }
146
147 void SHashReduce(SHashReducer hash_reduce) const {
148 hash_reduce(dtype);
149 hash_reduce(a);
150 hash_reduce(b);
151 }
152
153 TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
154};
155
156/*! \brief a + b */
157class AddNode : public BinaryOpNode<AddNode> {
158 public:
159 static constexpr const char* _type_key = "tir.Add";
160};
161
162/*!
163 * \brief Managed reference to AddNode
164 * \sa AddNode
165 */
166class Add : public PrimExpr {
167 public:
168 TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
169 TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
170 TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
171};
172
173/*! \brief a - b */
174class SubNode : public BinaryOpNode<SubNode> {
175 public:
176 static constexpr const char* _type_key = "tir.Sub";
177};
178
179/*!
180 * \brief Managed reference to SubNode
181 * \sa SubNode
182 */
183class Sub : public PrimExpr {
184 public:
185 TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
186 TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
187 TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
188};
189
190/*! \brief a * b */
191class MulNode : public BinaryOpNode<MulNode> {
192 public:
193 static constexpr const char* _type_key = "tir.Mul";
194};
195
196/*!
197 * \brief Managed reference to MulNode
198 * \sa MulNode
199 */
200class Mul : public PrimExpr {
201 public:
202 TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
203 TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
204 TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
205};
206
207/*!
208 * \brief a / b in the C semnatics.
209 * \note For integer division, C standard uses trunc div.
210 */
211class DivNode : public BinaryOpNode<DivNode> {
212 public:
213 static constexpr const char* _type_key = "tir.Div";
214};
215
216/*!
217 * \brief Managed reference to DivNode
218 * \sa DivNode
219 */
220class Div : public PrimExpr {
221 public:
222 TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
223 TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
224 TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
225};
226
227/*!
228 * \brief a % b in the C semnatics.
229 * \note For integer division, C standard uses trunc div.
230 */
231class ModNode : public BinaryOpNode<ModNode> {
232 public:
233 static constexpr const char* _type_key = "tir.Mod";
234};
235
236/*!
237 * \brief Managed reference to ModNode
238 * \sa ModNode
239 */
240class Mod : public PrimExpr {
241 public:
242 TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
243 TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
244 TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
245};
246
247/*! \brief Floor division, floor(a/b) */
248class FloorDivNode : public BinaryOpNode<FloorDivNode> {
249 public:
250 static constexpr const char* _type_key = "tir.FloorDiv";
251};
252
253/*!
254 * \brief Managed reference to FloorDivNode
255 * \sa FloorDivNode
256 */
257class FloorDiv : public PrimExpr {
258 public:
259 TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
260 TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
261 TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
262};
263
264/*! \brief The remainder of the floordiv */
265class FloorModNode : public BinaryOpNode<FloorModNode> {
266 public:
267 static constexpr const char* _type_key = "tir.FloorMod";
268};
269
270/*!
271 * \brief Managed reference to FloorModNode
272 * \sa FloorModNode
273 */
274class FloorMod : public PrimExpr {
275 public:
276 TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
277 TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
278 TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
279};
280
281/*! \brief min(a, b) */
282class MinNode : public BinaryOpNode<MinNode> {
283 public:
284 static constexpr const char* _type_key = "tir.Min";
285};
286
287/*!
288 * \brief Managed reference to MinNode
289 * \sa MinNode
290 */
291class Min : public PrimExpr {
292 public:
293 TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
294 TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
295 TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
296};
297
298/*! \brief max(a, b) */
299class MaxNode : public BinaryOpNode<MaxNode> {
300 public:
301 static constexpr const char* _type_key = "tir.Max";
302};
303
304/*!
305 * \brief Managed reference to MaxNode
306 * \sa MaxNode
307 */
308class Max : public PrimExpr {
309 public:
310 TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
311 TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
312 TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
313};
314
315/*!
316 * \brief Base template to implement comparison ops.
317 * \tparam T The type of the child class.
318 */
319template <typename T>
320class CmpOpNode : public PrimExprNode {
321 public:
322 /*! \brief The left operand. */
323 PrimExpr a;
324 /*! \brief The right operand. */
325 PrimExpr b;
326
327 void VisitAttrs(AttrVisitor* v) {
328 v->Visit("dtype", &(this->dtype));
329 v->Visit("a", &a);
330 v->Visit("b", &b);
331 v->Visit("span", &span);
332 }
333
334 bool SEqualReduce(const T* other, SEqualReducer equal) const {
335 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
336 }
337
338 void SHashReduce(SHashReducer hash_reduce) const {
339 hash_reduce(dtype);
340 hash_reduce(a);
341 hash_reduce(b);
342 }
343
344 TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
345};
346
347/*! \brief a == b */
348class EQNode : public CmpOpNode<EQNode> {
349 public:
350 static constexpr const char* _type_key = "tir.EQ";
351};
352
353/*!
354 * \brief Managed reference to EQNode
355 * \sa EQNode
356 */
357class EQ : public PrimExpr {
358 public:
359 TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
360 TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
361 TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
362};
363
364/*! \brief a != b */
365class NENode : public CmpOpNode<NENode> {
366 public:
367 static constexpr const char* _type_key = "tir.NE";
368};
369
370/*!
371 * \brief Managed reference to NENode
372 * \sa NENode
373 */
374class NE : public PrimExpr {
375 public:
376 TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
377 TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
378 TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
379};
380
381/*! \brief a < b */
382class LTNode : public CmpOpNode<LTNode> {
383 public:
384 static constexpr const char* _type_key = "tir.LT";
385};
386
387/*!
388 * \brief Managed reference to LTNode
389 * \sa LTNode
390 */
391class LT : public PrimExpr {
392 public:
393 TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
394 TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
395 TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
396};
397
398/*! \brief a <= b */
399struct LENode : public CmpOpNode<LENode> {
400 public:
401 static constexpr const char* _type_key = "tir.LE";
402};
403
404/*!
405 * \brief Managed reference to LENode
406 * \sa LENode
407 */
408class LE : public PrimExpr {
409 public:
410 TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
411 TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
412 TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
413};
414
415/*! \brief a > b */
416class GTNode : public CmpOpNode<GTNode> {
417 public:
418 static constexpr const char* _type_key = "tir.GT";
419};
420
421/*!
422 * \brief Managed reference to GTNode
423 * \sa GTNode
424 */
425class GT : public PrimExpr {
426 public:
427 TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
428 TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
429 TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
430};
431
432/*! \brief a >= b */
433class GENode : public CmpOpNode<GENode> {
434 public:
435 static constexpr const char* _type_key = "tir.GE";
436};
437
438/*!
439 * \brief Managed reference to GENode
440 * \sa GENode
441 */
442class GE : public PrimExpr {
443 public:
444 TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
445 TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
446 TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
447};
448
449/*! \brief a && b */
450class AndNode : public PrimExprNode {
451 public:
452 /*! \brief The left operand. */
453 PrimExpr a;
454 /*! \brief The right operand. */
455 PrimExpr b;
456
457 void VisitAttrs(AttrVisitor* v) {
458 v->Visit("dtype", &(this->dtype));
459 v->Visit("a", &a);
460 v->Visit("b", &b);
461 v->Visit("span", &span);
462 }
463
464 bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
465 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
466 }
467
468 void SHashReduce(SHashReducer hash_reduce) const {
469 hash_reduce(dtype);
470 hash_reduce(a);
471 hash_reduce(b);
472 }
473
474 static constexpr const char* _type_key = "tir.And";
475 TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
476};
477
478/*!
479 * \brief Managed reference to AndNode
480 * \sa AndNode
481 */
482class And : public PrimExpr {
483 public:
484 TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
485 TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
486 TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
487};
488
489/*! \brief a || b */
490class OrNode : public PrimExprNode {
491 public:
492 /*! \brief The left operand. */
493 PrimExpr a;
494 /*! \brief The right operand. */
495 PrimExpr b;
496
497 void VisitAttrs(AttrVisitor* v) {
498 v->Visit("dtype", &dtype);
499 v->Visit("a", &a);
500 v->Visit("b", &b);
501 v->Visit("span", &span);
502 }
503
504 bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
505 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
506 }
507
508 void SHashReduce(SHashReducer hash_reduce) const {
509 hash_reduce(dtype);
510 hash_reduce(a);
511 hash_reduce(b);
512 }
513
514 static constexpr const char* _type_key = "tir.Or";
515 TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
516};
517
518/*!
519 * \brief Managed reference to OrNode
520 * \sa OrNode
521 */
522class Or : public PrimExpr {
523 public:
524 TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
525 TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
526 TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
527};
528
529/*! \brief !a */
530class NotNode : public PrimExprNode {
531 public:
532 /*! \brief The input operand. */
533 PrimExpr a;
534
535 void VisitAttrs(AttrVisitor* v) {
536 v->Visit("dtype", &dtype);
537 v->Visit("a", &a);
538 v->Visit("span", &span);
539 }
540
541 bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
542 return equal(dtype, other->dtype) && equal(a, other->a);
543 }
544
545 void SHashReduce(SHashReducer hash_reduce) const {
546 hash_reduce(dtype);
547 hash_reduce(a);
548 }
549
550 static constexpr const char* _type_key = "tir.Not";
551 TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
552};
553
554/*!
555 * \brief Managed reference to NotNode
556 * \sa NotNode
557 */
558class Not : public PrimExpr {
559 public:
560 TVM_DLL Not(PrimExpr a, Span span = Span());
561 TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
562 TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
563};
564
565/*!
566 * \brief return true_value if condition is true, otherwise return false_value.
567 * \note Both true_value and false_value could be evaluated
568 * regardless of the condition value.
569 * Do not use it to guard against out of bound access,
570 * please use if_then_else instead.
571 */
572class SelectNode : public PrimExprNode {
573 public:
574 /*! \brief The condition */
575 PrimExpr condition;
576 /*! \brief value to be returned when condition is true. */
577 PrimExpr true_value;
578 /*! \brief value to be returned when condition is false. */
579 PrimExpr false_value;
580
581 void VisitAttrs(AttrVisitor* v) {
582 v->Visit("dtype", &dtype);
583 v->Visit("condition", &condition);
584 v->Visit("true_value", &true_value);
585 v->Visit("false_value", &false_value);
586 v->Visit("span", &span);
587 }
588
589 bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
590 return equal(dtype, other->dtype) && equal(condition, other->condition) &&
591 equal(true_value, other->true_value) && equal(false_value, other->false_value);
592 }
593
594 void SHashReduce(SHashReducer hash_reduce) const {
595 hash_reduce(dtype);
596 hash_reduce(condition);
597 hash_reduce(true_value);
598 hash_reduce(false_value);
599 }
600
601 static constexpr const char* _type_key = "tir.Select";
602 TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
603};
604
605/*!
606 * \brief Managed reference to SelectNode
607 * \sa SelectNode
608 */
609class Select : public PrimExpr {
610 public:
611 TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
612
613 TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
614 TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
615};
616
617/*!
618 * \brief Load value from the high dimension buffer.
619 *
620 * \code
621 *
622 * value = buffer[i, j];
623 *
624 * \endcode
625 * \sa BufferStore
626 */
627class BufferLoadNode : public PrimExprNode {
628 public:
629 /*! \brief The buffer variable. */
630 Buffer buffer;
631 /*! \brief The indices location to be loaded. */
632 Array<PrimExpr> indices;
633
634 void VisitAttrs(AttrVisitor* v) {
635 v->Visit("dtype", &(this->dtype));
636 v->Visit("buffer", &buffer);
637 v->Visit("indices", &indices);
638 v->Visit("span", &span);
639 }
640
641 bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
642 return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
643 equal(indices, other->indices);
644 }
645
646 void SHashReduce(SHashReducer hash_reduce) const {
647 hash_reduce(dtype);
648 hash_reduce(buffer);
649 hash_reduce(indices);
650 }
651
652 static constexpr const char* _type_key = "tir.BufferLoad";
653 TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
654
655 private:
656 /*! \brief Set the dtype based on the buffer/indices
657 *
658 * Usually, the BufferLoad's dtype will be the same dtype as the
659 * buffer. This may have a different number of lanes than the
660 * buffer's dtype if index values have more than 1 lane.
661 *
662 * This function should only be called during construction and after
663 * CopyOnWrite. Friend class used here to restrict usage.
664 */
665 void LegalizeDType();
666 friend class BufferLoad;
667 friend class CustomDatatypesLowerer;
668 friend class VectorTypeRewriter;
669 friend class Vectorizer;
670};
671
672/*!
673 * \brief Managed reference to BufferLoadNode.
674 * \sa BufferLoadNode
675 */
676class BufferLoad : public PrimExpr {
677 public:
678 TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
679 TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
680 TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
681};
682
683/*!
684 * \brief Load value from the result produced by the producer.
685 *
686 * \note This node only appears in high-level DSLs that are built on top of the TIR.
687 * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
688 * this node before TIR transformations.
689 *
690 * \sa ProducerLoad, DataProducerNode
691 */
692class ProducerLoadNode : public PrimExprNode {
693 public:
694 /*! \brief The buffer producer. */
695 DataProducer producer;
696 /*! \brief The location arguments. */
697 Array<PrimExpr> indices;
698
699 void VisitAttrs(AttrVisitor* v) {
700 v->Visit("dtype", &(this->dtype));
701 v->Visit("producer", &producer);
702 v->Visit("indices", &indices);
703 v->Visit("span", &span);
704 }
705
706 bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
707 return equal(dtype, other->dtype) && equal(producer, other->producer) &&
708 equal(indices, other->indices);
709 }
710
711 void SHashReduce(SHashReducer hash_reduce) const {
712 hash_reduce(dtype);
713 hash_reduce(producer);
714 hash_reduce(indices);
715 }
716
717 static constexpr const char* _type_key = "tir.ProducerLoad";
718 TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
719};
720
721/*!
722 * \brief Managed reference to ProducerLoadNode.
723 * \sa ProducerLoadNode
724 */
725class ProducerLoad : public PrimExpr {
726 public:
727 TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());
728
729 TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
730 TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
731};
732
733/*!
734 * \brief Load the value from buffer_var.
735 *
736 * Equivalent to ((DType*)buffer_var)[index]
737 * where DType is the type specified by type().element_of().
738 *
739 * For example, if type = float32x3, then the load will corresponds to
740 *
741 * \code
742 *
743 * auto buffer = static_cast<float*>(buffer_var);
744 * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
745 *
746 * \endcode
747 */
748class LoadNode : public PrimExprNode {
749 public:
750 /*! \brief The buffer variable. */
751 Var buffer_var;
752 /*! \brief The index locations to be loaded. */
753 PrimExpr index;
754 /*! \brief The predicate to mask which lanes would be loaded. */
755 PrimExpr predicate;
756
757 void VisitAttrs(AttrVisitor* v) {
758 v->Visit("dtype", &dtype);
759 v->Visit("buffer_var", &buffer_var);
760 v->Visit("index", &index);
761 v->Visit("predicate", &predicate);
762 v->Visit("span", &span);
763 }
764
765 bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
766 return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
767 equal(index, other->index) && equal(predicate, other->predicate);
768 }
769
770 void SHashReduce(SHashReducer hash_reduce) const {
771 hash_reduce(dtype);
772 hash_reduce(buffer_var);
773 hash_reduce(index);
774 hash_reduce(predicate);
775 }
776
777 static constexpr const char* _type_key = "tir.Load";
778 TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
779};
780
781/*!
782 * \brief Managed reference to LoadNode
783 * \sa LoadNode
784 */
785class Load : public PrimExpr {
786 public:
787 TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
788 Span span = Span());
789 TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
790 TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
791};
792
793/*!
794 * \brief Construct a vector with lanes elements
795 * where its i-th element equals base + i * stride.
796 * This is useful to construct a index for a continuous vector load.
797 *
798 * Examples:
799 * - ramp(0, 1, 3) = [0, 1, 2]
800 * - ramp(1, 2, 4) = [1, 3, 5, 7]
801 */
802class RampNode : public PrimExprNode {
803 public:
804 /*! \brief The base value. */
805 PrimExpr base;
806 /*! \brief The stride of each step. */
807 PrimExpr stride;
808 /*! \brief Total number of lanes. */
809 int lanes;
810
811 void VisitAttrs(AttrVisitor* v) {
812 v->Visit("dtype", &dtype);
813 v->Visit("base", &base);
814 v->Visit("stride", &stride);
815 v->Visit("lanes", &lanes);
816 v->Visit("span", &span);
817 }
818
819 bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
820 return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
821 equal(lanes, other->lanes);
822 }
823
824 void SHashReduce(SHashReducer hash_reduce) const {
825 hash_reduce(dtype);
826 hash_reduce(base);
827 hash_reduce(stride);
828 hash_reduce(lanes);
829 }
830
831 static constexpr const char* _type_key = "tir.Ramp";
832 TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
833};
834
835/*!
836 * \brief Managed reference to RampNode
837 * \sa RampNode
838 */
839class Ramp : public PrimExpr {
840 public:
841 TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
842 TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
843 TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
844};
845
846/*! \brief Create a vector where all the elements are value. */
847class BroadcastNode : public PrimExprNode {
848 public:
849 /*! \brief The base value. */
850 PrimExpr value;
851 /*! \brief The number of lanes. */
852 int lanes;
853
854 void VisitAttrs(AttrVisitor* v) {
855 v->Visit("dtype", &dtype);
856 v->Visit("value", &value);
857 v->Visit("lanes", &lanes);
858 v->Visit("span", &span);
859 }
860
861 bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
862 return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes);
863 }
864
865 void SHashReduce(SHashReducer hash_reduce) const {
866 hash_reduce(dtype);
867 hash_reduce(value);
868 hash_reduce(lanes);
869 }
870
871 static constexpr const char* _type_key = "tir.Broadcast";
872 TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
873};
874
875/*!
876 * \brief Managed reference to BroadcastNode
877 * \sa BroadcastNode
878 */
879class Broadcast : public PrimExpr {
880 public:
881 TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
882 TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
883 TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
884};
885
886/*!
887 * \brief Let binding. Bind var to value then evaluate body.
888 */
889class LetNode : public PrimExprNode {
890 public:
891 /*! \brief The variable. */
892 Var var;
893 /*! \brief The value to be binded. */
894 PrimExpr value;
895 /*! \brief The result expression. */
896 PrimExpr body;
897
898 void VisitAttrs(AttrVisitor* v) {
899 v->Visit("dtype", &dtype);
900 v->Visit("var", &var);
901 v->Visit("value", &value);
902 v->Visit("body", &body);
903 v->Visit("span", &span);
904 }
905
906 bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
907 return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) &&
908 equal(value, other->value) && equal(body, other->body);
909 }
910
911 void SHashReduce(SHashReducer hash_reduce) const {
912 hash_reduce(dtype);
913 hash_reduce.DefHash(var);
914 hash_reduce(value);
915 hash_reduce(body);
916 }
917
918 static constexpr const char* _type_key = "tir.Let";
919 TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
920};
921
922/*!
923 * \brief Managed reference to LetNode
924 * \sa LetNode
925 */
926class Let : public PrimExpr {
927 public:
928 TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
929 TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
930 TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
931};
932
933/*!
934 * \brief Call node.
935 */
936class CallNode : public PrimExprNode {
937 public:
938 /*!
939 * \brief The operator(function) being invoked
940 *
941 * - It can be tvm::Op which corresponds to the primitive operators(intrinsics).
942 * - It can also be another function in the IRModule (GlobalVar).
943 */
944 RelayExpr op;
945
946 /*! \brief The arguments. */
947 Array<PrimExpr> args;
948 void VisitAttrs(AttrVisitor* v) {
949 v->Visit("dtype", &dtype);
950 v->Visit("op", &op);
951 v->Visit("args", &args);
952 v->Visit("span", &span);
953 }
954
955 bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
956 return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
957 }
958
959 void SHashReduce(SHashReducer hash_reduce) const {
960 hash_reduce(dtype);
961 hash_reduce(op);
962 hash_reduce(args);
963 }
964
965 static constexpr const char* _type_key = "tir.Call";
966 TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
967};
968
969/*!
970 * \brief Managed reference to CallNode
971 * \sa CallNode
972 */
973class Call : public PrimExpr {
974 public:
975 TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span());
976 TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
977 TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
978};
979
980/*!
981 * \brief Shuffle instruction.
982 * vec = concat(vectors)
983 * result = (vec[indices[0]], vec[indices[1]] ...)
984 */
985class ShuffleNode : public PrimExprNode {
986 public:
987 /*! \brief the input vectors. */
988 Array<PrimExpr> vectors;
989 /*! \brief The indices of each element. */
990 Array<PrimExpr> indices;
991
992 void VisitAttrs(AttrVisitor* v) {
993 v->Visit("dtype", &dtype);
994 v->Visit("vectors", &vectors);
995 v->Visit("indices", &indices);
996 v->Visit("span", &span);
997 }
998
999 bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
1000 return equal(dtype, other->dtype) && equal(vectors, other->vectors) &&
1001 equal(indices, other->indices);
1002 }
1003
1004 void SHashReduce(SHashReducer hash_reduce) const {
1005 hash_reduce(dtype);
1006 hash_reduce(vectors);
1007 hash_reduce(indices);
1008 }
1009
1010 static constexpr const char* _type_key = "tir.Shuffle";
1011 TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
1012};
1013
1014/*!
1015 * \brief Managed reference to ShuffleNode
1016 * \sa ShuffleNode
1017 */
1018class Shuffle : public PrimExpr {
1019 public:
1020 TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span = Span());
1021 TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors, Span span = Span());
1022 TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
1023
1024 TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
1025 TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
1026};
1027
1028// Reduce operator
1029/*!
1030 * \brief A commutative reducer node to represent a commutative
1031 * binary operator with identity element
1032 */
1033class CommReducerNode : public Object {
1034 public:
1035 /*! \brief The left argument of reducer */
1036 Array<Var> lhs;
1037 /*! \brief The right argument of reducer */
1038 Array<Var> rhs;
1039 /*! \brief The result of reducer */
1040 Array<PrimExpr> result;
1041 /*!
1042 * \brief The identity element of reducer, which leaves other
1043 * elements unchanged when combined with it, with respect to
1044 * the binary operation of this reducer uses.
1045 */
1046 Array<PrimExpr> identity_element;
1047 /*! \brief Function call operator to combine a and b */
1048 Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
1049 /*!
1050 * \brief Span that points to the original source code.
1051 * Reserved debug information.
1052 */
1053 mutable Span span;
1054
1055 void VisitAttrs(AttrVisitor* v) {
1056 v->Visit("lhs", &lhs);
1057 v->Visit("rhs", &rhs);
1058 v->Visit("result", &result);
1059 v->Visit("identity_element", &identity_element);
1060 v->Visit("span", &span);
1061 }
1062
1063 bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
1064 return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) &&
1065 equal(result, other->result) && equal(identity_element, other->identity_element);
1066 }
1067
1068 void SHashReduce(SHashReducer hash_reduce) const {
1069 hash_reduce.DefHash(lhs);
1070 hash_reduce.DefHash(rhs);
1071 hash_reduce(result);
1072 hash_reduce(identity_element);
1073 }
1074
1075 static constexpr const char* _type_key = "tir.CommReducer";
1076 static constexpr const bool _type_has_method_sequal_reduce = true;
1077 static constexpr const bool _type_has_method_shash_reduce = true;
1078 TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
1079};
1080
1081/*!
1082 * \brief Managed reference to CommReducerNode
1083 * \sa CommReducerNode
1084 */
1085class CommReducer : public ObjectRef {
1086 public:
1087 TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
1088 Array<PrimExpr> identity_element, Span span = Span());
1089
1090 TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode);
1091};
1092
1093/*! \brief Reduction operator operator */
1094class ReduceNode : public PrimExprNode {
1095 public:
1096 /*! \brief The commutative combiner */
1097 CommReducer combiner;
1098 /*! \brief The source operand */
1099 Array<PrimExpr> source;
1100 /*! \brief The init operand */
1101 Array<PrimExpr> init;
1102 /*! \brief The reduction axis */
1103 Array<IterVar> axis;
1104 /*!
1105 * \brief Predicate on the reduction
1106 * Only add the body to reduction if condition is true.
1107 */
1108 PrimExpr condition;
1109 /*! \brief the index of this reduce node */
1110 int value_index;
1111
1112 void VisitAttrs(AttrVisitor* v) {
1113 v->Visit("dtype", &dtype);
1114 v->Visit("combiner", &combiner);
1115 v->Visit("source", &source);
1116 v->Visit("init", &init);
1117 v->Visit("axis", &axis);
1118 v->Visit("condition", &condition);
1119 v->Visit("value_index", &value_index);
1120 v->Visit("span", &span);
1121 }
1122
1123 bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
1124 // check axis first so IterVars can define the necessary variables.
1125 return equal(dtype, other->dtype) && equal(axis, other->axis) &&
1126 equal(combiner, other->combiner) && equal(source, other->source) &&
1127 equal(init, other->init) && equal(condition, other->condition) &&
1128 equal(value_index, other->value_index);
1129 }
1130
1131 void SHashReduce(SHashReducer hash_reduce) const {
1132 hash_reduce(dtype);
1133 hash_reduce(axis);
1134 hash_reduce(combiner);
1135 hash_reduce(source);
1136 hash_reduce(init);
1137 hash_reduce(condition);
1138 hash_reduce(value_index);
1139 }
1140
1141 static constexpr const char* _type_key = "tir.Reduce";
1142 TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
1143};
1144
1145/*!
1146 * \brief Managed reference to ReduceNode
1147 * \sa ReduceNode
1148 */
1149class Reduce : public PrimExpr {
1150 public:
1151 TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
1152 int value_index, Array<PrimExpr> init, Span span = Span());
1153
1154 TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
1155 TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
1156};
1157
1158/*! \brief Any shape. */
1159class AnyNode : public PrimExprNode {
1160 public:
1161 void VisitAttrs(AttrVisitor* v) {
1162 v->Visit("dtype", &dtype);
1163 v->Visit("span", &span);
1164 }
1165
1166 bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
1167 return equal(dtype, other->dtype);
1168 }
1169
1170 void SHashReduce(SHashReducer hash_reduce) const {}
1171
1172 /*! \brief Convert to var. */
1173 Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
1174
1175 /*! \brief Convert to SizeVar. */
1176 SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); }
1177
1178 static constexpr const char* _type_key = "tir.Any";
1179 TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
1180};
1181
1182/*!
1183 * \brief Managed reference to AnyNode
1184 * \sa AnyNode
1185 */
1186class Any : public PrimExpr {
1187 public:
1188 TVM_DLL Any(Span span = Span());
1189
1190 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
1191 TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
1192};
1193
1194/*
1195 * \brief Template function to convert Map to unordered_map
1196 * Sometimes useful for API gluing when internal uses unordered_map
1197 * \param dmap The container map
1198 * \return The corresponding unordered_map.
1199 * \tparam K the key of the Map.
1200 * \tparam V the value of the Map.
1201 */
1202template <typename K, typename V>
1203inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
1204 std::unordered_map<K, V> ret;
1205 for (auto kv : dmap) {
1206 ret[kv.first] = kv.second;
1207 }
1208 return ret;
1209}
1210} // namespace tir
1211} // namespace tvm
1212
1213namespace std {
1214template <>
1215struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
1216} // namespace std
1217#endif // TVM_TIR_EXPR_H_
1218