1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/tir/stmt.h |
21 | * \brief TIR statements. |
22 | */ |
23 | // Acknowledgement: Many low-level stmts originate from Halide. |
24 | #ifndef TVM_TIR_STMT_H_ |
25 | #define TVM_TIR_STMT_H_ |
26 | |
27 | #include <tvm/tir/expr.h> |
28 | |
29 | #include <string> |
30 | #include <type_traits> |
31 | #include <utility> |
32 | #include <vector> |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | /*! \brief Base node of all statements. */ |
38 | class StmtNode : public Object { |
39 | public: |
40 | /*! |
41 | * \brief Span that points to the original source code. |
42 | * Reserved debug information. |
43 | */ |
44 | mutable Span span; |
45 | |
46 | StmtNode() = default; |
47 | explicit StmtNode(Span span) : span(span) {} |
48 | |
49 | TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); |
50 | |
51 | static constexpr const char* _type_key = "tir.Stmt" ; |
52 | static constexpr const bool _type_has_method_sequal_reduce = true; |
53 | static constexpr const bool _type_has_method_shash_reduce = true; |
54 | static constexpr const uint32_t _type_child_slots = 15; |
55 | TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); |
56 | }; |
57 | |
58 | /*! \brief Container of all statements */ |
59 | class Stmt : public ObjectRef { |
60 | public: |
61 | TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); |
62 | }; |
63 | |
64 | /*! |
65 | * \brief Let binding, bind var to value, then run body. |
66 | */ |
67 | class LetStmtNode : public StmtNode { |
68 | public: |
69 | /*! \brief The variable. */ |
70 | Var var; |
71 | /*! \brief The value to be binded. */ |
72 | PrimExpr value; |
73 | /*! \brief The body block. */ |
74 | Stmt body; |
75 | |
76 | void VisitAttrs(AttrVisitor* v) { |
77 | v->Visit("var" , &var); |
78 | v->Visit("value" , &value); |
79 | v->Visit("body" , &body); |
80 | v->Visit("span" , &span); |
81 | } |
82 | |
83 | bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { |
84 | return equal.DefEqual(var, other->var) && equal(value, other->value) && |
85 | equal(body, other->body); |
86 | } |
87 | |
88 | void SHashReduce(SHashReducer hash_reduce) const { |
89 | hash_reduce.DefHash(var); |
90 | hash_reduce(value); |
91 | hash_reduce(body); |
92 | } |
93 | |
94 | static constexpr const char* _type_key = "tir.LetStmt" ; |
95 | TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); |
96 | }; |
97 | |
98 | /*! |
99 | * \brief Managed reference to LetStmtNode. |
100 | * \sa LetStmtNode |
101 | */ |
102 | class LetStmt : public Stmt { |
103 | public: |
104 | TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); |
105 | |
106 | TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); |
107 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode); |
108 | }; |
109 | |
110 | /*! |
111 | * \brief Define certain auxiliary attribute for the body to be a symbolic value. |
112 | * This provide auxiliary information for IR passes that transforms body. |
113 | * |
114 | * In terms of effect, this is equivalent to Block(Evaluate(value), body). |
115 | * |
116 | * Examples of possible usage: |
117 | * - Bound of function, variables. |
118 | * - Hint which block corresponds to a parallel region. |
119 | */ |
120 | class AttrStmtNode : public StmtNode { |
121 | public: |
122 | /*! \brief this is attribute about certain node */ |
123 | ObjectRef node; |
124 | /*! \brief the type key of the attribute */ |
125 | String attr_key; |
126 | /*! \brief The attribute value, value is well defined at current scope. */ |
127 | PrimExpr value; |
128 | /*! \brief The body statement to be executed */ |
129 | Stmt body; |
130 | |
131 | void VisitAttrs(AttrVisitor* v) { |
132 | v->Visit("node" , &node); |
133 | v->Visit("attr_key" , &attr_key); |
134 | v->Visit("value" , &value); |
135 | v->Visit("body" , &body); |
136 | v->Visit("span" , &span); |
137 | } |
138 | |
139 | bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { |
140 | return equal(node, other->node) && equal(attr_key, other->attr_key) && |
141 | equal(value, other->value) && equal(body, other->body); |
142 | } |
143 | |
144 | void SHashReduce(SHashReducer hash_reduce) const { |
145 | hash_reduce(node); |
146 | hash_reduce(attr_key); |
147 | hash_reduce(value); |
148 | hash_reduce(body); |
149 | } |
150 | |
151 | static constexpr const char* _type_key = "tir.AttrStmt" ; |
152 | TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); |
153 | }; |
154 | |
155 | /*! |
156 | * \brief Managed reference to AttrStmtNode. |
157 | * \sa AttrStmtNode |
158 | */ |
159 | class AttrStmt : public Stmt { |
160 | public: |
161 | TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); |
162 | |
163 | TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); |
164 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); |
165 | }; |
166 | |
167 | /*! |
168 | * \brief Assert condition, if an error occurs, return the error message. |
169 | */ |
170 | class AssertStmtNode : public StmtNode { |
171 | public: |
172 | /*! \brief Condition to be checked. */ |
173 | PrimExpr condition; |
174 | /*! \brief Error message when assertion failed. */ |
175 | PrimExpr message; |
176 | /*! |
177 | * \brief Body which this assertion holds true. |
178 | * Will be executed after the assertion. |
179 | */ |
180 | Stmt body; |
181 | |
182 | void VisitAttrs(AttrVisitor* v) { |
183 | v->Visit("condition" , &condition); |
184 | v->Visit("message" , &message); |
185 | v->Visit("body" , &body); |
186 | v->Visit("span" , &span); |
187 | } |
188 | |
189 | bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { |
190 | return equal(condition, other->condition) && equal(message, other->message) && |
191 | equal(body, other->body); |
192 | } |
193 | |
194 | void SHashReduce(SHashReducer hash_reduce) const { |
195 | hash_reduce(condition); |
196 | hash_reduce(message); |
197 | hash_reduce(body); |
198 | } |
199 | |
200 | static constexpr const char* _type_key = "tir.AssertStmt" ; |
201 | TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); |
202 | }; |
203 | |
204 | /*! |
205 | * \brief Managed reference to AssertStmtNode. |
206 | * \sa AssertStmtNode |
207 | */ |
208 | class AssertStmt : public Stmt { |
209 | public: |
210 | TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span()); |
211 | |
212 | TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); |
213 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode); |
214 | }; |
215 | |
216 | /*! |
217 | * \brief Store value to the buffer. |
218 | * |
219 | * Equivalent to ((DType*)buffer_var)[index] = value. |
220 | * where DType is the type specified by type().element_of(). |
221 | * |
222 | * For example, if type = float32x3, then the store will corresponds to |
223 | * |
224 | * \code |
225 | * |
226 | * auto buffer = static_cast<float*>(buffer_var); |
227 | * buffer[index.v0] = value.v0; |
228 | * buffer[index.v1] = value.v1; |
229 | * buffer[index.v2] = value.v2; |
230 | * |
231 | * \endcode |
232 | * \sa LoadNode |
233 | */ |
234 | class StoreNode : public StmtNode { |
235 | public: |
236 | /*! \brief The buffer variable. */ |
237 | Var buffer_var; |
238 | /*! \brief The value to be stored. */ |
239 | PrimExpr value; |
240 | /*! \brief The index locations to be stored. */ |
241 | PrimExpr index; |
242 | /*! \brief The predicate to mask which lanes would be stored. */ |
243 | PrimExpr predicate; |
244 | |
245 | void VisitAttrs(AttrVisitor* v) { |
246 | v->Visit("buffer_var" , &buffer_var); |
247 | v->Visit("value" , &value); |
248 | v->Visit("index" , &index); |
249 | v->Visit("predicate" , &predicate); |
250 | v->Visit("span" , &span); |
251 | } |
252 | |
253 | bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const { |
254 | return equal(buffer_var, other->buffer_var) && equal(value, other->value) && |
255 | equal(index, other->index) && equal(predicate, other->predicate); |
256 | } |
257 | |
258 | void SHashReduce(SHashReducer hash_reduce) const { |
259 | hash_reduce(buffer_var); |
260 | hash_reduce(value); |
261 | hash_reduce(index); |
262 | hash_reduce(predicate); |
263 | } |
264 | |
265 | static constexpr const char* _type_key = "tir.Store" ; |
266 | TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); |
267 | }; |
268 | |
269 | /*! |
270 | * \brief Managed reference to StoreNode. |
271 | * \sa StoreNode |
272 | */ |
273 | class Store : public Stmt { |
274 | public: |
275 | TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, |
276 | Span span = Span()); |
277 | |
278 | TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode); |
279 | TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode); |
280 | }; |
281 | |
282 | /*! |
283 | * \brief Store value to the high dimension buffer. |
284 | * |
285 | * \code |
286 | * |
287 | * buffer[i, j] = value; |
288 | * |
289 | * \endcode |
290 | * \sa BufferLoad |
291 | */ |
292 | class BufferStoreNode : public StmtNode { |
293 | public: |
294 | /*! \brief The buffer variable. */ |
295 | Buffer buffer; |
296 | /*! \brief The value to be stored. */ |
297 | PrimExpr value; |
298 | /*! \brief The indices location to be stored. */ |
299 | Array<PrimExpr> indices; |
300 | |
301 | void VisitAttrs(AttrVisitor* v) { |
302 | v->Visit("buffer" , &buffer); |
303 | v->Visit("value" , &value); |
304 | v->Visit("indices" , &indices); |
305 | v->Visit("span" , &span); |
306 | } |
307 | |
308 | bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { |
309 | return equal(buffer, other->buffer) && equal(value, other->value) && |
310 | equal(indices, other->indices); |
311 | } |
312 | |
313 | void SHashReduce(SHashReducer hash_reduce) const { |
314 | hash_reduce(buffer); |
315 | hash_reduce(value); |
316 | hash_reduce(indices); |
317 | } |
318 | |
319 | static constexpr const char* _type_key = "tir.BufferStore" ; |
320 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); |
321 | }; |
322 | |
323 | /*! |
324 | * \brief Managed reference to BufferStoreNode. |
325 | * \sa BufferStoreNode |
326 | */ |
327 | class BufferStore : public Stmt { |
328 | public: |
329 | TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, |
330 | Span span = Span()); |
331 | |
332 | TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); |
333 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); |
334 | }; |
335 | |
336 | /*! |
337 | * \brief Annotate the region where the buffer need to |
338 | * be read and write in the body. |
339 | * We only need to allocate the space for the corresponding region. |
340 | * |
341 | * \note There should be at most one BufferRealize for each buffer. |
342 | * BufferRealize is not necessary for external buffers, |
343 | * since they are assumed to be fully allocated. |
344 | * |
345 | * \sa BufferLoad, BufferStore |
346 | */ |
347 | class BufferRealizeNode : public StmtNode { |
348 | public: |
349 | /*! \brief The buffer variable. */ |
350 | Buffer buffer; |
351 | /*! \brief Bounds to be realized */ |
352 | Array<Range> bounds; |
353 | /*! \brief Only realize if condition holds. */ |
354 | PrimExpr condition; |
355 | /*! \brief The body of realization. */ |
356 | Stmt body; |
357 | |
358 | void VisitAttrs(AttrVisitor* v) { |
359 | v->Visit("buffer" , &buffer); |
360 | v->Visit("bounds" , &bounds); |
361 | v->Visit("condition" , &condition); |
362 | v->Visit("body" , &body); |
363 | v->Visit("span" , &span); |
364 | } |
365 | |
366 | bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { |
367 | return equal(buffer, other->buffer) && equal(bounds, other->bounds) && |
368 | equal(condition, other->condition) && equal(body, other->body); |
369 | } |
370 | |
371 | void SHashReduce(SHashReducer hash_reduce) const { |
372 | hash_reduce(buffer); |
373 | hash_reduce(bounds); |
374 | hash_reduce(condition); |
375 | hash_reduce(body); |
376 | } |
377 | |
378 | BufferRealizeNode() = default; |
379 | BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body, |
380 | Span span = Span()) |
381 | : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {} |
382 | |
383 | static constexpr const char* _type_key = "tir.BufferRealize" ; |
384 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); |
385 | }; |
386 | |
387 | /*! |
388 | * \brief Managed reference to BufferRealizeNode. |
389 | * \sa BufferRealizeNode |
390 | */ |
391 | class BufferRealize : public Stmt { |
392 | public: |
393 | TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body, |
394 | Span span = Span()); |
395 | |
396 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); |
397 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); |
398 | }; |
399 | |
400 | /*! |
401 | * \brief Store value into mult-dimensional array that will be read by the consumer |
402 | * of the producer. |
403 | * |
404 | * \note This node only appears in high-level DSLs that are built on top of the TIR. |
405 | * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower |
406 | * this node before TIR transformations. |
407 | * |
408 | * \sa DataProducer |
409 | */ |
410 | class ProducerStoreNode : public StmtNode { |
411 | public: |
412 | /*! \brief The producer to store the results into. */ |
413 | DataProducer producer; |
414 | /*! \brief The value to be stored. */ |
415 | PrimExpr value; |
416 | /*! \brief The index arguments of the function. */ |
417 | Array<PrimExpr> indices; |
418 | |
419 | void VisitAttrs(AttrVisitor* v) { |
420 | v->Visit("producer" , &producer); |
421 | v->Visit("value" , &value); |
422 | v->Visit("indices" , &indices); |
423 | v->Visit("span" , &span); |
424 | } |
425 | |
426 | bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const { |
427 | return equal(producer, other->producer) && equal(value, other->value) && |
428 | equal(indices, other->indices); |
429 | } |
430 | |
431 | void SHashReduce(SHashReducer hash_reduce) const { |
432 | hash_reduce(producer); |
433 | hash_reduce(value); |
434 | hash_reduce(indices); |
435 | } |
436 | |
437 | static constexpr const char* _type_key = "tir.ProducerStore" ; |
438 | TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); |
439 | }; |
440 | |
441 | /*! |
442 | * \brief Managed reference to ProducerStoreNode. |
443 | * \sa ProducerStoreNode |
444 | */ |
445 | class ProducerStore : public Stmt { |
446 | public: |
447 | TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices, |
448 | Span span = Span()); |
449 | |
450 | TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); |
451 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode); |
452 | }; |
453 | |
454 | /*! |
455 | * \brief Annotate the bounds where the data produced by the producer |
456 | * need to be written and read in body. |
457 | * We will need to allocate space for the corresponding regions. |
458 | * |
459 | * \note This node only appears in high-level DSLs that are built on top of the TIR. |
460 | * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower |
461 | * this node before TIR transformations. |
462 | * |
463 | * \sa DataProducer |
464 | */ |
465 | class ProducerRealizeNode : public StmtNode { |
466 | public: |
467 | /*! \brief The producer that produces the data. */ |
468 | DataProducer producer; |
469 | /*! \brief Bounds to be realized. */ |
470 | Region bounds; |
471 | /*! \brief Only realize if condition holds. */ |
472 | PrimExpr condition; |
473 | /*! \brief The body of realization. */ |
474 | Stmt body; |
475 | /*! \brief The storage scope associated with this realization. */ |
476 | String storage_scope; |
477 | |
478 | void VisitAttrs(AttrVisitor* v) { |
479 | v->Visit("producer" , &producer); |
480 | v->Visit("bounds" , &bounds); |
481 | v->Visit("condition" , &condition); |
482 | v->Visit("body" , &body); |
483 | v->Visit("storage_scope" , &storage_scope); |
484 | v->Visit("span" , &span); |
485 | } |
486 | |
487 | bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { |
488 | return equal(producer, other->producer) && equal(bounds, other->bounds) && |
489 | equal(condition, other->condition) && equal(body, other->body) && |
490 | equal(storage_scope, other->storage_scope); |
491 | } |
492 | |
493 | void SHashReduce(SHashReducer hash_reduce) const { |
494 | hash_reduce(producer); |
495 | hash_reduce(bounds); |
496 | hash_reduce(condition); |
497 | hash_reduce(body); |
498 | hash_reduce(storage_scope); |
499 | } |
500 | |
501 | static constexpr const char* _type_key = "tir.ProducerRealize" ; |
502 | TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); |
503 | }; |
504 | |
505 | /*! |
506 | * \brief Managed reference to ProducerRealizeNode. |
507 | * \sa ProducerRealizeNode |
508 | */ |
509 | class ProducerRealize : public Stmt { |
510 | public: |
511 | TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, |
512 | String storage_scope = "" , Span span = Span()); |
513 | |
514 | TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); |
515 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode); |
516 | }; |
517 | |
518 | /*! |
519 | * \brief Allocate a buffer that can be used in body. |
520 | */ |
521 | class AllocateNode : public StmtNode { |
522 | public: |
523 | /*! \brief The buffer variable. */ |
524 | Var buffer_var; |
525 | /*! \brief The type of the buffer. */ |
526 | DataType dtype; |
527 | /*! \brief The extents of the buffer. */ |
528 | Array<PrimExpr> extents; |
529 | /*! \brief Only allocate buffer when condition is satisfied. */ |
530 | PrimExpr condition; |
531 | /*! \brief The body to be executed. */ |
532 | Stmt body; |
533 | /*! |
534 | * \brief Additional annotations about the allocation. |
535 | * |
536 | * These annotations can be used as auxiliary hint |
537 | * to future transformations. |
538 | */ |
539 | Map<String, ObjectRef> annotations; |
540 | |
541 | void VisitAttrs(AttrVisitor* v) { |
542 | v->Visit("buffer_var" , &buffer_var); |
543 | v->Visit("dtype" , &dtype); |
544 | v->Visit("extents" , &extents); |
545 | v->Visit("condition" , &condition); |
546 | v->Visit("body" , &body); |
547 | v->Visit("annotations" , &annotations); |
548 | v->Visit("span" , &span); |
549 | } |
550 | |
551 | bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { |
552 | return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && |
553 | equal(extents, other->extents) && equal(condition, other->condition) && |
554 | equal(body, other->body) && equal(annotations, other->annotations); |
555 | } |
556 | |
557 | void SHashReduce(SHashReducer hash_reduce) const { |
558 | hash_reduce.DefHash(buffer_var); |
559 | hash_reduce(dtype); |
560 | hash_reduce(extents); |
561 | hash_reduce(condition); |
562 | hash_reduce(body); |
563 | hash_reduce(annotations); |
564 | } |
565 | |
566 | /*! |
567 | * \brief If the buffer size is constant, return the size. |
568 | * Otherwise return 0. |
569 | * \return The result. |
570 | */ |
571 | int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); } |
572 | /*! |
573 | * \brief If the buffer size is constant, return the size. |
574 | * Otherwise return 0. |
575 | * \param extents The extents of the buffer. |
576 | * \return The result. |
577 | */ |
578 | TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents); |
579 | |
580 | static constexpr const char* _type_key = "tir.Allocate" ; |
581 | static constexpr const bool _type_has_method_sequal_reduce = true; |
582 | static constexpr const bool _type_has_method_shash_reduce = true; |
583 | TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); |
584 | }; |
585 | |
586 | /*! |
587 | * \brief Managed reference to AllocateNode. |
588 | * \sa AllocateNode |
589 | */ |
590 | class Allocate : public Stmt { |
591 | public: |
592 | TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition, |
593 | Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), |
594 | Span span = Span()); |
595 | |
596 | TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); |
597 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode); |
598 | }; |
599 | |
600 | /*! |
601 | * \brief Allocate a buffer that can be used in body. |
602 | */ |
603 | class AllocateConstNode : public StmtNode { |
604 | public: |
605 | /*! \brief The buffer variable. */ |
606 | Var buffer_var; |
607 | /*! \brief The optional data associated to the constant. |
608 | */ |
609 | Optional<runtime::NDArray> data; |
610 | /*! |
611 | * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index |
612 | * to indicate the index within "constants" attribute, that is a Array<NDArray> of IRModule. |
613 | */ |
614 | Optional<Integer> irmod_storage_idx; |
615 | /*! \brief The type of the buffer. */ |
616 | DataType dtype; |
617 | /*! \brief The extents of the buffer. */ |
618 | Array<PrimExpr> extents; |
619 | /*! \brief The body to be executed. */ |
620 | Stmt body; |
621 | /*! |
622 | * \brief Additional annotations about the allocation. |
623 | * |
624 | * These annotations can be used as auxiliary hint |
625 | * to future transformations. |
626 | */ |
627 | Map<String, ObjectRef> annotations; |
628 | |
629 | void VisitAttrs(AttrVisitor* v) { |
630 | v->Visit("buffer_var" , &buffer_var); |
631 | v->Visit("data" , &data); |
632 | v->Visit("irmod_storage_idx" , &irmod_storage_idx); |
633 | v->Visit("dtype" , &dtype); |
634 | v->Visit("extents" , &extents); |
635 | v->Visit("body" , &body); |
636 | v->Visit("annotations" , &annotations); |
637 | v->Visit("span" , &span); |
638 | } |
639 | |
640 | bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const { |
641 | return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && |
642 | equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) && |
643 | equal(annotations, other->annotations); |
644 | } |
645 | |
646 | void SHashReduce(SHashReducer hash_reduce) const { |
647 | hash_reduce.DefHash(buffer_var); |
648 | hash_reduce(dtype); |
649 | hash_reduce(extents); |
650 | hash_reduce(body); |
651 | hash_reduce(annotations); |
652 | hash_reduce(data); |
653 | } |
654 | |
655 | /*! |
656 | * \brief If the buffer size is constant, return the size. |
657 | * Otherwise return 0. |
658 | * \return The result. |
659 | */ |
660 | int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); } |
661 | /*! |
662 | * \brief If the buffer size is constant, return the size. |
663 | * Otherwise return 0. |
664 | * \param extents The extents of the buffer. |
665 | * \return The result. |
666 | */ |
667 | TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents); |
668 | |
669 | static constexpr const char* _type_key = "tir.AllocateConst" ; |
670 | static constexpr const bool _type_has_method_sequal_reduce = true; |
671 | static constexpr const bool _type_has_method_shash_reduce = true; |
672 | TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); |
673 | }; |
674 | |
675 | /*! |
676 | * \brief Managed reference to AllocateConstNode. |
677 | * \sa AllocateConstNode |
678 | */ |
679 | class AllocateConst : public Stmt { |
680 | public: |
681 | /* The constructor to create a IRNode with constant data |
682 | * depending on the type of ObjectRef, it will either |
683 | * create AllocateConstNode with irmod_storage_idx or data |
684 | */ |
685 | TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents, |
686 | ObjectRef data_or_idx, Stmt body, |
687 | Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), |
688 | Span span = Span()); |
689 | TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); |
690 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode); |
691 | }; |
692 | |
693 | /*! \brief Declare a buffer that can be used in the body */ |
694 | class DeclBufferNode : public StmtNode { |
695 | public: |
696 | /*! \brief The buffer being declared */ |
697 | Buffer buffer; |
698 | /*! \brief The body to be executed */ |
699 | Stmt body; |
700 | |
701 | void VisitAttrs(AttrVisitor* v) { |
702 | v->Visit("buffer" , &buffer); |
703 | v->Visit("body" , &body); |
704 | v->Visit("span" , &span); |
705 | } |
706 | |
707 | bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const { |
708 | return equal(buffer, other->buffer) && equal(body, other->body); |
709 | } |
710 | |
711 | void SHashReduce(SHashReducer hash_reduce) const { |
712 | hash_reduce(buffer); |
713 | hash_reduce(body); |
714 | } |
715 | |
716 | static constexpr const char* _type_key = "tir.DeclBuffer" ; |
717 | TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode); |
718 | }; |
719 | |
720 | /*! \brief Managed reference to DeclBufferNode */ |
721 | class DeclBuffer : public Stmt { |
722 | public: |
723 | TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span()); |
724 | TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode); |
725 | TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode); |
726 | }; |
727 | |
728 | /*! |
729 | * \brief The container of seq statement. |
730 | * Represent a sequence of statements. |
731 | */ |
732 | class SeqStmtNode : public StmtNode { |
733 | public: |
734 | /*! \brief internal sequence content. */ |
735 | Array<Stmt> seq; |
736 | |
737 | /*! \return get the size of the sequence */ |
738 | size_t size() const { return seq.size(); } |
739 | /*! |
740 | * \brief Get the index-th element in the sequence. |
741 | */ |
742 | Stmt operator[](size_t index) const { return seq[index]; } |
743 | |
744 | void VisitAttrs(AttrVisitor* v) { |
745 | v->Visit("seq" , &seq); |
746 | v->Visit("span" , &span); |
747 | } |
748 | |
749 | bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { |
750 | return equal(seq, other->seq); |
751 | } |
752 | |
753 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } |
754 | |
755 | static constexpr const char* _type_key = "tir.SeqStmt" ; |
756 | TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); |
757 | }; |
758 | |
759 | /*! \brief Sequence statement. */ |
760 | class SeqStmt : public Stmt { |
761 | public: |
762 | /*! |
763 | * \brief Construct SeqStmt. |
764 | * \param seq The sequence. |
765 | * \param span The location of this object in the source code. |
766 | */ |
767 | TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span()); |
768 | |
769 | /*! \return get the size of the sequence */ |
770 | size_t size() const { return operator->()->size(); } |
771 | /*! |
772 | * \brief Get the index-th element in the sequence. |
773 | */ |
774 | Stmt operator[](size_t index) const { return (*(operator->()))[index]; } |
775 | /*! |
776 | * \brief Construct a sequence statement by flattening |
777 | * all the arrays and sequences in the arguments |
778 | * recursively. |
779 | * |
780 | * - When an argument is nullptr, it will be ignored. |
781 | * - When an argument is an array or a SeqStmt, it will be flattened recursively. |
782 | * - A normal Stmt will be appended to the end of the sequence. |
783 | * |
784 | * \note This function can directly return an element |
785 | * if it is the only element in the sequence. |
786 | * |
787 | * \param seq_args The list of arguments to be flattened. |
788 | * \tparam Args arguments |
789 | * \return The constructed statement |
790 | */ |
791 | template <typename... Args> |
792 | static Stmt Flatten(Args&&... seq_args) { |
793 | Array<Stmt> seq; |
794 | runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...); |
795 | if (seq.size() == 1) return seq[0]; |
796 | return SeqStmt(seq); |
797 | } |
798 | /*! \brief Helper class to flatten sequence of arguments into Array. */ |
799 | class Flattener { |
800 | public: |
801 | explicit Flattener(Array<Stmt>* seq) : seq_(seq) {} |
802 | |
803 | void operator()(size_t i, const Stmt& stmt) const { |
804 | if (!stmt.defined()) return; |
805 | if (auto* op = stmt.as<SeqStmtNode>()) { |
806 | operator()(0, op->seq); |
807 | } else { |
808 | seq_->push_back(stmt); |
809 | } |
810 | } |
811 | |
812 | template <typename T> |
813 | void operator()(size_t i, const T& seq) const { |
814 | for (auto v : seq) { |
815 | this->operator()(0, v); |
816 | } |
817 | } |
818 | |
819 | private: |
820 | Array<Stmt>* seq_; |
821 | }; |
822 | |
823 | TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); |
824 | TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode); |
825 | }; |
826 | |
827 | /*! |
828 | * \brief IfThenElse statment. |
829 | */ |
830 | class IfThenElseNode : public StmtNode { |
831 | public: |
832 | /*! \brief The condition. */ |
833 | PrimExpr condition; |
834 | /*! \brief The branch to be executed when condition is true. */ |
835 | Stmt then_case; |
836 | /*! \brief The branch to be executed when condition is false, can be null. */ |
837 | Optional<Stmt> else_case; |
838 | |
839 | void VisitAttrs(AttrVisitor* v) { |
840 | v->Visit("condition" , &condition); |
841 | v->Visit("then_case" , &then_case); |
842 | v->Visit("else_case" , &else_case); |
843 | v->Visit("span" , &span); |
844 | } |
845 | |
846 | bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { |
847 | return equal(condition, other->condition) && equal(then_case, other->then_case) && |
848 | equal(else_case, other->else_case); |
849 | } |
850 | |
851 | void SHashReduce(SHashReducer hash_reduce) const { |
852 | hash_reduce(condition); |
853 | hash_reduce(then_case); |
854 | hash_reduce(else_case); |
855 | } |
856 | |
857 | static constexpr const char* _type_key = "tir.IfThenElse" ; |
858 | TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); |
859 | }; |
860 | |
861 | /*! |
862 | * \brief Managed reference to IfThenElseNode. |
863 | * \sa IfThenElseNode |
864 | */ |
865 | class IfThenElse : public Stmt { |
866 | public: |
867 | TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt, |
868 | Span span = Span()); |
869 | |
870 | TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); |
871 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode); |
872 | }; |
873 | |
874 | /*! |
875 | * \brief Evaluates an expression. |
876 | * This is mostly used for putting a Call node into Stmt. |
877 | * |
878 | * If value do not have side-effect, this node can be safely removed. |
879 | */ |
880 | class EvaluateNode : public StmtNode { |
881 | public: |
882 | /*! \brief The expression to be evaluated. */ |
883 | PrimExpr value; |
884 | |
885 | void VisitAttrs(AttrVisitor* v) { |
886 | v->Visit("value" , &value); |
887 | v->Visit("span" , &span); |
888 | } |
889 | |
890 | bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { |
891 | return equal(value, other->value); |
892 | } |
893 | |
894 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } |
895 | |
896 | static constexpr const char* _type_key = "tir.Evaluate" ; |
897 | TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); |
898 | }; |
899 | |
900 | /*! |
901 | * \brief Managed reference to EvaluateNode. |
902 | * \sa EvaluateNode |
903 | */ |
904 | class Evaluate : public Stmt { |
905 | public: |
906 | TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span()); |
907 | |
908 | explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} |
909 | |
910 | TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); |
911 | TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode); |
912 | }; |
913 | |
914 | /*! |
915 | * \brief The kind of the loop. |
916 | * |
917 | * ForKind can change the control flow semantics |
918 | * of the loop. So the kind field needs to be considered |
919 | * in all TIR passes. |
920 | */ |
921 | enum class ForKind : int { |
922 | /*! \brief default semantics -- serial execution. */ |
923 | kSerial = 0, |
924 | /*! \brief Parallel execution on CPU. */ |
925 | kParallel = 1, |
926 | /*! |
927 | * \brief Vector SIMD loop. |
928 | * The loop body will be vectorized. |
929 | */ |
930 | kVectorized = 2, |
931 | /*! \brief The loop body must be unrolled. */ |
932 | kUnrolled = 3, |
933 | /*! |
934 | * \brief The loop variable is bound to a thread in |
935 | * an environment. In the final stage of lowering, |
936 | * the loop is simply removed and the loop variable is |
937 | * mapped to the corresponding context thread. |
938 | */ |
939 | kThreadBinding = 4 |
940 | }; |
941 | |
942 | /*! |
943 | * \brief A for loop, with poissible type annotations. |
944 | * |
945 | * \code |
946 | * |
947 | * for (loop_var = min; loop_var < min + extent; ++loop_var) { |
948 | * // body |
949 | * } |
950 | * \endcode |
951 | */ |
952 | class ForNode : public StmtNode { |
953 | public: |
954 | /*! \brief The loop variable. */ |
955 | Var loop_var; |
956 | /*! \brief The minimum value of iteration. */ |
957 | PrimExpr min; |
958 | /*! \brief The extent of the iteration. */ |
959 | PrimExpr extent; |
960 | /*! \brief The kind of the for loop. */ |
961 | ForKind kind; |
962 | /*! \brief The body of the for loop. */ |
963 | Stmt body; |
964 | /*! |
965 | * \brief Only valid when kind == ForKind::kThreadBinding |
966 | * The context thread that this loop variable bounds to. |
967 | */ |
968 | Optional<IterVar> thread_binding; |
969 | /*! |
970 | * \brief Additional annotations about the loop. |
971 | * |
972 | * These annotations can be used as auxiliary hint |
973 | * to future transformations. An annotation should |
974 | * not change the control flow semantics of the loop |
975 | * and can be ignored in most passes. |
976 | */ |
977 | Map<String, ObjectRef> annotations; |
978 | |
979 | void VisitAttrs(AttrVisitor* v) { |
980 | v->Visit("loop_var" , &loop_var); |
981 | v->Visit("min" , &min); |
982 | v->Visit("extent" , &extent); |
983 | v->Visit("kind" , &kind); |
984 | v->Visit("body" , &body); |
985 | v->Visit("thread_binding" , &thread_binding); |
986 | v->Visit("annotations" , &annotations); |
987 | v->Visit("span" , &span); |
988 | } |
989 | |
990 | bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { |
991 | return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && |
992 | equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && |
993 | equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations); |
994 | } |
995 | |
996 | void SHashReduce(SHashReducer hash_reduce) const { |
997 | hash_reduce.DefHash(loop_var); |
998 | hash_reduce(min); |
999 | hash_reduce(extent); |
1000 | hash_reduce(kind); |
1001 | hash_reduce(body); |
1002 | hash_reduce(thread_binding); |
1003 | hash_reduce(annotations); |
1004 | } |
1005 | |
1006 | static constexpr const char* _type_key = "tir.For" ; |
1007 | TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); |
1008 | }; |
1009 | |
1010 | /*! |
1011 | * \brief Managed reference to ForNode. |
1012 | * \sa ForNode |
1013 | */ |
1014 | class For : public Stmt { |
1015 | public: |
1016 | TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, |
1017 | Optional<IterVar> thread_binding = NullOpt, |
1018 | Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span()); |
1019 | |
1020 | TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); |
1021 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); |
1022 | }; |
1023 | |
1024 | /*! |
1025 | * \brief A While loop |
1026 | * |
1027 | * \code |
1028 | * |
1029 | * while (condition) |
1030 | * body |
1031 | * |
1032 | * \endcode |
1033 | */ |
1034 | class WhileNode : public StmtNode { |
1035 | public: |
1036 | /*! \brief The termination condition. */ |
1037 | PrimExpr condition; |
1038 | /*! \brief The body of the while loop. */ |
1039 | Stmt body; |
1040 | |
1041 | void VisitAttrs(AttrVisitor* v) { |
1042 | v->Visit("condition" , &condition); |
1043 | v->Visit("body" , &body); |
1044 | v->Visit("span" , &span); |
1045 | } |
1046 | |
1047 | bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { |
1048 | return equal(condition, other->condition) && equal(body, other->body); |
1049 | } |
1050 | |
1051 | void SHashReduce(SHashReducer hash_reduce) const { |
1052 | hash_reduce(condition); |
1053 | hash_reduce(body); |
1054 | } |
1055 | |
1056 | static constexpr const char* _type_key = "tir.While" ; |
1057 | TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode); |
1058 | }; |
1059 | |
1060 | /*! |
1061 | * \brief Managed reference to WhileNode. |
1062 | * \sa WhileNode |
1063 | */ |
1064 | class While : public Stmt { |
1065 | public: |
1066 | TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span()); |
1067 | |
1068 | TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode); |
1069 | TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode); |
1070 | }; |
1071 | |
1072 | /*! |
1073 | * \brief A prefetch hint for a buffer |
1074 | */ |
1075 | class PrefetchNode : public StmtNode { |
1076 | public: |
1077 | /*! \brief The function to be prefetched. */ |
1078 | Buffer buffer; |
1079 | /*! \brief Bounds to be prefetched. */ |
1080 | Array<Range> bounds; |
1081 | |
1082 | void VisitAttrs(AttrVisitor* v) { |
1083 | v->Visit("buffer" , &buffer); |
1084 | v->Visit("bounds" , &bounds); |
1085 | v->Visit("span" , &span); |
1086 | } |
1087 | |
1088 | bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { |
1089 | return equal(buffer, other->buffer) && equal(bounds, other->bounds); |
1090 | } |
1091 | |
1092 | void SHashReduce(SHashReducer hash_reduce) const { |
1093 | hash_reduce(buffer); |
1094 | hash_reduce(bounds); |
1095 | } |
1096 | |
1097 | PrefetchNode() = default; |
1098 | PrefetchNode(Buffer buffer, Array<Range> bounds, Span span = Span()) |
1099 | : StmtNode(span), buffer(buffer), bounds(bounds) {} |
1100 | |
1101 | static constexpr const char* _type_key = "tir.Prefetch" ; |
1102 | TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); |
1103 | }; |
1104 | |
1105 | /*! |
1106 | * \brief Managed reference to PrefetchNode. |
1107 | * \sa PrefetchNode |
1108 | */ |
1109 | class Prefetch : public Stmt { |
1110 | public: |
1111 | TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span()); |
1112 | |
1113 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); |
1114 | TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode); |
1115 | }; |
1116 | |
1117 | /*! |
1118 | * \brief Representing the region of multi-dimensional buffer access. |
1119 | */ |
1120 | class BufferRegionNode : public Object { |
1121 | public: |
1122 | /*! \brief The buffer of the buffer region. */ |
1123 | Buffer buffer; |
1124 | /*! \brief The region array of the buffer region. */ |
1125 | Array<Range> region; |
1126 | |
1127 | void VisitAttrs(AttrVisitor* v) { |
1128 | v->Visit("buffer" , &buffer); |
1129 | v->Visit("region" , ®ion); |
1130 | } |
1131 | |
1132 | bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const { |
1133 | return equal(buffer, other->buffer) && equal(region, other->region); |
1134 | } |
1135 | |
1136 | void SHashReduce(SHashReducer hash_reduce) const { |
1137 | hash_reduce(buffer); |
1138 | hash_reduce(region); |
1139 | } |
1140 | |
1141 | static constexpr const char* _type_key = "tir.BufferRegion" ; |
1142 | static constexpr const bool _type_has_method_sequal_reduce = true; |
1143 | static constexpr const bool _type_has_method_shash_reduce = true; |
1144 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object); |
1145 | }; |
1146 | |
1147 | /*! |
1148 | * \brief Managed reference to BufferRegionNode. |
1149 | * \sa BufferRegionNode |
1150 | */ |
1151 | class BufferRegion : public ObjectRef { |
1152 | public: |
1153 | TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region); |
1154 | |
1155 | /*! |
1156 | * \brief Create a BufferRegion which is full region of the given buffer. |
1157 | * \param buffer The buffer to generate full BufferRegion. |
1158 | * \return The BufferRegion which covers all region of the given buffer |
1159 | */ |
1160 | TVM_DLL static BufferRegion FullRegion(Buffer buffer); |
1161 | |
1162 | /*! |
1163 | * \brief Create a BufferRegion which is a single point of the given buffer. |
1164 | * \param buffer The buffer to generate single point BufferRegion. |
1165 | * \param indices The access point indices of the buffer |
1166 | * \return The BufferRegion which is the single point of the given buffer. |
1167 | */ |
1168 | TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices); |
1169 | |
1170 | TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode); |
1171 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode); |
1172 | }; |
1173 | |
1174 | /*! |
1175 | * \brief Match introduces a constraint that the source buffer region can be remapped to the data |
1176 | * layout specified by the buffer field. The constraint can be checked in later part of lowering (or |
1177 | * optionally during runtime). |
1178 | * |
1179 | * MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in |
1180 | * low-level hardware primitives in the IR and defer the check after the sequence of |
1181 | * transformations. |
1182 | */ |
1183 | class MatchBufferRegionNode : public Object { |
1184 | public: |
1185 | /*! \brief The target buffer. */ |
1186 | Buffer buffer; |
1187 | /*! \brief The source buffer region. */ |
1188 | BufferRegion source; |
1189 | |
1190 | void VisitAttrs(AttrVisitor* v) { |
1191 | v->Visit("buffer" , &buffer); |
1192 | v->Visit("source" , &source); |
1193 | } |
1194 | |
1195 | bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const { |
1196 | return equal(buffer, other->buffer) && equal(source, other->source); |
1197 | } |
1198 | |
1199 | void SHashReduce(SHashReducer hash_reduce) const { |
1200 | hash_reduce(buffer); |
1201 | hash_reduce(source); |
1202 | } |
1203 | |
1204 | static constexpr const char* _type_key = "tir.MatchBufferRegion" ; |
1205 | static constexpr const bool _type_has_method_sequal_reduce = true; |
1206 | static constexpr const bool _type_has_method_shash_reduce = true; |
1207 | TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object); |
1208 | }; |
1209 | |
1210 | /*! |
1211 | * \brief Managed reference to MatchBufferRegionNode. |
1212 | * \sa MatchBufferRegionNode |
1213 | */ |
1214 | class MatchBufferRegion : public ObjectRef { |
1215 | public: |
1216 | TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source); |
1217 | |
1218 | TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); |
1219 | TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode); |
1220 | }; |
1221 | |
1222 | /*! |
1223 | * \brief A block is a basic schedule unit in TIR. |
1224 | * \note Block's body is parameterized by iter vars. |
1225 | * \code |
1226 | * |
1227 | * with T.block(name): |
1228 | * v0 = T.axis.S(domain, value0) |
1229 | * v1 = T.axis.R(domain, value1) |
1230 | * ... |
1231 | * T.reads([buffer0[start:end, ...], ...]) |
1232 | * T.writes([buffer1[start:end, ...], ...]) |
1233 | * T.where(predicate) |
1234 | * buffer2 = T.alloc_buffer(shape, dtype) |
1235 | * buffer3 = T.match_buffer(source_buffer[start:end, ...]) |
1236 | * T.attr({attr_key: attr_value, ...}) |
1237 | * with T.init(): |
1238 | * // init body |
1239 | * // body |
1240 | * |
1241 | * \endcode |
1242 | */ |
1243 | class BlockNode : public StmtNode { |
1244 | public: |
1245 | /*! \brief The variables of the block. */ |
1246 | Array<IterVar> iter_vars; |
1247 | /*! \brief The read buffer regions of the block. */ |
1248 | Array<BufferRegion> reads; |
1249 | /*! \brief The write buffer regions of the block. */ |
1250 | Array<BufferRegion> writes; |
1251 | /*! \brief The name_hint of the block. */ |
1252 | String name_hint; |
1253 | /*! \brief The body of the block. */ |
1254 | Stmt body; |
1255 | /*! |
1256 | * \brief The init statement is executed during the first iteration of reduction loops in a |
1257 | * reduction block. The optional init field allows us to represent initialization and |
1258 | * reduction update in a single block and transform them collectively. |
1259 | * We also provide primitives to decompose the init into a separate block during scheduling. |
1260 | * Init field is `NullOpt` if there is no reduction iter_vars |
1261 | */ |
1262 | Optional<Stmt> init; |
1263 | /*! \brief The buffer allocated in the block. */ |
1264 | Array<Buffer> alloc_buffers; |
1265 | /*! \brief The match buffer regions. */ |
1266 | Array<MatchBufferRegion> match_buffers; |
1267 | /*! \brief The annotation of the block. */ |
1268 | Map<String, ObjectRef> annotations; |
1269 | |
1270 | void VisitAttrs(AttrVisitor* v) { |
1271 | v->Visit("iter_vars" , &iter_vars); |
1272 | v->Visit("reads" , &reads); |
1273 | v->Visit("writes" , &writes); |
1274 | v->Visit("name_hint" , &name_hint); |
1275 | v->Visit("body" , &body); |
1276 | v->Visit("init" , &init); |
1277 | v->Visit("alloc_buffers" , &alloc_buffers); |
1278 | v->Visit("match_buffers" , &match_buffers); |
1279 | v->Visit("annotations" , &annotations); |
1280 | } |
1281 | |
1282 | bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const { |
1283 | // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars |
1284 | return equal.DefEqual(iter_vars, other->iter_vars) && |
1285 | equal(alloc_buffers, other->alloc_buffers) && |
1286 | equal(match_buffers, other->match_buffers) && equal(reads, other->reads) && |
1287 | equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) && |
1288 | equal(annotations, other->annotations); |
1289 | } |
1290 | |
1291 | void SHashReduce(SHashReducer hash_reduce) const { |
1292 | hash_reduce.DefHash(iter_vars); |
1293 | hash_reduce(alloc_buffers); |
1294 | hash_reduce(match_buffers); |
1295 | hash_reduce(reads); |
1296 | hash_reduce(writes); |
1297 | hash_reduce(body); |
1298 | hash_reduce(init); |
1299 | hash_reduce(annotations); |
1300 | } |
1301 | |
1302 | static constexpr const char* _type_key = "tir.Block" ; |
1303 | TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode); |
1304 | }; |
1305 | |
1306 | /*! |
1307 | * \brief Managed reference to BlockNode. |
1308 | * \sa BlockNode |
1309 | */ |
1310 | class Block : public Stmt { |
1311 | public: |
1312 | TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, |
1313 | Array<BufferRegion> writes, String name_hint, Stmt body, |
1314 | Optional<Stmt> init = NullOpt, |
1315 | Array<Buffer> alloc_buffers = Array<Buffer>(), |
1316 | Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(), |
1317 | Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), |
1318 | Span span = Span()); |
1319 | |
1320 | TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode); |
1321 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); |
1322 | }; |
1323 | |
1324 | /*! |
1325 | * \brief A block realization node represents execution of the block at the binding values. |
1326 | */ |
1327 | class BlockRealizeNode : public StmtNode { |
1328 | public: |
1329 | /*! \brief The corresponding values of the iter vars. */ |
1330 | Array<PrimExpr> iter_values; |
1331 | /*! |
1332 | * \brief The predicate of the block realization, the block will only be executed when the |
1333 | * predicate is true. |
1334 | */ |
1335 | PrimExpr predicate; |
1336 | /*! \brief The block to be realized. */ |
1337 | Block block; |
1338 | |
1339 | void VisitAttrs(AttrVisitor* v) { |
1340 | v->Visit("iter_values" , &iter_values); |
1341 | v->Visit("predicate" , &predicate); |
1342 | v->Visit("block" , &block); |
1343 | } |
1344 | |
1345 | bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const { |
1346 | return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) && |
1347 | equal(block, other->block); |
1348 | } |
1349 | |
1350 | void SHashReduce(SHashReducer hash_reduce) const { |
1351 | hash_reduce(iter_values); |
1352 | hash_reduce(predicate); |
1353 | hash_reduce(block); |
1354 | } |
1355 | |
1356 | static constexpr const char* _type_key = "tir.BlockRealize" ; |
1357 | TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode); |
1358 | }; |
1359 | |
1360 | /*! |
1361 | * \brief Managed reference to BlockRealizeNode |
1362 | * \sa BlockRealizeNode |
1363 | */ |
1364 | class BlockRealize : public Stmt { |
1365 | public: |
1366 | TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block, |
1367 | Span span = Span()); |
1368 | |
1369 | TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode); |
1370 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); |
1371 | }; |
1372 | |
1373 | /*! \brief namespace of possible attributes in AttrStmt.attr_key */ |
1374 | namespace attr { |
1375 | // The above attr does not pass to ir stage. |
1376 | /*! \brief Mark launching extent of thread, used by device API. */ |
1377 | constexpr const char* thread_extent = "thread_extent" ; |
1378 | /*! \brief Mark launching of a virtual thread. */ |
1379 | constexpr const char* virtual_thread = "virtual_thread" ; |
1380 | /*! \brief Mark region is processed by a co-proccesor */ |
1381 | constexpr const char* coproc_scope = "coproc_scope" ; |
1382 | /*! |
1383 | * \brief Mark region creates coprocessor micro ops, |
1384 | * can be reused if corresponding variable is independent. |
1385 | */ |
1386 | constexpr const char* coproc_uop_scope = "coproc_uop_scope" ; |
1387 | /*! \brief Mark the scope as volatile access for certain handle. */ |
1388 | constexpr const char* volatile_scope = "volatile_scope" ; |
1389 | /*! |
1390 | * \brief Mark the scope as generated by extern primitive. |
1391 | * such scope can contain arbitrary ir program and we need to be careful |
1392 | * when make certain assumptions about the structure of the program. |
1393 | */ |
1394 | constexpr const char* extern_scope = "extern_scope" ; |
1395 | /*! |
1396 | * \brief Mark the scope as when computation start to happen |
1397 | * This can hint some code generator to create a new function for compute. |
1398 | */ |
1399 | constexpr const char* compute_scope = "compute_scope" ; |
1400 | /*! \brief Mark storage alignment requirement of buffers */ |
1401 | constexpr const char* storage_alignment = "storage_alignment" ; |
1402 | /*! \brief Mark storage scope of realization */ |
1403 | constexpr const char* realize_scope = "realize_scope" ; |
1404 | /*! \brief The allocation device for global malloc in host. */ |
1405 | constexpr const char* device_id = "device_id" ; |
1406 | /*! \brief The device type. */ |
1407 | constexpr const char* device_type = "device_type" ; |
1408 | /*! \brief Mark of loop scope */ |
1409 | constexpr const char* loop_scope = "loop_scope" ; |
1410 | /*! \brief Mark of reduce scope */ |
1411 | constexpr const char* reduce_scope = "reduce_scope" ; |
1412 | /*! \brief Pragma: auto-unroll, max_step */ |
1413 | constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step" ; |
1414 | /*! \brief Pragma: unroll explicit */ |
1415 | constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit" ; |
1416 | /*! \brief Mark region is guarded by the pragma extension */ |
1417 | constexpr const char* pragma_scope_prefix = "pragma_" ; |
1418 | /*! \brief Import C source or file into the final code gen module */ |
1419 | constexpr const char* pragma_import_c = "pragma_import_c" ; |
1420 | /*! \brief Import llvm source or file into the final code gen module */ |
1421 | constexpr const char* pragma_import_llvm = "pragma_import_llvm" ; |
1422 | /*! \brief Try to modify the AST to support Tensor Core */ |
1423 | constexpr const char* pragma_tensor_core = "pragma_tensor_core" ; |
1424 | /*! |
1425 | * \brief Mark of prefetch scope, value=offset, |
1426 | * run prefetch of Tensor on the current loop scope |
1427 | */ |
1428 | constexpr const char* prefetch_scope = "prefetch_scope" ; |
1429 | /*! |
1430 | * \brief Marks the layout transforms to be used for a tensor. |
1431 | * |
1432 | * Only applies to a DataProducer, as it should be made part of the |
1433 | * PrimFunc attributes for TIR. |
1434 | */ |
1435 | constexpr const char* layout_transforms = "layout_transforms" ; |
1436 | /*! |
1437 | * \brief Marks the physical axis separators |
1438 | * |
1439 | * Only applies to a DataProducer, as it should be made part of the |
1440 | * Buffer definition in a PrimFunc. See `BufferNode::axis_separators` |
1441 | * for more details. |
1442 | */ |
1443 | constexpr const char* axis_separators = "axis_separators" ; |
1444 | /*! |
1445 | * \brief Marks production of double buffer data |
1446 | */ |
1447 | constexpr const char* double_buffer_scope = "double_buffer_scope" ; |
1448 | /*! |
1449 | * \brief Marks region used by double buffer write |
1450 | */ |
1451 | constexpr const char* double_buffer_write = "double_buffer_write" ; |
1452 | /*! \brief Mark realization for rolling buffer optimization */ |
1453 | constexpr const char* rolling_buffer_scope = "rolling_buffer_scope" ; |
1454 | /*! \brief Mark of scan update scope */ |
1455 | constexpr const char* scan_update_scope = "scan_update_scope" ; |
1456 | /*! \brief Mark of scan init scope */ |
1457 | constexpr const char* scan_init_scope = "scan_init_scope" ; |
1458 | /*! |
1459 | * \brief Mark alignment of buffer dimension |
1460 | * stmt.node is Tensor |
1461 | * stmt.value is tvm_tuple(dim, align, offset) |
1462 | * This gives hint to require stride of dim to be k * align + offset. |
1463 | */ |
1464 | constexpr const char* buffer_dim_align = "buffer_dim_align" ; |
1465 | /*! \brief Mark stores/loads with theirs bounds. */ |
1466 | constexpr const char* buffer_bound = "buffer_bound" ; |
1467 | /*! |
1468 | * \brief Bind the buffer specification to the region of the op |
1469 | * When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor] |
1470 | * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). |
1471 | * The scope represents that we need to bind the storage region of tensor to buffer. |
1472 | * This will affect replacement of some variables inside the scope that |
1473 | * corresponds to field of buffer to be the actual expressions of tensor during |
1474 | * storage flattening phase. |
1475 | */ |
1476 | constexpr const char* buffer_bind_scope = "buffer_bind_scope" ; |
1477 | // Pipeline related attributes |
1478 | /*! \brief channel read scope */ |
1479 | constexpr const char* channel_read_scope = "channel_read_scope" ; |
1480 | /*! \brief Advance step of channel after end of scope */ |
1481 | constexpr const char* channel_read_advance = "channel_read_advance" ; |
1482 | /*! \brief channel write scope */ |
1483 | constexpr const char* channel_write_scope = "channel_write_scope" ; |
1484 | /*! \brief Advance step of channel after end of scope */ |
1485 | constexpr const char* channel_write_advance = "channel_write_advance" ; |
1486 | /*! \brief pipeline stage scope, implies always execution */ |
1487 | constexpr const char* pipeline_stage_scope = "pipeline_stage_scope" ; |
1488 | /*! \brief pipeline execution scope, implies the scope can be pipelined. */ |
1489 | constexpr const char* pipeline_exec_scope = "pipeline_exec_scope" ; |
1490 | |
1491 | /*! |
1492 | * \brief Mark that it is in the device scope. |
1493 | */ |
1494 | constexpr const char* device_scope = "device_scope" ; |
1495 | |
1496 | /*! |
1497 | * \brief Mark that the attached statement runs asynchronously. |
1498 | */ |
1499 | constexpr const char* async_scope = "async_scope" ; |
1500 | |
1501 | /*! |
1502 | * \brief Annotations for invoking and synchronizing asynchronous operations. |
1503 | |
1504 | * Synchronization is done in terms of "queue": It is an abstract entity associated |
1505 | * with each asynchronous unit, and it tracks invocations and completions of asynchronous |
1506 | * operations in the FIFO order. |
1507 | * |
1508 | * Similarly to PTX instructions commit_group and wait_group, these annotations express |
1509 | * synchronization by "counting": |
1510 | * |
1511 | * async_commit_queue(i): Group one or more invocations of async operations in the given scope, |
1512 | * and "commit" (or push) them to the queue i. A group of operations committed together is |
1513 | * awaited as one chunk. Groups committed to the same queue complete in the FIFO order. |
1514 | * |
1515 | * async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at |
1516 | * the queue i. N does not have to be a constant, but some backends may require a constant count. |
1517 | */ |
1518 | constexpr const char* async_commit_queue_scope = "async_commit_queue_scope" ; |
1519 | constexpr const char* async_wait_queue_scope = "async_wait_queue_scope" ; |
1520 | constexpr const char* async_wait_inflight_count = "async_wait_inflight_count" ; |
1521 | |
1522 | /*! |
1523 | * \brief Mark that the shape of TensorCore fragment |
1524 | */ |
1525 | constexpr const char* fragment_shape = "fragment_shape" ; |
1526 | |
1527 | /*! |
1528 | * \brief Mark that the layout of TensorCore fragment |
1529 | */ |
1530 | constexpr const char* fragment_layout = "fragment_layout" ; |
1531 | |
1532 | /*! |
1533 | * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted |
1534 | */ |
1535 | constexpr const char* hand_threaded = "hand_threaded" ; |
1536 | |
1537 | /*! |
1538 | * \brief Mark whether the script-completer need to fill in missing access region |
1539 | * during script parsing. |
1540 | * \note The result should be a integer mask with range [0, 4). |
1541 | * if (mask & 1) the read region should be detected, |
1542 | * if (mask & 2) the write region should be detected. |
1543 | */ |
1544 | constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access" ; |
1545 | |
1546 | /*! |
1547 | * \brief Mark that the loop should be partitioned. |
1548 | */ |
1549 | constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint" ; |
1550 | |
1551 | /*! \brief Mark the stage of a statement in the software pipeline */ |
1552 | constexpr const char* software_pipeline_stage = "software_pipeline_stage" ; |
1553 | |
1554 | /*! \brief Mark the order of a statement in the software pipeline */ |
1555 | constexpr const char* software_pipeline_order = "software_pipeline_order" ; |
1556 | |
1557 | /*! \brief List stages in the software pipeline that should run asynchronously |
1558 | * \note All statements in the provided stages are assumed to have asynchronous |
1559 | * semantics (e.g. CUDA async global to shared memory copy). |
1560 | */ |
1561 | constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages" ; |
1562 | |
1563 | /*! \brief Mark the buffers which is const access and can be transformed layout. */ |
1564 | constexpr const char* layout_free_buffers = "layout_free_buffers" ; |
1565 | |
1566 | /*! \brief Mark the local stage for the shared memory access should be added. */ |
1567 | constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage" ; |
1568 | |
1569 | /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ |
1570 | constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure" ; |
1571 | |
1572 | /*! |
1573 | * \brief Mark that the loop should be further skip and bound to environment threads to enable |
1574 | * cooperative fetching. |
1575 | */ |
1576 | constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch" ; |
1577 | |
1578 | /*! \brief The allowed range of thread extent in thread bindings */ |
1579 | constexpr const char* meta_schedule_thread_extent_low_inclusive = |
1580 | "meta_schedule.thread_extent_low_inclusive" ; |
1581 | |
1582 | /*! \brief The allowed range of thread extent in thread bindings */ |
1583 | constexpr const char* meta_schedule_thread_extent_high_inclusive = |
1584 | "meta_schedule.thread_extent_high_inclusive" ; |
1585 | |
1586 | /*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ |
1587 | constexpr const char* meta_schedule_random_compute_producer = |
1588 | "meta_schedule.random_compute_producer" ; |
1589 | |
1590 | /*! \brief Mark auto-parallel setting on the block. */ |
1591 | constexpr const char* meta_schedule_parallel = "meta_schedule.parallel" ; |
1592 | |
1593 | /*! \brief Mark auto-vectorize setting on the block. */ |
1594 | constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize" ; |
1595 | |
1596 | /*! \brief Mark auto-unroll setting on the block. */ |
1597 | constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit" ; |
1598 | |
1599 | /*! \brief Mark auto-unroll setting on the block. */ |
1600 | constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit" ; |
1601 | |
1602 | /*! \brief Mark that a block should be further rewritten using tensorization. */ |
1603 | constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize" ; |
1604 | |
1605 | /*! \brief Mark that a block is a preprocessor block for layout rewrite. */ |
1606 | constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc" ; |
1607 | /*! |
1608 | * \brief Mark that the init statement of a block should be further rewritten using tensorization. |
1609 | */ |
1610 | constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init" ; |
1611 | |
1612 | /*! |
1613 | * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is |
1614 | * warp size. |
1615 | */ |
1616 | constexpr const char* warp_execution = "warp_execution" ; |
1617 | |
1618 | /*! \brief Mark that a block is disallowed in auto inline. */ |
1619 | constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule" ; |
1620 | |
1621 | /*! |
1622 | * \brief Check if attr_key is a pragma key extension |
1623 | * \param attr_key The attr key to be compared |
1624 | * \return true if it is a pragma key |
1625 | */ |
1626 | inline bool IsPragmaKey(const std::string& attr_key) { |
1627 | return attr_key.compare(0, 7, "pragma_" ) == 0; |
1628 | } |
1629 | |
1630 | } // namespace attr |
1631 | /*! |
1632 | * \brief Create a type annotation expression |
1633 | * \param dtype The data type |
1634 | * \param span The location of this object in the source code. |
1635 | * \return Expr a expression with dtype. |
1636 | */ |
1637 | TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); |
1638 | |
1639 | // overload printing of for type. |
1640 | TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); |
1641 | |
1642 | // inline implementations |
1643 | inline const char* ForKind2String(ForKind t) { |
1644 | switch (t) { |
1645 | case ForKind::kSerial: |
1646 | return "serial" ; |
1647 | case ForKind::kParallel: |
1648 | return "parallel" ; |
1649 | case ForKind::kVectorized: |
1650 | return "vectorized" ; |
1651 | case ForKind::kUnrolled: |
1652 | return "unroll" ; |
1653 | case ForKind::kThreadBinding: |
1654 | return "thread_binding" ; |
1655 | } |
1656 | LOG(FATAL) << "Unknown ForKind" << t; |
1657 | } |
1658 | |
1659 | } // namespace tir |
1660 | } // namespace tvm |
1661 | #endif // TVM_TIR_STMT_H_ |
1662 | |