1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/te/schedule.h
22 * \brief Define a schedule.
23 */
24// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25#ifndef TVM_TE_SCHEDULE_H_
26#define TVM_TE_SCHEDULE_H_
27
28#include <tvm/support/with.h>
29#include <tvm/te/tensor.h>
30#include <tvm/te/tensor_intrin.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/index_map.h>
33
34#include <string>
35#include <unordered_map>
36
37namespace tvm {
38namespace te {
39// Node container for Stage
40class StageNode;
41// Node container for Schedule
42class ScheduleNode;
43// Node container for IterVarRelation
44class IterVarRelationNode;
45// Attribute of itervar.
46class IterVarAttrNode;
47
48/*! \brief the attachment type */
49enum AttachType : int {
50 kGroupRoot = 1,
51 kInline = 2,
52 kInlinedAlready = 3,
53 kScope = 4,
54 kScanUpdate = 5
55};
56
57/*! \brief Stage, contains scheduling for a stage of computation. */
58class Stage : public ObjectRef {
59 public:
60 Stage() {}
61 explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
62 /*!
63 * \brief create a new schedule for op.
64 * \param op The operator in the schedule
65 */
66 explicit Stage(Operation op);
67 /*!
68 * \brief access the internal node container
69 * \return the pointer to the internal node container
70 */
71 inline const StageNode* operator->() const;
72 /*!
73 * \brief access the internal node container
74 * \return the pointer to the internal node container
75 */
76 inline StageNode* operator->();
77 /*!
78 * \brief set the memory scope of the stage
79 * \param scope The memory scope.
80 */
81 TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
82 /*!
83 * \brief specify the schedule to be computed at the parent schedule's scope.
84 * \param parent The parent schedule.
85 * \param scope The iteration point to carry the schedule.
86 * \return reference to self.
87 */
88 TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
89 /*!
90 * \brief Compute the function inline.
91 * \return reference to self.
92 */
93 TVM_DLL Stage& compute_inline(); // NOLINT(*)
94 /*!
95 * \brief Compute the function at group root.
96 * \return reference to self.
97 */
98 TVM_DLL Stage& compute_root(); // NOLINT(*)
99 /*!
100 * \brief Bind the IterVar to thread index.
101 *
102 * \param ivar The IterVar to be bound.
103 * \param thread_ivar The thread axis to be bound.
104 * \return reference to self.
105 */
106 TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
107 /*!
108 * \brief Set the predicate to determine whether a store to the array should be performed.
109 * Use this when there are multiple threads performing the same store and we only
110 * need one of them to do the store.
111 *
112 * \note This is a dangerous scheduling primitive that can change behavior of program.
113 * Only do when we are certain that thare are duplicated stores.
114 * \param predicate The condition to be checked.
115 * \return reference to self.
116 */
117 TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
118 /*!
119 * \brief Specify environment threads that launched around the group's scope.
120 * This can only be used in group stage.
121 * \param threads The threads to be launched around the scope.
122 * \note Each thread can only appear in one env_threads.
123 * This is a beta feature.
124 * \return reference to self.
125 */
126 TVM_DLL Stage& env_threads(Array<IterVar> threads);
127 /*!
128 * \brief Split the parent by factor, generate
129 * \param parent The parent iteration domain.
130 * \param factor The split factor of the loop.
131 * \param p_outer The result outer domain
132 * \param p_inner The result inner domain.
133 * \return reference to self.
134 */
135 TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
136 IterVar* p_inner); // NOLINT(*)
137 /*!
138 * \brief Split the iteration with given number of parts.
139 *
140 * \param parent The parent domain.
141 * \param nparts The number of parts in the outer domain.
142 * \param p_outer The result outer domain.
143 * \param p_inner The result inner domain.
144 * \return reference to self.
145 */
146 TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
147 IterVar* p_inner); // NOLINT(*)
148 /*!
149 * \brief Fuse the inner outer domain to the target
150 * \param outer The outer domain to be fused.
151 * \param inner The inner domain to be fused
152 * \param p_target The result target domain.
153 * \return reference to self.
154 */
155 TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
156 /*!
157 * \brief Fuse all the axes together into a single axis.
158 *
159 * \param axes All the axes to be fused.
160 * \param p_target The result target domain.
161 *
162 * \note axes can be an empty array,
163 * in that case, a singleton IterVar is created and
164 * inserted to the outermost loop.
165 * The fuse of empty array is used to support zero-dimension tensors.
166 *
167 * \return reference to self.
168 */
169 TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
170 /*!
171 * \brief Reorder the iteration
172 * \param order The order of iteration variable.
173 * \return reference to self.
174 */
175 TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
176 /*!
177 * \brief Perform tiling on two dimensions
178 * The final loop order from outmost to inner most are
179 * [x_outer, y_outer, x_inner, y_inner]
180 *
181 * \param x_parent The original x dimension
182 * \param y_parent The original y dimension
183 * \param x_factor The stride factor on x axis
184 * \param y_factor The stride factor on y axis
185 * \param p_x_outer Outer axis of x dimension
186 * \param p_y_outer Outer axis of y dimension
187 * \param p_x_inner Inner axis of x dimension
188 * \param p_y_inner Inner axis of y dimension
189 * \return reference to self.
190 */
191 TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
192 PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
193 IterVar* p_x_inner, IterVar* p_y_inner);
194 /*!
195 * \brief Vectorize iteration.
196 * \param var The axis to be vectorized.
197 * \return reference to self.
198 */
199 TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
200 /*!
201 * \brief Replace computation of the current stage by tensor intrinsic f.
202 * \param var The axis marks beginning of tensorization.
203 * Every operations inside the axis(include axis itself is tensorized).
204 * \param f The Tensor compute intrinsics.
205 * \return reference to self.
206 */
207 TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
208 /*!
209 * \brief Unroll iteration.
210 * \param var The axis to be unrolled.
211 * \return reference to self.
212 */
213 TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
214 /*!
215 * \brief Parallelize iteration.
216 * \param var The axis to be parallelized.
217 * \return reference to self.
218 */
219 TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
220 /*!
221 * \brief Annotate the iteration with pragma
222 *
223 * \param var The axis to be parallelized.
224 * \param pragma_type The pragma type.
225 * \param pragma_value The pragma value
226 *
227 * \return reference to self.
228 */
229 TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
230 const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
231 /*!
232 * \brief Fetch data in advance.
233 * \param domain the tensor to be prefetched
234 * \param var the iteration point at which to apply prefetching
235 * \param offset the number of iterations be to fetched in advance
236 * \return reference to self
237 */
238 TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*)
239 /*!
240 * \brief Set alignment requirement for specific dimension.
241 *
242 * Such that stride[axis] == k * factor + offset for some k.
243 *
244 * \param axis The dimension to be specified for alignment.
245 * \param factor The factor multiple of alignment
246 * \param offset The required offset factor.
247 * \return reference to self
248 */
249 TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*)
250 /*!
251 * \brief Compute current stage with double buffering.
252 * \return reference to self.
253 */
254 TVM_DLL Stage& double_buffer(); // NOLINT(*)
255 /*!
256 * \brief Compute current stage with rolling buffering.
257 * \return reference to self.
258 */
259 TVM_DLL Stage& rolling_buffer(); // NOLINT(*)
260 /*!
261 * \brief Defines a layout transformation to be applied to the buffer.
262 *
263 * The map from initial_index to final_index must be an
264 * invertible affine transformation.
265 *
266 * \param initial_indices An array of variables to represent a
267 * value's location in the tensor, using the pre-transformation
268 * layout. These variables are used as binding occurrences to
269 * represent the initial indices when applying the initial->final
270 * mapping, and should not occur elsewhere in the
271 * Schedule. (i.e. Pass in newly constructed variables, not the
272 * initial IterVar::var)
273 *
274 * \param final_indices An array of expressions, giving the
275 * value's location in the tensor, using the post-transformation layout.
276 * Expressions should be in terms of the variables given in
277 * initial_indices.
278 *
279 * \param out_iter_vars An optional output location for the updated
280 * loop iteration variables.
281 *
282 * \return reference to self
283 */
284 TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
285 const Array<PrimExpr>& final_indices,
286 Array<IterVar>* out_iter_vars = nullptr);
287 /*! \brief Defines separators between groups of axes.
288 *
289 * Used to define `BufferNode::axis_separators`, which has
290 * additional details.
291 *
292 * \param axis_separators A list of axis separators.
293 */
294 TVM_DLL Stage& set_axis_separators(const Array<IntImm>& axis_separators);
295 /*!
296 * \brief whether the stage has been scheduled.
297 * \return whether the stage has been scheduled.
298 */
299 bool is_scheduled() const;
300 /*!
301 * \brief Get attachment spec of current stage.
302 * If the stage compute at Group root, this function
303 * will traverse the group function to get the
304 * final spec from the group.
305 * \return A stage representing the attach spec of the group.
306 */
307 Stage GetAttachSpec() const;
308 // declare container type
309 using ContainerType = StageNode;
310};
311
312/*!
313 * \brief Global schedule container
314 * For operations and all the operations they depend on.
315 * The schedule per Operation is named as stage.
316 */
317class Schedule : public ObjectRef {
318 public:
319 Schedule() {}
320 explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
321 /*!
322 * \brief Create a schedule for array of ops(and their dependencies).
323 * \param ops The ops to be scheduled.
324 * \return sch The created Schedule.
325 */
326 TVM_DLL explicit Schedule(Array<Operation> ops);
327 /*!
328 * \brief Get a copy of current schedule.
329 * \return The copied schedule.
330 */
331 Schedule copy() const;
332 /*!
333 * \brief Get the stage corresponds to the op
334 * \param op The operation.
335 */
336 TVM_DLL Stage operator[](const Operation& op);
337 /*!
338 * \brief Short hand for getting the stage of tensor's operation.
339 * \param tensor The tensor
340 * \return The stage corresponding to the tensor's op
341 */
342 TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
343 /*!
344 * \brief Create a new stage group for all intermediate
345 * operations between inputs and outputs.
346 *
347 * \param outputs The output boundary of the group.
348 * \param inputs The input boundary of the group.
349 * \param include_inputs Whether include inputs if they are reachable from outputs.
350 * \return The new grouped stage.
351 */
352 TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
353 bool include_inputs = false);
354 /*!
355 * \brief create a cache read of original tensor for readers.
356 * This will mutate the body of the readers.
357 * A new stage will be created for the tensor.
358 * \param tensor The tensor cached.
359 * \param scope The scope of the cache.
360 * \param readers The readers to redirect to the tensor.
361 * \return The created tensor.
362 */
363 TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
364 const Array<Operation>& readers);
365 /*!
366 * \brief Create a cache write tensor for producing tensor.
367 * The tensor will take over body of original tensor op.
368 *
369 * This function can be used to do data layout transformation.
370 * If there is a split/fuse/reorder on the data parallel axis of tensor
371 * before cache_write is called. The intermediate cache stores
372 * the data in the layout as the iteration order of leave axis.
373 * The data will be transformed back to the original layout in the original tensor.
374 * User can further call compute_inline to inline the original layout and keep
375 * the data stored in the transformed layout.
376 *
377 * \param tensor The tensors to be produced.
378 * \param scope The scope of the storage.
379 * \return The created tensor.
380 */
381 TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
382 /*!
383 * \brief Create a cache write tensor for producing tensor.
384 * The tensor will take over body of original tensor op.
385 *
386 * This function can be used to do data layout transformation.
387 * If there is a split/fuse/reorder on the data parallel axis of tensor
388 * before cache_write is called. The intermediate cache stores
389 * the data in the layout as the iteration order of leave axis.
390 * The data will be transformed back to the original layout in the original tensor.
391 * User can further call compute_inline to inline the original layout and keep
392 * the data stored in the transformed layout.
393 *
394 * \param tensor The tensor to be produced.
395 * \param scope The scope of the storage.
396 * \return The created tensor.
397 */
398 TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
399 /*!
400 * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
401 * This will create a new stage that generated the new tensor with axis
402 * as the first dimension. The tensor's body will be rewritten as a reduction
403 * over the factored tensor.
404 *
405 * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
406 *
407 * \param tensor The tensor to be factored.
408 * \param axis The reduction axis in tensor's schedule to be factored.
409 * \param factor_axis The position where the new axis is placed.
410 * \return The created factored tensors.
411 */
412 TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
413 /*!
414 * \brief Normalize the schedule.
415 * This is needed before bound inference.
416 * Insert necessary RebaseNode to make sure all leaf_iter_vars
417 * are in form [0, extent)
418 *
419 * \return A normalized schedule, can be same as current one.
420 */
421 Schedule normalize();
422
423 /*!
424 * \brief Normalize the schedule for feature extraction in auto-scheduler.
425 * This is similar to `Schedule::normalize`, but we do aggressive simplification
426 * to the TE compute with const_matrix=True for faster compilation and feature extraction.
427 * The resulted schedule may be wrong, but it is good enough for feature extraction
428 * purposes.
429 *
430 * \return A normalized schedule, can be same as current one.
431 */
432 Schedule normalize_for_feature_extraction();
433
434 /*!
435 * \brief access the internal node container
436 * \return the pointer to the internal node container
437 */
438 inline const ScheduleNode* operator->() const;
439 /*!
440 * \brief access the internal node container
441 * \return the pointer to the internal node container
442 */
443 inline ScheduleNode* operator->();
444 // declare container type
445 using ContainerType = ScheduleNode;
446};
447
448/*!
449 * \brief The schedule relation between IterVars
450 * can be Split, Fuse.
451 */
452class IterVarRelation : public ObjectRef {
453 public:
454 IterVarRelation() {}
455 explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
456 /*!
457 * \brief access the internal node container
458 * \return the pointer to the internal node container
459 */
460 inline const IterVarRelationNode* operator->() const;
461};
462
463/*!
464 * \brief Additional scheduable attributes about IterVar.
465 */
466class IterVarAttr : public ObjectRef {
467 public:
468 IterVarAttr() {}
469 explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
470 /*!
471 * \brief access the internal node container
472 * \return the pointer to the internal node container
473 */
474 inline const IterVarAttrNode* operator->() const;
475};
476
477/*!
478 * \brief represents a stage.
479 *
480 * relations form a Directed acylic hypergraph in bipartite manner.
481 * With each node is represented by a IterVar,
482 * and each hyper-edge is represented by a IterVarRelation.
483 * The relations connects the IterVars in the graph.
484 *
485 * Besides typical stage that corresponds to operations.
486 * There is also group stage, which groups stages together.
487 * Each stage's group(given by group) represent an constraint,
488 * the stage can only be attached to stages within the group.
489 *
490 * The group stage node can be attached to IterVars as in normal stage.
491 */
492class StageNode : public Object {
493 public:
494 /*!
495 * \brief The operation of stage, can be different from original op.
496 * If it is null, then this stage is a group stage.
497 */
498 Operation op;
499 /*!
500 * \brief The original operator.
501 * The op field can change during schedule to alternate the dataflow,
502 * while origin_op remains fixed.
503 */
504 Operation origin_op;
505 /*! \brief All the nodes in the iter var
506 *
507 * Each element of all_iter_vars represents an iteration variable
508 * that may appear within this stage's computation. Any element
509 * of `all_iter_vars` that is in `leaf_iter_vars` represents a
510 * variable that is directly defined and usable within the stage's
511 * computation. All other elements of `all_iter_vars` represent
512 * variables whose value must be computed from the variables in
513 * `leaf_iter_vars`. (e.g. Support index k has been split by
514 * ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in
515 * `leaf_iter_vars`, while k will not, and must be computed as
516 * `4*ko + ki`.
517 */
518 Array<IterVar> all_iter_vars;
519 /*! \brief The current active leaf iter vars in the stage.
520 *
521 * Each element of leaf_iter_vars will either be replaced with the
522 * bound index (e.g. threadIdx.x), or will be expanded into a loop
523 * over the variable's extent. `leaf_iter_vars` is a subset of
524 * `all_iter_vars`.
525 */
526 Array<IterVar> leaf_iter_vars;
527 /*!
528 * \brief Specify threads to be launched at the stage.
529 * This is only valid for composite ops such as Scan.
530 * \note Experimental primitive: used for thread persistence.
531 */
532 Array<IterVar> env_threads;
533 /*!
534 * \brief The predicate under which store can happen
535 * Use this when there can be duplicated threads doing the same store.
536 * \note Experimental primitive: used by cross thread-reduction.
537 */
538 PrimExpr store_predicate;
539 /*! \brief The relation bwteen of IterVars */
540 Array<IterVarRelation> relations;
541 /*! \brief additional attributes about iter var. */
542 Map<IterVar, IterVarAttr> iter_var_attrs;
543 /*! \brief The attachment type of the schedule */
544 AttachType attach_type{kGroupRoot};
545 /*! \brief The attach point of this schedule. */
546 IterVar attach_ivar;
547 /*! \brief The stage this node attaches to */
548 Stage attach_stage;
549 /*! \brief The thread storage scope level of the stage */
550 std::string scope;
551 /*! \brief Whether this is an output stage */
552 bool is_output{false};
553 /*! \brief Whether apply double buffer optimization to this stage */
554 bool double_buffer{false};
555 /*! \brief Whether apply rolling buffer optimization to this stage */
556 bool rolling_buffer{false};
557 /*! \brief Layout transformations to be applied onto the stage's tensors. */
558 Array<IndexMap> layout_transforms;
559 /*! \brief List of axes after which to divide physical axes.
560 *
561 * Used to populate `BufferNode::axis_separators`, which has
562 * additional details.
563 */
564 Array<IntImm> axis_separators;
565 /*!
566 * \brief The parent group of the current stage.
567 * The stage cannot be assigned to stages outside the group.
568 */
569 Stage group;
570 /*! \brief Number of direct child stages, only used for group stage.*/
571 int num_child_stages{0};
572
573 void VisitAttrs(AttrVisitor* v) {
574 v->Visit("op", &op);
575 v->Visit("origin_op", &origin_op);
576 v->Visit("all_iter_vars", &all_iter_vars);
577 v->Visit("leaf_iter_vars", &leaf_iter_vars);
578 v->Visit("env_threads", &env_threads);
579 v->Visit("relations", &relations);
580 v->Visit("iter_var_attrs", &iter_var_attrs);
581 v->Visit("attach_type", &attach_type);
582 v->Visit("attach_ivar", &attach_ivar);
583 v->Visit("attach_stage", &attach_stage);
584 v->Visit("scope", &scope);
585 v->Visit("is_output", &is_output);
586 v->Visit("double_buffer", &double_buffer);
587 v->Visit("layout_transforms", &layout_transforms);
588 v->Visit("axis_separators", &axis_separators);
589 v->Visit("group", &group);
590 v->Visit("num_child_stages", &num_child_stages);
591 }
592
593 static constexpr const char* _type_key = "Stage";
594 TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
595};
596
597/*! \brief node container for schedule */
598class ScheduleNode : public Object {
599 public:
600 /*! \brief The output operations in original data flow graph */
601 Array<Operation> outputs;
602 /*!
603 * \brief list of all stages for ops.
604 * The stages are sorted in dependency order.
605 */
606 Array<Stage> stages;
607 /*!
608 * \brief List of all stage groups.
609 */
610 Array<Stage> groups;
611 /*! \brief map of original operation to the stages */
612 Map<Operation, Stage> stage_map;
613 /*!
614 * \brief Internal stage map to map internal ops to stages.
615 * This is created on demand and can be invalidated.
616 */
617 std::unordered_map<const Object*, Stage> op2stage_cache_;
618
619 void VisitAttrs(AttrVisitor* v) {
620 v->Visit("outputs", &outputs);
621 v->Visit("stages", &stages);
622 v->Visit("groups", &groups);
623 v->Visit("stage_map", &stage_map);
624 }
625
626 /*! \brief Initialize temp cache. */
627 void InitCache();
628 /*! \brief Invalidate temp cache. */
629 void InvalidateCache();
630
631 /*!
632 * \brief Check if the schedule contains an Operation.
633 * \param op The candidate Operation.
634 * \return true if the schedule has the Operation. Otherwise, false.
635 */
636 TVM_DLL bool Contain(const Operation& op) const;
637
638 /*!
639 * \brief Check if the schedule contains a Tensor.
640 * \param tensor The candidate tensor.
641 * \return true if the schedule has the tensor. Otherwise, false.
642 */
643 TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
644
645 static constexpr const char* _type_key = "Schedule";
646 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
647};
648
649/*!
650 * \brief Create a schedule for array of ops(and their dependencies).
651 * \param ops The ops to be scheduled.
652 * \return sch The created Schedule.
653 */
654inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
655
656/*! \brief node container for IterVar attr */
657class IterVarAttrNode : public Object {
658 public:
659 /*! \brief The iteration type. */
660 IterVarType iter_type{kDataPar};
661 /*! \brief The thread this iter Var binds, can be null */
662 IterVar bind_thread;
663 /*! \brief List of tensor to be prefetched in this loop */
664 Array<Tensor> prefetch_data;
665 /*! \brief The offset used in each prefetch */
666 Array<PrimExpr> prefetch_offset;
667 /*!
668 * \brief Tensor intrinsic used in tensorization,
669 * when the axis is marked as Tensorized
670 */
671 TensorIntrin tensor_intrin;
672 /*! \brief Alignment factor of buffer dimension */
673 int dim_align_factor{0};
674 /*! \brief Alignment offset of buffer dimension */
675 int dim_align_offset{0};
676 /*!
677 * \brief Additional pragma keys, array of StringImm
678 */
679 Array<PrimExpr> pragma_keys;
680 /*!
681 * \brief Additional values of pragma, if any
682 */
683 Array<PrimExpr> pragma_values;
684
685 void VisitAttrs(AttrVisitor* v) {
686 v->Visit("iter_type", &iter_type);
687 v->Visit("bind_thread", &bind_thread);
688 v->Visit("prefetch_data", &prefetch_data);
689 v->Visit("prefetch_offset", &prefetch_offset);
690 v->Visit("tensor_intrin", &tensor_intrin);
691 v->Visit("dim_align_factor", &dim_align_factor);
692 v->Visit("dim_align_offset", &dim_align_offset);
693 v->Visit("pragma_keys", &pragma_keys);
694 v->Visit("pragma_values", &pragma_values);
695 }
696
697 static constexpr const char* _type_key = "IterVarAttr";
698 TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
699};
700
701/*! \brief base node of iteration var */
702class IterVarRelationNode : public Object {
703 public:
704 static constexpr const char* _type_key = "IterVarRelation";
705 TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
706};
707
708/*!
709 * \brief Split the parent domain into product of
710 * outer and iter.
711 */
712class SplitNode : public IterVarRelationNode {
713 public:
714 /*! \brief The parent domain */
715 IterVar parent;
716 /*! \brief The outer domain */
717 IterVar outer;
718 /*! \brief The inner domain */
719 IterVar inner;
720 /*! \brief The split factor */
721 PrimExpr factor;
722 /*! \brief Number of parts, only factor or nparts can be given */
723 PrimExpr nparts;
724
725 void VisitAttrs(AttrVisitor* v) {
726 v->Visit("parent", &parent);
727 v->Visit("outer", &outer);
728 v->Visit("inner", &inner);
729 v->Visit("factor", &factor);
730 v->Visit("nparts", &nparts);
731 }
732
733 static constexpr const char* _type_key = "Split";
734 TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
735};
736
737/*!
738 * \brief Managed reference to SplitNode
739 * \sa SplitNode
740 */
741class Split : public IterVarRelation {
742 public:
743 TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
744
745 TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
746};
747
748/*!
749 * \brief Fuse two domains into one domain.
750 */
751class FuseNode : public IterVarRelationNode {
752 public:
753 /*! \brief The outer domain */
754 IterVar outer;
755 /*! \brief The inner domain */
756 IterVar inner;
757 /*! \brief The target domain */
758 IterVar fused;
759
760 void VisitAttrs(AttrVisitor* v) {
761 v->Visit("outer", &outer);
762 v->Visit("inner", &inner);
763 v->Visit("fused", &fused);
764 }
765
766 static constexpr const char* _type_key = "Fuse";
767 TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
768};
769
770/*!
771 * \brief Managed reference to FuseNode
772 * \sa FuseNode
773 */
774class Fuse : public IterVarRelation {
775 public:
776 TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
777
778 TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
779};
780
781/*!
782 * \brief Rebase the iteration to make min to be 0.
783 * This is useful to normalize the Schedule
784 * to make every leaf variable's min to be 0.
785 */
786class RebaseNode : public IterVarRelationNode {
787 public:
788 /*! \brief The parent domain */
789 IterVar parent;
790 /*! \brief The inner domain */
791 IterVar rebased;
792
793 void VisitAttrs(AttrVisitor* v) {
794 v->Visit("parent", &parent);
795 v->Visit("rebased", &rebased);
796 }
797
798 static constexpr const char* _type_key = "Rebase";
799 TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
800};
801
802/*!
803 * \brief Managed reference to RebaseNode
804 * \sa RebaseNode
805 */
806class Rebase : public IterVarRelation {
807 public:
808 TVM_DLL Rebase(IterVar parent, IterVar rebased);
809
810 TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
811};
812
813/*!
814 * \brief Singleton iterator [0, 1)
815 */
816class SingletonNode : public IterVarRelationNode {
817 public:
818 /*! \brief The singleton iterator */
819 IterVar iter;
820
821 void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
822
823 static constexpr const char* _type_key = "Singleton";
824 TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
825};
826
827/*!
828 * \brief Managed reference to SingletonNode
829 * \sa SingletonNode
830 */
831class Singleton : public IterVarRelation {
832 public:
833 TVM_DLL explicit Singleton(IterVar iter);
834
835 TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
836};
837
838/*!
839 * \brief Transform iterator according to some arbitrary expression.
840 */
841class TransformNode : public IterVarRelationNode {
842 public:
843 /*! \brief The loop variables that were replaced by the transformation.
844 *
845 * Prior to applying a layout transformation, these represent the
846 * loops to iterate over a tensor as it is being computed, following
847 * a row-major traversal of the tensor's original shape in the
848 * compute definition.
849 */
850 Array<IterVar> original_variables;
851
852 /*! \brief The variables generated by the transformation.
853 *
854 * After to applying a layout transformation, these represent the
855 * loops to iterate over a tensor as it is being computed, following
856 * a row-major traversal of the transformed shape of the tensor.
857 */
858 Array<IterVar> transformed_variables;
859
860 /*! \brief Map from the original variables to the transformed variables.
861 *
862 * Used to determine iterator ranges over the transformed variables.
863 */
864 IndexMap forward_transformation;
865
866 /*! \brief Map from transformed variables to the original variables
867 *
868 * Used to rewrite expressions containing the original loop iterators
869 * in terms of the transformed loop iterators.
870 */
871 IndexMap inverse_transformation;
872
873 void VisitAttrs(AttrVisitor* v) {
874 v->Visit("original_variables", &original_variables);
875 v->Visit("transformed_variables", &transformed_variables);
876 v->Visit("forward_transformation", &forward_transformation);
877 v->Visit("inverse_transformation", &inverse_transformation);
878 }
879
880 static constexpr const char* _type_key = "Transform";
881 TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode);
882};
883
884class Transform : public IterVarRelation {
885 public:
886 TVM_DLL explicit Transform(Array<IterVar> original_variables,
887 Array<IterVar> transformed_variables, IndexMap forward_transformation,
888 IndexMap inverse_transformation);
889
890 TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode);
891};
892
893/*! \brief Container for specialization conditions. */
894class SpecializedConditionNode : public Object {
895 public:
896 /*!
897 * \brief List of conditions in conjunctive joint form (CNF).
898 * Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
899 * where n, m are tvm::Var that represents a dimension in the tensor shape.
900 */
901 Array<PrimExpr> clauses;
902
903 void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
904
905 static constexpr const char* _type_key = "SpecializedCondition";
906 TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
907};
908
909/*!
910 * \brief Specialized condition to enable op specialization
911 */
912class SpecializedCondition : public ObjectRef {
913 public:
914 /*!
915 * \brief construct from conditions
916 * \param conditions The clauses in the specialized condition.
917 */
918 TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
919
920 /*!
921 * \brief Get the current specialized condition.
922 * \return the current specialized condition.
923 */
924 TVM_DLL static SpecializedCondition Current();
925
926 TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
927 class Internal;
928
929 private:
930 // enable with syntax.
931 friend class Internal;
932 friend class With<SpecializedCondition>;
933 /*! \brief Push a new specialized condition onto the thread local stack. */
934 TVM_DLL void EnterWithScope();
935 /*! \brief Pop a specialized condition off the thread local context stack. */
936 TVM_DLL void ExitWithScope();
937};
938
939// implementations
940inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
941inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
942
943inline const ScheduleNode* Schedule::operator->() const {
944 return static_cast<const ScheduleNode*>(get());
945}
946inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
947
948inline const IterVarRelationNode* IterVarRelation::operator->() const {
949 return static_cast<const IterVarRelationNode*>(get());
950}
951
952inline const IterVarAttrNode* IterVarAttr::operator->() const {
953 return static_cast<const IterVarAttrNode*>(get());
954}
955
956} // namespace te
957} // namespace tvm
958#endif // TVM_TE_SCHEDULE_H_
959