1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/tir/stmt.h
21 * \brief TIR statements.
22 */
23// Acknowledgement: Many low-level stmts originate from Halide.
24#ifndef TVM_TIR_STMT_H_
25#define TVM_TIR_STMT_H_
26
27#include <tvm/tir/expr.h>
28
29#include <string>
30#include <type_traits>
31#include <utility>
32#include <vector>
33
34namespace tvm {
35namespace tir {
36
37/*! \brief Base node of all statements. */
38class StmtNode : public Object {
39 public:
40 /*!
41 * \brief Span that points to the original source code.
42 * Reserved debug information.
43 */
44 mutable Span span;
45
46 StmtNode() = default;
47 explicit StmtNode(Span span) : span(span) {}
48
49 TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
50
51 static constexpr const char* _type_key = "tir.Stmt";
52 static constexpr const bool _type_has_method_sequal_reduce = true;
53 static constexpr const bool _type_has_method_shash_reduce = true;
54 static constexpr const uint32_t _type_child_slots = 15;
55 TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
56};
57
58/*! \brief Container of all statements */
59class Stmt : public ObjectRef {
60 public:
61 TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode);
62};
63
64/*!
65 * \brief Let binding, bind var to value, then run body.
66 */
67class LetStmtNode : public StmtNode {
68 public:
69 /*! \brief The variable. */
70 Var var;
71 /*! \brief The value to be binded. */
72 PrimExpr value;
73 /*! \brief The body block. */
74 Stmt body;
75
76 void VisitAttrs(AttrVisitor* v) {
77 v->Visit("var", &var);
78 v->Visit("value", &value);
79 v->Visit("body", &body);
80 v->Visit("span", &span);
81 }
82
83 bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
84 return equal.DefEqual(var, other->var) && equal(value, other->value) &&
85 equal(body, other->body);
86 }
87
88 void SHashReduce(SHashReducer hash_reduce) const {
89 hash_reduce.DefHash(var);
90 hash_reduce(value);
91 hash_reduce(body);
92 }
93
94 static constexpr const char* _type_key = "tir.LetStmt";
95 TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
96};
97
98/*!
99 * \brief Managed reference to LetStmtNode.
100 * \sa LetStmtNode
101 */
102class LetStmt : public Stmt {
103 public:
104 TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
105
106 TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
107 TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
108};
109
110/*!
111 * \brief Define certain auxiliary attribute for the body to be a symbolic value.
112 * This provide auxiliary information for IR passes that transforms body.
113 *
114 * In terms of effect, this is equivalent to Block(Evaluate(value), body).
115 *
116 * Examples of possible usage:
117 * - Bound of function, variables.
118 * - Hint which block corresponds to a parallel region.
119 */
120class AttrStmtNode : public StmtNode {
121 public:
122 /*! \brief this is attribute about certain node */
123 ObjectRef node;
124 /*! \brief the type key of the attribute */
125 String attr_key;
126 /*! \brief The attribute value, value is well defined at current scope. */
127 PrimExpr value;
128 /*! \brief The body statement to be executed */
129 Stmt body;
130
131 void VisitAttrs(AttrVisitor* v) {
132 v->Visit("node", &node);
133 v->Visit("attr_key", &attr_key);
134 v->Visit("value", &value);
135 v->Visit("body", &body);
136 v->Visit("span", &span);
137 }
138
139 bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
140 return equal(node, other->node) && equal(attr_key, other->attr_key) &&
141 equal(value, other->value) && equal(body, other->body);
142 }
143
144 void SHashReduce(SHashReducer hash_reduce) const {
145 hash_reduce(node);
146 hash_reduce(attr_key);
147 hash_reduce(value);
148 hash_reduce(body);
149 }
150
151 static constexpr const char* _type_key = "tir.AttrStmt";
152 TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
153};
154
155/*!
156 * \brief Managed reference to AttrStmtNode.
157 * \sa AttrStmtNode
158 */
159class AttrStmt : public Stmt {
160 public:
161 TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
162
163 TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
164 TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
165};
166
167/*!
168 * \brief Assert condition, if an error occurs, return the error message.
169 */
170class AssertStmtNode : public StmtNode {
171 public:
172 /*! \brief Condition to be checked. */
173 PrimExpr condition;
174 /*! \brief Error message when assertion failed. */
175 PrimExpr message;
176 /*!
177 * \brief Body which this assertion holds true.
178 * Will be executed after the assertion.
179 */
180 Stmt body;
181
182 void VisitAttrs(AttrVisitor* v) {
183 v->Visit("condition", &condition);
184 v->Visit("message", &message);
185 v->Visit("body", &body);
186 v->Visit("span", &span);
187 }
188
189 bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
190 return equal(condition, other->condition) && equal(message, other->message) &&
191 equal(body, other->body);
192 }
193
194 void SHashReduce(SHashReducer hash_reduce) const {
195 hash_reduce(condition);
196 hash_reduce(message);
197 hash_reduce(body);
198 }
199
200 static constexpr const char* _type_key = "tir.AssertStmt";
201 TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
202};
203
204/*!
205 * \brief Managed reference to AssertStmtNode.
206 * \sa AssertStmtNode
207 */
208class AssertStmt : public Stmt {
209 public:
210 TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
211
212 TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
213 TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
214};
215
216/*!
217 * \brief Store value to the buffer.
218 *
219 * Equivalent to ((DType*)buffer_var)[index] = value.
220 * where DType is the type specified by type().element_of().
221 *
222 * For example, if type = float32x3, then the store will corresponds to
223 *
224 * \code
225 *
226 * auto buffer = static_cast<float*>(buffer_var);
227 * buffer[index.v0] = value.v0;
228 * buffer[index.v1] = value.v1;
229 * buffer[index.v2] = value.v2;
230 *
231 * \endcode
232 * \sa LoadNode
233 */
234class StoreNode : public StmtNode {
235 public:
236 /*! \brief The buffer variable. */
237 Var buffer_var;
238 /*! \brief The value to be stored. */
239 PrimExpr value;
240 /*! \brief The index locations to be stored. */
241 PrimExpr index;
242 /*! \brief The predicate to mask which lanes would be stored. */
243 PrimExpr predicate;
244
245 void VisitAttrs(AttrVisitor* v) {
246 v->Visit("buffer_var", &buffer_var);
247 v->Visit("value", &value);
248 v->Visit("index", &index);
249 v->Visit("predicate", &predicate);
250 v->Visit("span", &span);
251 }
252
253 bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
254 return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
255 equal(index, other->index) && equal(predicate, other->predicate);
256 }
257
258 void SHashReduce(SHashReducer hash_reduce) const {
259 hash_reduce(buffer_var);
260 hash_reduce(value);
261 hash_reduce(index);
262 hash_reduce(predicate);
263 }
264
265 static constexpr const char* _type_key = "tir.Store";
266 TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
267};
268
269/*!
270 * \brief Managed reference to StoreNode.
271 * \sa StoreNode
272 */
273class Store : public Stmt {
274 public:
275 TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
276 Span span = Span());
277
278 TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
279 TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
280};
281
282/*!
283 * \brief Store value to the high dimension buffer.
284 *
285 * \code
286 *
287 * buffer[i, j] = value;
288 *
289 * \endcode
290 * \sa BufferLoad
291 */
292class BufferStoreNode : public StmtNode {
293 public:
294 /*! \brief The buffer variable. */
295 Buffer buffer;
296 /*! \brief The value to be stored. */
297 PrimExpr value;
298 /*! \brief The indices location to be stored. */
299 Array<PrimExpr> indices;
300
301 void VisitAttrs(AttrVisitor* v) {
302 v->Visit("buffer", &buffer);
303 v->Visit("value", &value);
304 v->Visit("indices", &indices);
305 v->Visit("span", &span);
306 }
307
308 bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
309 return equal(buffer, other->buffer) && equal(value, other->value) &&
310 equal(indices, other->indices);
311 }
312
313 void SHashReduce(SHashReducer hash_reduce) const {
314 hash_reduce(buffer);
315 hash_reduce(value);
316 hash_reduce(indices);
317 }
318
319 static constexpr const char* _type_key = "tir.BufferStore";
320 TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
321};
322
323/*!
324 * \brief Managed reference to BufferStoreNode.
325 * \sa BufferStoreNode
326 */
327class BufferStore : public Stmt {
328 public:
329 TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
330 Span span = Span());
331
332 TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
333 TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
334};
335
336/*!
337 * \brief Annotate the region where the buffer need to
338 * be read and write in the body.
339 * We only need to allocate the space for the corresponding region.
340 *
341 * \note There should be at most one BufferRealize for each buffer.
342 * BufferRealize is not necessary for external buffers,
343 * since they are assumed to be fully allocated.
344 *
345 * \sa BufferLoad, BufferStore
346 */
347class BufferRealizeNode : public StmtNode {
348 public:
349 /*! \brief The buffer variable. */
350 Buffer buffer;
351 /*! \brief Bounds to be realized */
352 Array<Range> bounds;
353 /*! \brief Only realize if condition holds. */
354 PrimExpr condition;
355 /*! \brief The body of realization. */
356 Stmt body;
357
358 void VisitAttrs(AttrVisitor* v) {
359 v->Visit("buffer", &buffer);
360 v->Visit("bounds", &bounds);
361 v->Visit("condition", &condition);
362 v->Visit("body", &body);
363 v->Visit("span", &span);
364 }
365
366 bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
367 return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
368 equal(condition, other->condition) && equal(body, other->body);
369 }
370
371 void SHashReduce(SHashReducer hash_reduce) const {
372 hash_reduce(buffer);
373 hash_reduce(bounds);
374 hash_reduce(condition);
375 hash_reduce(body);
376 }
377
378 BufferRealizeNode() = default;
379 BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
380 Span span = Span())
381 : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {}
382
383 static constexpr const char* _type_key = "tir.BufferRealize";
384 TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
385};
386
387/*!
388 * \brief Managed reference to BufferRealizeNode.
389 * \sa BufferRealizeNode
390 */
391class BufferRealize : public Stmt {
392 public:
393 TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
394 Span span = Span());
395
396 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
397 TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode);
398};
399
400/*!
401 * \brief Store value into mult-dimensional array that will be read by the consumer
402 * of the producer.
403 *
404 * \note This node only appears in high-level DSLs that are built on top of the TIR.
405 * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
406 * this node before TIR transformations.
407 *
408 * \sa DataProducer
409 */
410class ProducerStoreNode : public StmtNode {
411 public:
412 /*! \brief The producer to store the results into. */
413 DataProducer producer;
414 /*! \brief The value to be stored. */
415 PrimExpr value;
416 /*! \brief The index arguments of the function. */
417 Array<PrimExpr> indices;
418
419 void VisitAttrs(AttrVisitor* v) {
420 v->Visit("producer", &producer);
421 v->Visit("value", &value);
422 v->Visit("indices", &indices);
423 v->Visit("span", &span);
424 }
425
426 bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
427 return equal(producer, other->producer) && equal(value, other->value) &&
428 equal(indices, other->indices);
429 }
430
431 void SHashReduce(SHashReducer hash_reduce) const {
432 hash_reduce(producer);
433 hash_reduce(value);
434 hash_reduce(indices);
435 }
436
437 static constexpr const char* _type_key = "tir.ProducerStore";
438 TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
439};
440
441/*!
442 * \brief Managed reference to ProducerStoreNode.
443 * \sa ProducerStoreNode
444 */
445class ProducerStore : public Stmt {
446 public:
447 TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
448 Span span = Span());
449
450 TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
451 TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
452};
453
454/*!
455 * \brief Annotate the bounds where the data produced by the producer
456 * need to be written and read in body.
457 * We will need to allocate space for the corresponding regions.
458 *
459 * \note This node only appears in high-level DSLs that are built on top of the TIR.
460 * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
461 * this node before TIR transformations.
462 *
463 * \sa DataProducer
464 */
465class ProducerRealizeNode : public StmtNode {
466 public:
467 /*! \brief The producer that produces the data. */
468 DataProducer producer;
469 /*! \brief Bounds to be realized. */
470 Region bounds;
471 /*! \brief Only realize if condition holds. */
472 PrimExpr condition;
473 /*! \brief The body of realization. */
474 Stmt body;
475 /*! \brief The storage scope associated with this realization. */
476 String storage_scope;
477
478 void VisitAttrs(AttrVisitor* v) {
479 v->Visit("producer", &producer);
480 v->Visit("bounds", &bounds);
481 v->Visit("condition", &condition);
482 v->Visit("body", &body);
483 v->Visit("storage_scope", &storage_scope);
484 v->Visit("span", &span);
485 }
486
487 bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
488 return equal(producer, other->producer) && equal(bounds, other->bounds) &&
489 equal(condition, other->condition) && equal(body, other->body) &&
490 equal(storage_scope, other->storage_scope);
491 }
492
493 void SHashReduce(SHashReducer hash_reduce) const {
494 hash_reduce(producer);
495 hash_reduce(bounds);
496 hash_reduce(condition);
497 hash_reduce(body);
498 hash_reduce(storage_scope);
499 }
500
501 static constexpr const char* _type_key = "tir.ProducerRealize";
502 TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
503};
504
505/*!
506 * \brief Managed reference to ProducerRealizeNode.
507 * \sa ProducerRealizeNode
508 */
509class ProducerRealize : public Stmt {
510 public:
511 TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
512 String storage_scope = "", Span span = Span());
513
514 TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
515 TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
516};
517
518/*!
519 * \brief Allocate a buffer that can be used in body.
520 */
521class AllocateNode : public StmtNode {
522 public:
523 /*! \brief The buffer variable. */
524 Var buffer_var;
525 /*! \brief The type of the buffer. */
526 DataType dtype;
527 /*! \brief The extents of the buffer. */
528 Array<PrimExpr> extents;
529 /*! \brief Only allocate buffer when condition is satisfied. */
530 PrimExpr condition;
531 /*! \brief The body to be executed. */
532 Stmt body;
533 /*!
534 * \brief Additional annotations about the allocation.
535 *
536 * These annotations can be used as auxiliary hint
537 * to future transformations.
538 */
539 Map<String, ObjectRef> annotations;
540
541 void VisitAttrs(AttrVisitor* v) {
542 v->Visit("buffer_var", &buffer_var);
543 v->Visit("dtype", &dtype);
544 v->Visit("extents", &extents);
545 v->Visit("condition", &condition);
546 v->Visit("body", &body);
547 v->Visit("annotations", &annotations);
548 v->Visit("span", &span);
549 }
550
551 bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
552 return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
553 equal(extents, other->extents) && equal(condition, other->condition) &&
554 equal(body, other->body) && equal(annotations, other->annotations);
555 }
556
557 void SHashReduce(SHashReducer hash_reduce) const {
558 hash_reduce.DefHash(buffer_var);
559 hash_reduce(dtype);
560 hash_reduce(extents);
561 hash_reduce(condition);
562 hash_reduce(body);
563 hash_reduce(annotations);
564 }
565
566 /*!
567 * \brief If the buffer size is constant, return the size.
568 * Otherwise return 0.
569 * \return The result.
570 */
571 int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
572 /*!
573 * \brief If the buffer size is constant, return the size.
574 * Otherwise return 0.
575 * \param extents The extents of the buffer.
576 * \return The result.
577 */
578 TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
579
580 static constexpr const char* _type_key = "tir.Allocate";
581 static constexpr const bool _type_has_method_sequal_reduce = true;
582 static constexpr const bool _type_has_method_shash_reduce = true;
583 TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
584};
585
586/*!
587 * \brief Managed reference to AllocateNode.
588 * \sa AllocateNode
589 */
590class Allocate : public Stmt {
591 public:
592 TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
593 Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
594 Span span = Span());
595
596 TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
597 TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode);
598};
599
600/*!
601 * \brief Allocate a buffer that can be used in body.
602 */
603class AllocateConstNode : public StmtNode {
604 public:
605 /*! \brief The buffer variable. */
606 Var buffer_var;
607 /*! \brief The optional data associated to the constant.
608 */
609 Optional<runtime::NDArray> data;
610 /*!
611 * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index
612 * to indicate the index within "constants" attribute, that is a Array<NDArray> of IRModule.
613 */
614 Optional<Integer> irmod_storage_idx;
615 /*! \brief The type of the buffer. */
616 DataType dtype;
617 /*! \brief The extents of the buffer. */
618 Array<PrimExpr> extents;
619 /*! \brief The body to be executed. */
620 Stmt body;
621 /*!
622 * \brief Additional annotations about the allocation.
623 *
624 * These annotations can be used as auxiliary hint
625 * to future transformations.
626 */
627 Map<String, ObjectRef> annotations;
628
629 void VisitAttrs(AttrVisitor* v) {
630 v->Visit("buffer_var", &buffer_var);
631 v->Visit("data", &data);
632 v->Visit("irmod_storage_idx", &irmod_storage_idx);
633 v->Visit("dtype", &dtype);
634 v->Visit("extents", &extents);
635 v->Visit("body", &body);
636 v->Visit("annotations", &annotations);
637 v->Visit("span", &span);
638 }
639
640 bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
641 return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
642 equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
643 equal(annotations, other->annotations);
644 }
645
646 void SHashReduce(SHashReducer hash_reduce) const {
647 hash_reduce.DefHash(buffer_var);
648 hash_reduce(dtype);
649 hash_reduce(extents);
650 hash_reduce(body);
651 hash_reduce(annotations);
652 hash_reduce(data);
653 }
654
655 /*!
656 * \brief If the buffer size is constant, return the size.
657 * Otherwise return 0.
658 * \return The result.
659 */
660 int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
661 /*!
662 * \brief If the buffer size is constant, return the size.
663 * Otherwise return 0.
664 * \param extents The extents of the buffer.
665 * \return The result.
666 */
667 TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
668
669 static constexpr const char* _type_key = "tir.AllocateConst";
670 static constexpr const bool _type_has_method_sequal_reduce = true;
671 static constexpr const bool _type_has_method_shash_reduce = true;
672 TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
673};
674
675/*!
676 * \brief Managed reference to AllocateConstNode.
677 * \sa AllocateConstNode
678 */
679class AllocateConst : public Stmt {
680 public:
681 /* The constructor to create a IRNode with constant data
682 * depending on the type of ObjectRef, it will either
683 * create AllocateConstNode with irmod_storage_idx or data
684 */
685 TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
686 ObjectRef data_or_idx, Stmt body,
687 Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
688 Span span = Span());
689 TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
690 TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode);
691};
692
693/*! \brief Declare a buffer that can be used in the body */
694class DeclBufferNode : public StmtNode {
695 public:
696 /*! \brief The buffer being declared */
697 Buffer buffer;
698 /*! \brief The body to be executed */
699 Stmt body;
700
701 void VisitAttrs(AttrVisitor* v) {
702 v->Visit("buffer", &buffer);
703 v->Visit("body", &body);
704 v->Visit("span", &span);
705 }
706
707 bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
708 return equal(buffer, other->buffer) && equal(body, other->body);
709 }
710
711 void SHashReduce(SHashReducer hash_reduce) const {
712 hash_reduce(buffer);
713 hash_reduce(body);
714 }
715
716 static constexpr const char* _type_key = "tir.DeclBuffer";
717 TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode);
718};
719
720/*! \brief Managed reference to DeclBufferNode */
721class DeclBuffer : public Stmt {
722 public:
723 TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
724 TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode);
725 TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode);
726};
727
728/*!
729 * \brief The container of seq statement.
730 * Represent a sequence of statements.
731 */
732class SeqStmtNode : public StmtNode {
733 public:
734 /*! \brief internal sequence content. */
735 Array<Stmt> seq;
736
737 /*! \return get the size of the sequence */
738 size_t size() const { return seq.size(); }
739 /*!
740 * \brief Get the index-th element in the sequence.
741 */
742 Stmt operator[](size_t index) const { return seq[index]; }
743
744 void VisitAttrs(AttrVisitor* v) {
745 v->Visit("seq", &seq);
746 v->Visit("span", &span);
747 }
748
749 bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
750 return equal(seq, other->seq);
751 }
752
753 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
754
755 static constexpr const char* _type_key = "tir.SeqStmt";
756 TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
757};
758
759/*! \brief Sequence statement. */
760class SeqStmt : public Stmt {
761 public:
762 /*!
763 * \brief Construct SeqStmt.
764 * \param seq The sequence.
765 * \param span The location of this object in the source code.
766 */
767 TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
768
769 /*! \return get the size of the sequence */
770 size_t size() const { return operator->()->size(); }
771 /*!
772 * \brief Get the index-th element in the sequence.
773 */
774 Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
775 /*!
776 * \brief Construct a sequence statement by flattening
777 * all the arrays and sequences in the arguments
778 * recursively.
779 *
780 * - When an argument is nullptr, it will be ignored.
781 * - When an argument is an array or a SeqStmt, it will be flattened recursively.
782 * - A normal Stmt will be appended to the end of the sequence.
783 *
784 * \note This function can directly return an element
785 * if it is the only element in the sequence.
786 *
787 * \param seq_args The list of arguments to be flattened.
788 * \tparam Args arguments
789 * \return The constructed statement
790 */
791 template <typename... Args>
792 static Stmt Flatten(Args&&... seq_args) {
793 Array<Stmt> seq;
794 runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
795 if (seq.size() == 1) return seq[0];
796 return SeqStmt(seq);
797 }
798 /*! \brief Helper class to flatten sequence of arguments into Array. */
799 class Flattener {
800 public:
801 explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
802
803 void operator()(size_t i, const Stmt& stmt) const {
804 if (!stmt.defined()) return;
805 if (auto* op = stmt.as<SeqStmtNode>()) {
806 operator()(0, op->seq);
807 } else {
808 seq_->push_back(stmt);
809 }
810 }
811
812 template <typename T>
813 void operator()(size_t i, const T& seq) const {
814 for (auto v : seq) {
815 this->operator()(0, v);
816 }
817 }
818
819 private:
820 Array<Stmt>* seq_;
821 };
822
823 TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
824 TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode);
825};
826
827/*!
828 * \brief IfThenElse statment.
829 */
830class IfThenElseNode : public StmtNode {
831 public:
832 /*! \brief The condition. */
833 PrimExpr condition;
834 /*! \brief The branch to be executed when condition is true. */
835 Stmt then_case;
836 /*! \brief The branch to be executed when condition is false, can be null. */
837 Optional<Stmt> else_case;
838
839 void VisitAttrs(AttrVisitor* v) {
840 v->Visit("condition", &condition);
841 v->Visit("then_case", &then_case);
842 v->Visit("else_case", &else_case);
843 v->Visit("span", &span);
844 }
845
846 bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
847 return equal(condition, other->condition) && equal(then_case, other->then_case) &&
848 equal(else_case, other->else_case);
849 }
850
851 void SHashReduce(SHashReducer hash_reduce) const {
852 hash_reduce(condition);
853 hash_reduce(then_case);
854 hash_reduce(else_case);
855 }
856
857 static constexpr const char* _type_key = "tir.IfThenElse";
858 TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
859};
860
861/*!
862 * \brief Managed reference to IfThenElseNode.
863 * \sa IfThenElseNode
864 */
865class IfThenElse : public Stmt {
866 public:
867 TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
868 Span span = Span());
869
870 TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
871 TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
872};
873
874/*!
875 * \brief Evaluates an expression.
876 * This is mostly used for putting a Call node into Stmt.
877 *
878 * If value do not have side-effect, this node can be safely removed.
879 */
880class EvaluateNode : public StmtNode {
881 public:
882 /*! \brief The expression to be evaluated. */
883 PrimExpr value;
884
885 void VisitAttrs(AttrVisitor* v) {
886 v->Visit("value", &value);
887 v->Visit("span", &span);
888 }
889
890 bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
891 return equal(value, other->value);
892 }
893
894 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
895
896 static constexpr const char* _type_key = "tir.Evaluate";
897 TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
898};
899
900/*!
901 * \brief Managed reference to EvaluateNode.
902 * \sa EvaluateNode
903 */
904class Evaluate : public Stmt {
905 public:
906 TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
907
908 explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
909
910 TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
911 TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
912};
913
914/*!
915 * \brief The kind of the loop.
916 *
917 * ForKind can change the control flow semantics
918 * of the loop. So the kind field needs to be considered
919 * in all TIR passes.
920 */
921enum class ForKind : int {
922 /*! \brief default semantics -- serial execution. */
923 kSerial = 0,
924 /*! \brief Parallel execution on CPU. */
925 kParallel = 1,
926 /*!
927 * \brief Vector SIMD loop.
928 * The loop body will be vectorized.
929 */
930 kVectorized = 2,
931 /*! \brief The loop body must be unrolled. */
932 kUnrolled = 3,
933 /*!
934 * \brief The loop variable is bound to a thread in
935 * an environment. In the final stage of lowering,
936 * the loop is simply removed and the loop variable is
937 * mapped to the corresponding context thread.
938 */
939 kThreadBinding = 4
940};
941
942/*!
943 * \brief A for loop, with poissible type annotations.
944 *
945 * \code
946 *
947 * for (loop_var = min; loop_var < min + extent; ++loop_var) {
948 * // body
949 * }
950 * \endcode
951 */
952class ForNode : public StmtNode {
953 public:
954 /*! \brief The loop variable. */
955 Var loop_var;
956 /*! \brief The minimum value of iteration. */
957 PrimExpr min;
958 /*! \brief The extent of the iteration. */
959 PrimExpr extent;
960 /*! \brief The kind of the for loop. */
961 ForKind kind;
962 /*! \brief The body of the for loop. */
963 Stmt body;
964 /*!
965 * \brief Only valid when kind == ForKind::kThreadBinding
966 * The context thread that this loop variable bounds to.
967 */
968 Optional<IterVar> thread_binding;
969 /*!
970 * \brief Additional annotations about the loop.
971 *
972 * These annotations can be used as auxiliary hint
973 * to future transformations. An annotation should
974 * not change the control flow semantics of the loop
975 * and can be ignored in most passes.
976 */
977 Map<String, ObjectRef> annotations;
978
979 void VisitAttrs(AttrVisitor* v) {
980 v->Visit("loop_var", &loop_var);
981 v->Visit("min", &min);
982 v->Visit("extent", &extent);
983 v->Visit("kind", &kind);
984 v->Visit("body", &body);
985 v->Visit("thread_binding", &thread_binding);
986 v->Visit("annotations", &annotations);
987 v->Visit("span", &span);
988 }
989
990 bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
991 return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
992 equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
993 equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
994 }
995
996 void SHashReduce(SHashReducer hash_reduce) const {
997 hash_reduce.DefHash(loop_var);
998 hash_reduce(min);
999 hash_reduce(extent);
1000 hash_reduce(kind);
1001 hash_reduce(body);
1002 hash_reduce(thread_binding);
1003 hash_reduce(annotations);
1004 }
1005
1006 static constexpr const char* _type_key = "tir.For";
1007 TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
1008};
1009
1010/*!
1011 * \brief Managed reference to ForNode.
1012 * \sa ForNode
1013 */
1014class For : public Stmt {
1015 public:
1016 TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
1017 Optional<IterVar> thread_binding = NullOpt,
1018 Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());
1019
1020 TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
1021 TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
1022};
1023
1024/*!
1025 * \brief A While loop
1026 *
1027 * \code
1028 *
1029 * while (condition)
1030 * body
1031 *
1032 * \endcode
1033 */
1034class WhileNode : public StmtNode {
1035 public:
1036 /*! \brief The termination condition. */
1037 PrimExpr condition;
1038 /*! \brief The body of the while loop. */
1039 Stmt body;
1040
1041 void VisitAttrs(AttrVisitor* v) {
1042 v->Visit("condition", &condition);
1043 v->Visit("body", &body);
1044 v->Visit("span", &span);
1045 }
1046
1047 bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
1048 return equal(condition, other->condition) && equal(body, other->body);
1049 }
1050
1051 void SHashReduce(SHashReducer hash_reduce) const {
1052 hash_reduce(condition);
1053 hash_reduce(body);
1054 }
1055
1056 static constexpr const char* _type_key = "tir.While";
1057 TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode);
1058};
1059
1060/*!
1061 * \brief Managed reference to WhileNode.
1062 * \sa WhileNode
1063 */
1064class While : public Stmt {
1065 public:
1066 TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1067
1068 TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
1069 TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
1070};
1071
1072/*!
1073 * \brief A prefetch hint for a buffer
1074 */
1075class PrefetchNode : public StmtNode {
1076 public:
1077 /*! \brief The function to be prefetched. */
1078 Buffer buffer;
1079 /*! \brief Bounds to be prefetched. */
1080 Array<Range> bounds;
1081
1082 void VisitAttrs(AttrVisitor* v) {
1083 v->Visit("buffer", &buffer);
1084 v->Visit("bounds", &bounds);
1085 v->Visit("span", &span);
1086 }
1087
1088 bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1089 return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1090 }
1091
1092 void SHashReduce(SHashReducer hash_reduce) const {
1093 hash_reduce(buffer);
1094 hash_reduce(bounds);
1095 }
1096
1097 PrefetchNode() = default;
1098 PrefetchNode(Buffer buffer, Array<Range> bounds, Span span = Span())
1099 : StmtNode(span), buffer(buffer), bounds(bounds) {}
1100
1101 static constexpr const char* _type_key = "tir.Prefetch";
1102 TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
1103};
1104
1105/*!
1106 * \brief Managed reference to PrefetchNode.
1107 * \sa PrefetchNode
1108 */
1109class Prefetch : public Stmt {
1110 public:
1111 TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1112
1113 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
1114 TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
1115};
1116
1117/*!
1118 * \brief Representing the region of multi-dimensional buffer access.
1119 */
1120class BufferRegionNode : public Object {
1121 public:
1122 /*! \brief The buffer of the buffer region. */
1123 Buffer buffer;
1124 /*! \brief The region array of the buffer region. */
1125 Array<Range> region;
1126
1127 void VisitAttrs(AttrVisitor* v) {
1128 v->Visit("buffer", &buffer);
1129 v->Visit("region", &region);
1130 }
1131
1132 bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1133 return equal(buffer, other->buffer) && equal(region, other->region);
1134 }
1135
1136 void SHashReduce(SHashReducer hash_reduce) const {
1137 hash_reduce(buffer);
1138 hash_reduce(region);
1139 }
1140
1141 static constexpr const char* _type_key = "tir.BufferRegion";
1142 static constexpr const bool _type_has_method_sequal_reduce = true;
1143 static constexpr const bool _type_has_method_shash_reduce = true;
1144 TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object);
1145};
1146
1147/*!
1148 * \brief Managed reference to BufferRegionNode.
1149 * \sa BufferRegionNode
1150 */
1151class BufferRegion : public ObjectRef {
1152 public:
1153 TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1154
1155 /*!
1156 * \brief Create a BufferRegion which is full region of the given buffer.
1157 * \param buffer The buffer to generate full BufferRegion.
1158 * \return The BufferRegion which covers all region of the given buffer
1159 */
1160 TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1161
1162 /*!
1163 * \brief Create a BufferRegion which is a single point of the given buffer.
1164 * \param buffer The buffer to generate single point BufferRegion.
1165 * \param indices The access point indices of the buffer
1166 * \return The BufferRegion which is the single point of the given buffer.
1167 */
1168 TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1169
1170 TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
1171 TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
1172};
1173
1174/*!
1175 * \brief Match introduces a constraint that the source buffer region can be remapped to the data
1176 * layout specified by the buffer field. The constraint can be checked in later part of lowering (or
1177 * optionally during runtime).
1178 *
1179 * MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in
1180 * low-level hardware primitives in the IR and defer the check after the sequence of
1181 * transformations.
1182 */
1183class MatchBufferRegionNode : public Object {
1184 public:
1185 /*! \brief The target buffer. */
1186 Buffer buffer;
1187 /*! \brief The source buffer region. */
1188 BufferRegion source;
1189
1190 void VisitAttrs(AttrVisitor* v) {
1191 v->Visit("buffer", &buffer);
1192 v->Visit("source", &source);
1193 }
1194
1195 bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const {
1196 return equal(buffer, other->buffer) && equal(source, other->source);
1197 }
1198
1199 void SHashReduce(SHashReducer hash_reduce) const {
1200 hash_reduce(buffer);
1201 hash_reduce(source);
1202 }
1203
1204 static constexpr const char* _type_key = "tir.MatchBufferRegion";
1205 static constexpr const bool _type_has_method_sequal_reduce = true;
1206 static constexpr const bool _type_has_method_shash_reduce = true;
1207 TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
1208};
1209
1210/*!
1211 * \brief Managed reference to MatchBufferRegionNode.
1212 * \sa MatchBufferRegionNode
1213 */
1214class MatchBufferRegion : public ObjectRef {
1215 public:
1216 TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1217
1218 TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
1219 TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode);
1220};
1221
1222/*!
1223 * \brief A block is a basic schedule unit in TIR.
1224 * \note Block's body is parameterized by iter vars.
1225 * \code
1226 *
1227 * with T.block(name):
1228 * v0 = T.axis.S(domain, value0)
1229 * v1 = T.axis.R(domain, value1)
1230 * ...
1231 * T.reads([buffer0[start:end, ...], ...])
1232 * T.writes([buffer1[start:end, ...], ...])
1233 * T.where(predicate)
1234 * buffer2 = T.alloc_buffer(shape, dtype)
1235 * buffer3 = T.match_buffer(source_buffer[start:end, ...])
1236 * T.attr({attr_key: attr_value, ...})
1237 * with T.init():
1238 * // init body
1239 * // body
1240 *
1241 * \endcode
1242 */
1243class BlockNode : public StmtNode {
1244 public:
1245 /*! \brief The variables of the block. */
1246 Array<IterVar> iter_vars;
1247 /*! \brief The read buffer regions of the block. */
1248 Array<BufferRegion> reads;
1249 /*! \brief The write buffer regions of the block. */
1250 Array<BufferRegion> writes;
1251 /*! \brief The name_hint of the block. */
1252 String name_hint;
1253 /*! \brief The body of the block. */
1254 Stmt body;
1255 /*!
1256 * \brief The init statement is executed during the first iteration of reduction loops in a
1257 * reduction block. The optional init field allows us to represent initialization and
1258 * reduction update in a single block and transform them collectively.
1259 * We also provide primitives to decompose the init into a separate block during scheduling.
1260 * Init field is `NullOpt` if there is no reduction iter_vars
1261 */
1262 Optional<Stmt> init;
1263 /*! \brief The buffer allocated in the block. */
1264 Array<Buffer> alloc_buffers;
1265 /*! \brief The match buffer regions. */
1266 Array<MatchBufferRegion> match_buffers;
1267 /*! \brief The annotation of the block. */
1268 Map<String, ObjectRef> annotations;
1269
1270 void VisitAttrs(AttrVisitor* v) {
1271 v->Visit("iter_vars", &iter_vars);
1272 v->Visit("reads", &reads);
1273 v->Visit("writes", &writes);
1274 v->Visit("name_hint", &name_hint);
1275 v->Visit("body", &body);
1276 v->Visit("init", &init);
1277 v->Visit("alloc_buffers", &alloc_buffers);
1278 v->Visit("match_buffers", &match_buffers);
1279 v->Visit("annotations", &annotations);
1280 }
1281
1282 bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1283 // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1284 return equal.DefEqual(iter_vars, other->iter_vars) &&
1285 equal(alloc_buffers, other->alloc_buffers) &&
1286 equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1287 equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1288 equal(annotations, other->annotations);
1289 }
1290
1291 void SHashReduce(SHashReducer hash_reduce) const {
1292 hash_reduce.DefHash(iter_vars);
1293 hash_reduce(alloc_buffers);
1294 hash_reduce(match_buffers);
1295 hash_reduce(reads);
1296 hash_reduce(writes);
1297 hash_reduce(body);
1298 hash_reduce(init);
1299 hash_reduce(annotations);
1300 }
1301
1302 static constexpr const char* _type_key = "tir.Block";
1303 TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode);
1304};
1305
1306/*!
1307 * \brief Managed reference to BlockNode.
1308 * \sa BlockNode
1309 */
1310class Block : public Stmt {
1311 public:
1312 TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1313 Array<BufferRegion> writes, String name_hint, Stmt body,
1314 Optional<Stmt> init = NullOpt,
1315 Array<Buffer> alloc_buffers = Array<Buffer>(),
1316 Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
1317 Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
1318 Span span = Span());
1319
1320 TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
1321 TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
1322};
1323
1324/*!
1325 * \brief A block realization node represents execution of the block at the binding values.
1326 */
1327class BlockRealizeNode : public StmtNode {
1328 public:
1329 /*! \brief The corresponding values of the iter vars. */
1330 Array<PrimExpr> iter_values;
1331 /*!
1332 * \brief The predicate of the block realization, the block will only be executed when the
1333 * predicate is true.
1334 */
1335 PrimExpr predicate;
1336 /*! \brief The block to be realized. */
1337 Block block;
1338
1339 void VisitAttrs(AttrVisitor* v) {
1340 v->Visit("iter_values", &iter_values);
1341 v->Visit("predicate", &predicate);
1342 v->Visit("block", &block);
1343 }
1344
1345 bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1346 return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1347 equal(block, other->block);
1348 }
1349
1350 void SHashReduce(SHashReducer hash_reduce) const {
1351 hash_reduce(iter_values);
1352 hash_reduce(predicate);
1353 hash_reduce(block);
1354 }
1355
1356 static constexpr const char* _type_key = "tir.BlockRealize";
1357 TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode);
1358};
1359
1360/*!
1361 * \brief Managed reference to BlockRealizeNode
1362 * \sa BlockRealizeNode
1363 */
1364class BlockRealize : public Stmt {
1365 public:
1366 TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1367 Span span = Span());
1368
1369 TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode);
1370 TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
1371};
1372
1373/*! \brief namespace of possible attributes in AttrStmt.attr_key */
1374namespace attr {
1375// The above attr does not pass to ir stage.
1376/*! \brief Mark launching extent of thread, used by device API. */
1377constexpr const char* thread_extent = "thread_extent";
1378/*! \brief Mark launching of a virtual thread. */
1379constexpr const char* virtual_thread = "virtual_thread";
1380/*! \brief Mark region is processed by a co-proccesor */
1381constexpr const char* coproc_scope = "coproc_scope";
1382/*!
1383 * \brief Mark region creates coprocessor micro ops,
1384 * can be reused if corresponding variable is independent.
1385 */
1386constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1387/*! \brief Mark the scope as volatile access for certain handle. */
1388constexpr const char* volatile_scope = "volatile_scope";
1389/*!
1390 * \brief Mark the scope as generated by extern primitive.
1391 * such scope can contain arbitrary ir program and we need to be careful
1392 * when make certain assumptions about the structure of the program.
1393 */
1394constexpr const char* extern_scope = "extern_scope";
1395/*!
1396 * \brief Mark the scope as when computation start to happen
1397 * This can hint some code generator to create a new function for compute.
1398 */
1399constexpr const char* compute_scope = "compute_scope";
1400/*! \brief Mark storage alignment requirement of buffers */
1401constexpr const char* storage_alignment = "storage_alignment";
1402/*! \brief Mark storage scope of realization */
1403constexpr const char* realize_scope = "realize_scope";
1404/*! \brief The allocation device for global malloc in host. */
1405constexpr const char* device_id = "device_id";
1406/*! \brief The device type. */
1407constexpr const char* device_type = "device_type";
1408/*! \brief Mark of loop scope */
1409constexpr const char* loop_scope = "loop_scope";
1410/*! \brief Mark of reduce scope */
1411constexpr const char* reduce_scope = "reduce_scope";
1412/*! \brief Pragma: auto-unroll, max_step */
1413constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1414/*! \brief Pragma: unroll explicit */
1415constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1416/*! \brief Mark region is guarded by the pragma extension */
1417constexpr const char* pragma_scope_prefix = "pragma_";
1418/*! \brief Import C source or file into the final code gen module */
1419constexpr const char* pragma_import_c = "pragma_import_c";
1420/*! \brief Import llvm source or file into the final code gen module */
1421constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1422/*! \brief Try to modify the AST to support Tensor Core */
1423constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1424/*!
1425 * \brief Mark of prefetch scope, value=offset,
1426 * run prefetch of Tensor on the current loop scope
1427 */
1428constexpr const char* prefetch_scope = "prefetch_scope";
1429/*!
1430 * \brief Marks the layout transforms to be used for a tensor.
1431 *
1432 * Only applies to a DataProducer, as it should be made part of the
1433 * PrimFunc attributes for TIR.
1434 */
1435constexpr const char* layout_transforms = "layout_transforms";
1436/*!
1437 * \brief Marks the physical axis separators
1438 *
1439 * Only applies to a DataProducer, as it should be made part of the
1440 * Buffer definition in a PrimFunc. See `BufferNode::axis_separators`
1441 * for more details.
1442 */
1443constexpr const char* axis_separators = "axis_separators";
1444/*!
1445 * \brief Marks production of double buffer data
1446 */
1447constexpr const char* double_buffer_scope = "double_buffer_scope";
1448/*!
1449 * \brief Marks region used by double buffer write
1450 */
1451constexpr const char* double_buffer_write = "double_buffer_write";
1452/*! \brief Mark realization for rolling buffer optimization */
1453constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1454/*! \brief Mark of scan update scope */
1455constexpr const char* scan_update_scope = "scan_update_scope";
1456/*! \brief Mark of scan init scope */
1457constexpr const char* scan_init_scope = "scan_init_scope";
1458/*!
1459 * \brief Mark alignment of buffer dimension
1460 * stmt.node is Tensor
1461 * stmt.value is tvm_tuple(dim, align, offset)
1462 * This gives hint to require stride of dim to be k * align + offset.
1463 */
1464constexpr const char* buffer_dim_align = "buffer_dim_align";
1465/*! \brief Mark stores/loads with theirs bounds. */
1466constexpr const char* buffer_bound = "buffer_bound";
1467/*!
1468 * \brief Bind the buffer specification to the region of the op
1469 * When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
1470 * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
1471 * The scope represents that we need to bind the storage region of tensor to buffer.
1472 * This will affect replacement of some variables inside the scope that
1473 * corresponds to field of buffer to be the actual expressions of tensor during
1474 * storage flattening phase.
1475 */
1476constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1477// Pipeline related attributes
1478/*! \brief channel read scope */
1479constexpr const char* channel_read_scope = "channel_read_scope";
1480/*! \brief Advance step of channel after end of scope */
1481constexpr const char* channel_read_advance = "channel_read_advance";
1482/*! \brief channel write scope */
1483constexpr const char* channel_write_scope = "channel_write_scope";
1484/*! \brief Advance step of channel after end of scope */
1485constexpr const char* channel_write_advance = "channel_write_advance";
1486/*! \brief pipeline stage scope, implies always execution */
1487constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1488/*! \brief pipeline execution scope, implies the scope can be pipelined. */
1489constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1490
1491/*!
1492 * \brief Mark that it is in the device scope.
1493 */
1494constexpr const char* device_scope = "device_scope";
1495
1496/*!
1497 * \brief Mark that the attached statement runs asynchronously.
1498 */
1499constexpr const char* async_scope = "async_scope";
1500
1501/*!
1502 * \brief Annotations for invoking and synchronizing asynchronous operations.
1503
1504 * Synchronization is done in terms of "queue": It is an abstract entity associated
1505 * with each asynchronous unit, and it tracks invocations and completions of asynchronous
1506 * operations in the FIFO order.
1507 *
1508 * Similarly to PTX instructions commit_group and wait_group, these annotations express
1509 * synchronization by "counting":
1510 *
1511 * async_commit_queue(i): Group one or more invocations of async operations in the given scope,
1512 * and "commit" (or push) them to the queue i. A group of operations committed together is
1513 * awaited as one chunk. Groups committed to the same queue complete in the FIFO order.
1514 *
1515 * async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at
1516 * the queue i. N does not have to be a constant, but some backends may require a constant count.
1517*/
1518constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1519constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1520constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1521
1522/*!
1523 * \brief Mark that the shape of TensorCore fragment
1524 */
1525constexpr const char* fragment_shape = "fragment_shape";
1526
1527/*!
1528 * \brief Mark that the layout of TensorCore fragment
1529 */
1530constexpr const char* fragment_layout = "fragment_layout";
1531
1532/*!
1533 * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
1534 */
1535constexpr const char* hand_threaded = "hand_threaded";
1536
1537/*!
1538 * \brief Mark whether the script-completer need to fill in missing access region
1539 * during script parsing.
1540 * \note The result should be a integer mask with range [0, 4).
1541 * if (mask & 1) the read region should be detected,
1542 * if (mask & 2) the write region should be detected.
1543 */
1544constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1545
1546/*!
1547 * \brief Mark that the loop should be partitioned.
1548 */
1549constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1550
1551/*! \brief Mark the stage of a statement in the software pipeline */
1552constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1553
1554/*! \brief Mark the order of a statement in the software pipeline */
1555constexpr const char* software_pipeline_order = "software_pipeline_order";
1556
1557/*! \brief List stages in the software pipeline that should run asynchronously
1558 * \note All statements in the provided stages are assumed to have asynchronous
1559 * semantics (e.g. CUDA async global to shared memory copy).
1560 */
1561constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1562
1563/*! \brief Mark the buffers which is const access and can be transformed layout. */
1564constexpr const char* layout_free_buffers = "layout_free_buffers";
1565
1566/*! \brief Mark the local stage for the shared memory access should be added. */
1567constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1568
1569/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
1570constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1571
1572/*!
1573 * \brief Mark that the loop should be further skip and bound to environment threads to enable
1574 * cooperative fetching.
1575 */
1576constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1577
1578/*! \brief The allowed range of thread extent in thread bindings */
1579constexpr const char* meta_schedule_thread_extent_low_inclusive =
1580 "meta_schedule.thread_extent_low_inclusive";
1581
1582/*! \brief The allowed range of thread extent in thread bindings */
1583constexpr const char* meta_schedule_thread_extent_high_inclusive =
1584 "meta_schedule.thread_extent_high_inclusive";
1585
1586/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
1587constexpr const char* meta_schedule_random_compute_producer =
1588 "meta_schedule.random_compute_producer";
1589
1590/*! \brief Mark auto-parallel setting on the block. */
1591constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1592
1593/*! \brief Mark auto-vectorize setting on the block. */
1594constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1595
1596/*! \brief Mark auto-unroll setting on the block. */
1597constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1598
1599/*! \brief Mark auto-unroll setting on the block. */
1600constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1601
1602/*! \brief Mark that a block should be further rewritten using tensorization. */
1603constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1604
1605/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
1606constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1607/*!
1608 * \brief Mark that the init statement of a block should be further rewritten using tensorization.
1609 */
1610constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1611
1612/*!
1613 * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
1614 * warp size.
1615 */
1616constexpr const char* warp_execution = "warp_execution";
1617
1618/*! \brief Mark that a block is disallowed in auto inline. */
1619constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1620
1621/*!
1622 * \brief Check if attr_key is a pragma key extension
1623 * \param attr_key The attr key to be compared
1624 * \return true if it is a pragma key
1625 */
1626inline bool IsPragmaKey(const std::string& attr_key) {
1627 return attr_key.compare(0, 7, "pragma_") == 0;
1628}
1629
1630} // namespace attr
1631/*!
1632 * \brief Create a type annotation expression
1633 * \param dtype The data type
1634 * \param span The location of this object in the source code.
1635 * \return Expr a expression with dtype.
1636 */
1637TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1638
1639// overload printing of for type.
1640TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1641
1642// inline implementations
1643inline const char* ForKind2String(ForKind t) {
1644 switch (t) {
1645 case ForKind::kSerial:
1646 return "serial";
1647 case ForKind::kParallel:
1648 return "parallel";
1649 case ForKind::kVectorized:
1650 return "vectorized";
1651 case ForKind::kUnrolled:
1652 return "unroll";
1653 case ForKind::kThreadBinding:
1654 return "thread_binding";
1655 }
1656 LOG(FATAL) << "Unknown ForKind" << t;
1657}
1658
1659} // namespace tir
1660} // namespace tvm
1661#endif // TVM_TIR_STMT_H_
1662