1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/transform_step.h
22 * \brief Transformation steps. These steps are used to manipulate `LoopState`.
23 * They are similar to the schedule primitives in te::Stage.
24 *
25 * \note How to add a new transform step:
26 * Take fuse step for example:
27 * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first
28 * construction function `FuseStep::FuseStep()` in `transform_steps.cc`.
29 * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`.
30 * - In these two functions you need to lower this step with tvm's te schedule API
31 * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`.
32 * - In these two functions you need to incrementally update all data structures in State with
33 * CopyOnWrite style.
34 * 4. Add your step to `StepApplyToState`, `StepApplyToSchedule`, and `StepPrintAsPythonAPI`.
35 * 5. Log record serialization support:
36 * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and
37 * output the record to it.
38 * - Add another construction function that takes a mutable JSONReader as input, this will get a
39 * step record from the reader and create the step.
40 * - Add the step implementation to `StepReadFromRecord`.
41 * 6. Add its corresponding Python API to `loop_state.py` with necessary unit tests. The test should
42 * at lease cover two parts: the functional test and the record serialization test.
43 */
44
45#ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
46#define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
47
48#include <dmlc/common.h>
49#include <dmlc/json.h>
50#include <tvm/node/node.h>
51#include <tvm/te/schedule.h>
52
53#include <vector>
54
55namespace tvm {
56namespace auto_scheduler {
57
58typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
59
60/*!
61 * \brief Update the current stage IterVar information to StageToAxesMap.
62 * \param stage The stage to be updated.
63 * \param stage_to_axes The map to be updated.
64 */
65void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
66
67/*! \brief The type of an iterator. */
68enum class IteratorKind : int {
69 /*! \brief Spatial iterator. */
70 kSpatial = 0,
71 /*! \brief Reduction iterator. */
72 kReduction = 1,
73 /*! \brief Fused spatial and reduction iterator. */
74 kMixed = 2,
75 /*! \brief Special iterator. (e.g. virtual root iterator) */
76 kSpecial = 3
77};
78
79/*! \brief The type of an iterator's annotation. */
80enum class IteratorAnnotation : int {
81 /*! \brief This iterator has no annotation. */
82 kNone = 0,
83 /*! \brief This iterator has been unrolled. */
84 kUnroll = 1,
85 /*! \brief This iterator has been vectorized. */
86 kVectorize = 2,
87 /*! \brief This iterator has been paralleld. */
88 kParallel = 3,
89 /*! \brief This iterator has been bind to vthread. */
90 kVThread = 4,
91 /*! \brief This iterator has been bind to blockIdx.x. */
92 kBlockX = 5,
93 /*! \brief This iterator has been bind to threadIdx.x. */
94 kThreadX = 6,
95 /*! \brief This iterator has been bind to blockIdx.y. */
96 kBlockY = 7,
97 /*! \brief This iterator has been bind to threadIdx.y. */
98 kThreadY = 8,
99 /*! \brief This iterator has been bind to blockIdx.y. */
100 kBlockZ = 9,
101 /*! \brief This iterator has been bind to threadIdx.y. */
102 kThreadZ = 10,
103 /*! \brief This iterator has been mapped with a tensorize intrinsic. */
104 kTensorize = 11
105};
106
107extern const char* IteratorAnnotationString[];
108
109// forward declaration
110class Iterator;
111
112/*!
113 * \brief An iterator of a for-loop
114 * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
115 */
116class IteratorNode : public Object {
117 public:
118 /*! \brief The name of this iterator. */
119 String name;
120 /*! \brief The range of this iterator. */
121 Range range;
122 /*! \brief The iterator type of this iterator. */
123 IteratorKind iter_kind;
124 /*! \brief The annotation type of this iterator. */
125 IteratorAnnotation annotation;
126 /*! The original iterators before fusion. */
127 std::vector<Iterator> orig_iters;
128
129 void VisitAttrs(tvm::AttrVisitor* v) {
130 v->Visit("name", &name);
131 v->Visit("range", &range);
132 v->Visit("iter_kind", &iter_kind);
133 v->Visit("annotation", &annotation);
134 }
135
136 static constexpr const char* _type_key = "auto_scheduler.Iterator";
137 TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
138};
139
140/*!
141 * \brief Managed reference to IteratorNode.
142 * \sa IteratorNode
143 */
144class Iterator : public ObjectRef {
145 public:
146 /*!
147 * \brief The constructor.
148 * \param name The name of this iterator.
149 * \param range The range of this iterator.
150 * \param iter_kind The iterator type of this iterator.
151 * \param annotation The annotation type of this iterator.
152 * \param orig_iters The original iterators before fusion
153 */
154 Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
155 const std::vector<Iterator>* orig_iters = nullptr);
156
157 TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
158};
159
160/*!
161 * \brief The base class of transformation steps. Each step has its corresponding tvm.te
162 * schedule primitives.
163 */
164class StepNode : public Object {
165 public:
166 /*! \brief The index of the stage. */
167 int stage_id;
168
169 /*!
170 * \brief Serialize the current step record to JSONWriter.
171 * \param writer The output JSONWriter.
172 */
173 virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0;
174
175 static constexpr const char* _type_key = "auto_scheduler.Step";
176 TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object);
177};
178
179/*!
180 * \brief Managed reference to StepNode.
181 * \sa StepNode
182 */
183class Step : public ObjectRef {
184 public:
185 /*!
186 * \brief CopyOnWrite function for Step.
187 * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
188 * steps.
189 * \return A base StepNode pointer, need to cast to its real StepNode type before doing any
190 * modifications.
191 * \code
192 *
193 * SplitStep ref;
194 * StepNode* mutable_ref = ref.CopyOnWrite();
195 * dynamic_cast<SplitStepNode*>(mutable_ref)->... = ...;
196 *
197 * \endcode
198 */
199 StepNode* CopyOnWrite();
200
201 TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
202};
203
204// Forward declaration
205class State;
206class ComputeDAG;
207
208/*!
209 * \brief Read a step record from JSONReader and create the corresponding step.
210 * \param reader The input JSONReader.
211 */
212Step StepReadFromRecord(dmlc::JSONReader* reader);
213
214/*!
215 * \brief Apply a general step to a State with runtime dynamic dispatching.
216 * \param step The step to be applied to State.
217 * \param state A mutable pointer to state, which will be updated.
218 * \param dag The original ComputeDAG of this state.
219 */
220void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
221
222/*!
223 * \brief Apply a general step to tvm.schedule with runtime dynamic dispatching.
224 * \param step The step to be applied to tvm.schedule.
225 * \param stages The list of current stages
226 * \param stage_to_axes A map that maps stage ot all its iterators.
227 * \param schedule A mutable point to the current schedule
228 * \param transform_steps An array of all history transform steps.
229 */
230void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
231 te::Schedule* schedule, const Array<Step>& transform_steps);
232
233/*!
234 * \brief Print a general step as equivalent python schedule API with runtime dynamic dispatching.
235 * \param step The step to be printed as python API.
236 * \param stages The list of current stages
237 * \param stage_to_axes A map that maps stage ot all its iterators.
238 * \param schedule A mutable point to the current schedule
239 * \param transform_steps An array of all history transform steps.
240 * \return Python schedule code.
241 */
242String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
243 StageToAxesMap* stage_to_axes, te::Schedule* schedule,
244 const Array<Step>& transform_steps);
245
246/********** Steps working on single stage **********/
247
248/*!
249 * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding.
250 * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind)
251 */
252class AnnotationStepNode : public StepNode {
253 public:
254 /*! \brief The index of the iterator to add annotation. */
255 int iter_id;
256 /*! \brief The annotation type of this step. */
257 IteratorAnnotation annotation;
258
259 void WriteToRecord(dmlc::JSONWriter* writer) const final;
260
261 /*!
262 * \brief Apply the current step to State.
263 * \param state A mutable pointer to state, which will be updated.
264 * \return The iterator result after annotate.
265 */
266 Iterator ApplyToState(State* state) const;
267
268 /*!
269 * \brief Apply the current step to tvm.schedule.
270 * \param stages The list of current stages
271 * \param stage_to_axes A map that maps stage ot all its iterators.
272 */
273 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
274
275 /*!
276 * \brief Print the current step as equivalent python schedule API.
277 * \param stages The list of current stages
278 * \param stage_to_axes A map that maps stage ot all its iterators.
279 * \return Python schedule code.
280 */
281 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
282
283 static constexpr const char* record_prefix_str = "AN";
284
285 static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
286 TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode);
287};
288
289/*!
290 * \brief Managed reference to AnnotationStepNode.
291 * \sa AnnotationStepNode
292 */
293class AnnotationStep : public Step {
294 public:
295 /*!
296 * \brief The constructor.
297 * \param stage_id The index of the stage to add annotation.
298 * \param iter_id The index of the iterator to add annotation.
299 * \param ann The annotation type of this step.
300 */
301 AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann);
302
303 /*!
304 * \brief The constructor used to read a step record from JSONReader and create the
305 * corresponding step.
306 * \param reader The input JSONReader.
307 */
308 explicit AnnotationStep(dmlc::JSONReader* reader);
309
310 TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode);
311};
312
313/*! \brief Fuse step that corresponds to te::Stage::fuse */
314class FuseStepNode : public StepNode {
315 public:
316 /*! \brief The ids of iterators to fuse. */
317 Array<Integer> fused_ids;
318
319 void WriteToRecord(dmlc::JSONWriter* writer) const final;
320
321 /*!
322 * \brief Apply the current step to State.
323 * \param state A mutable pointer to state, which will be updated.
324 * \return The iterator result after fuse.
325 * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
326 * result will become the new attach point.
327 */
328 Iterator ApplyToState(State* state) const;
329
330 /*!
331 * \brief Apply the current step to tvm.schedule.
332 * \param stages The list of current stages
333 * \param stage_to_axes A map that maps stage ot all its iterators.
334 * \return The iterator result after fuse.
335 */
336 tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
337
338 /*!
339 * \brief Print the current step as equivalent python schedule API.
340 * \param stages The list of current stages
341 * \param stage_to_axes A map that maps stage ot all its iterators.
342 * \return Python schedule code.
343 */
344 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
345
346 static constexpr const char* record_prefix_str = "FU";
347
348 static constexpr const char* _type_key = "auto_scheduler.FuseStep";
349 TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode);
350};
351
352/*!
353 * \brief Managed reference to FuseStepNode.
354 * \sa FuseStepNode
355 */
356class FuseStep : public Step {
357 public:
358 /*!
359 * \brief The constructor.
360 * \param stage_id The index of the stage to be fused.
361 * \param fused_ids The index of the iterators to be fused.
362 */
363 FuseStep(int stage_id, const Array<Integer>& fused_ids);
364
365 /*!
366 * \brief The constructor used to read a step record from JSONReader and create the
367 * corresponding step.
368 * \param reader The input JSONReader.
369 */
370 explicit FuseStep(dmlc::JSONReader* reader);
371
372 TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
373};
374
375/*! \brief Pragma step that corresponds to te::Stage::pragma */
376class PragmaStepNode : public StepNode {
377 public:
378 /*! \brief The index of the iterator to add pragma. */
379 int iter_id;
380 /*! \brief The pragma string. */
381 String pragma_type;
382
383 void WriteToRecord(dmlc::JSONWriter* writer) const final;
384
385 /*!
386 * \brief Apply the current step to State.
387 * \param state A mutable pointer to state, which will be updated.
388 */
389 void ApplyToState(State* state) const;
390
391 /*!
392 * \brief Apply the current step to tvm.schedule.
393 * \param stages The list of current stages
394 * \param stage_to_axes A map that maps stage ot all its iterators.
395 */
396 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
397
398 /*!
399 * \brief Print the current step as equivalent python schedule API.
400 * \param stages The list of current stages
401 * \param stage_to_axes A map that maps stage ot all its iterators.
402 * \return Python schedule code.
403 */
404 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
405
406 static constexpr const char* record_prefix_str = "PR";
407
408 static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
409 TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode);
410};
411
412/*!
413 * \brief Managed reference to PragmaStepNode.
414 * \sa PragmaStepNode
415 */
416class PragmaStep : public Step {
417 public:
418 /*!
419 * \brief The constructor.
420 * \param stage_id The index of the stage to be fused.
421 * \param iter_id The index of the iterator to add pragma.
422 * \param pragma_type The pragma string.
423 */
424 PragmaStep(int stage_id, int iter_id, String pragma_type);
425
426 /*!
427 * \brief The constructor used to read a step record from JSONReader and create the
428 * corresponding step.
429 * \param reader The input JSONReader.
430 */
431 explicit PragmaStep(dmlc::JSONReader* reader);
432
433 TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
434};
435
436/*! \brief Reorder step that corresponds to te::Stage::reorder */
437class ReorderStepNode : public StepNode {
438 public:
439 /*!
440 * \brief The iterator ids after reorder.
441 * This array should specify the order of all iterators.
442 */
443 Array<Integer> after_ids;
444
445 void WriteToRecord(dmlc::JSONWriter* writer) const final;
446
447 /*!
448 * \brief Apply the current step to State.
449 * \param state A mutable pointer to state, which will be updated.
450 */
451 void ApplyToState(State* state) const;
452
453 /*!
454 * \brief Apply the current step to tvm.schedule.
455 * \param stages The list of current stages
456 * \param stage_to_axes A map that maps stage ot all its iterators.
457 */
458 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
459
460 /*!
461 * \brief Print the current step as equivalent python schedule API.
462 * \param stages The list of current stages
463 * \param stage_to_axes A map that maps stage ot all its iterators.
464 * \return Python schedule code.
465 */
466 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
467
468 static constexpr const char* record_prefix_str = "RE";
469
470 static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
471 TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode);
472};
473
474/*!
475 * \brief Managed reference to ReorderStepNode.
476 * \sa ReorderStepNode
477 */
478class ReorderStep : public Step {
479 public:
480 /*!
481 * \brief The constructor.
482 * \param stage_id The index of the stage to be reordered.
483 * \param after_ids The expected indexes of the iterators after reorder.
484 */
485 ReorderStep(int stage_id, const Array<Integer>& after_ids);
486
487 /*!
488 * \brief The constructor used to read a step record from JSONReader and create the
489 * corresponding step.
490 * \param reader The input JSONReader.
491 */
492 explicit ReorderStep(dmlc::JSONReader* reader);
493
494 TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode);
495};
496
497/*!
498 * \brief Split step that corresponds to te::Stage::split with additional
499 * support of multiple-level of factors
500 */
501class SplitStepNode : public StepNode {
502 public:
503 /*! \brief The id of the iter to split. */
504 int iter_id;
505 /*! \brief The extent length of the axis to split. */
506 Optional<PrimExpr> extent;
507 /*! \brief The split factors. */
508 Array<Optional<Integer>> lengths;
509 /*!
510 * \brief If true, the `lengths` denote the lengths of iterators
511 * from inner level to outer level
512 */
513 bool inner_to_outer;
514
515 void WriteToRecord(dmlc::JSONWriter* writer) const final;
516
517 /*!
518 * \brief Apply the current step to State.
519 * \param state A mutable pointer to state, which will be updated.
520 * \return The iterator results after split.
521 * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
522 * most iterator of split results will become the new attach point.
523 */
524 Array<Iterator> ApplyToState(State* state) const;
525
526 /*!
527 * \brief Apply the current step to tvm.schedule.
528 * \param stages The list of current stages
529 * \param stage_to_axes A map that maps stage ot all its iterators.
530 * \return The iterator results after split.
531 */
532 Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
533 StageToAxesMap* stage_to_axes) const;
534
535 /*!
536 * \brief Print the current step as equivalent python schedule API.
537 * \param stages The list of current stages
538 * \param stage_to_axes A map that maps stage ot all its iterators.
539 * \return Python schedule code.
540 */
541 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
542
543 static constexpr const char* record_prefix_str = "SP";
544
545 static constexpr const char* _type_key = "auto_scheduler.SplitStep";
546 TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode);
547};
548
549/*!
550 * \brief Managed reference to SplitStepNode.
551 * \sa SplitStepNode
552 */
553class SplitStep : public Step {
554 public:
555 /*!
556 * \brief The constructor.
557 * \param stage_id The index of the stage to be split.
558 * \param iter_id The index of the iterator to be split.
559 * \param extent The extent length of the axis to split.
560 * \param lengths The multiple split factors. Can be None to be filled by search policy.
561 * \param inner_to_outer The split direction.
562 */
563 SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
564 const Array<Optional<Integer>>& lengths, bool inner_to_outer);
565
566 /*!
567 * \brief The constructor used to read a step record from JSONReader and create the
568 * corresponding step.
569 * \param reader The input JSONReader.
570 */
571 explicit SplitStep(dmlc::JSONReader* reader);
572
573 TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
574};
575
576/*! \brief Similar to SplitStepNode, but uses split factors from another step
577 * (i.e. Follow another split step) */
578class FollowSplitStepNode : public StepNode {
579 public:
580 /*! \brief The id of the iter to be split. */
581 int iter_id;
582 /*! \brief The index of the split step to be followed in the history. */
583 int src_step_id;
584 /*! \brief The number of split level. */
585 int n_split;
586
587 void WriteToRecord(dmlc::JSONWriter* writer) const final;
588
589 /*!
590 * \brief Extract split lengths.
591 * \param transform_steps An array of history transform steps.
592 * \return The multiple split factors.
593 */
594 Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
595
596 /*!
597 * \brief Apply the current step to State.
598 * \param state A mutable pointer to state, which will be updated.
599 * \return The iterator results after split.
600 */
601 Array<Iterator> ApplyToState(State* state) const;
602
603 /*!
604 * \brief Apply the current step to tvm.schedule.
605 * \param stages The list of current stages
606 * \param stage_to_axes A map that maps stage ot all its iterators.
607 * \param transform_steps An array of history transform steps.
608 * \return The iterator results after split.
609 */
610 Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
611 const Array<Step>& transform_steps) const;
612
613 /*!
614 * \brief Print the current step as equivalent python schedule API.
615 * \param stages The list of current stages
616 * \param stage_to_axes A map that maps stage ot all its iterators.
617 * \param transform_steps An array of history transform steps.
618 * \return Python schedule code.
619 */
620 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
621 const Array<Step>& transform_steps) const;
622
623 static constexpr const char* record_prefix_str = "FSP";
624
625 static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
626 TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode);
627};
628
629/*!
630 * \brief Managed reference to FollowSplitStepNode.
631 * \sa FollowSplitStepNode
632 */
633class FollowSplitStep : public Step {
634 public:
635 /*!
636 * \brief The constructor.
637 * \param stage_id The index of the stage to be split.
638 * \param iter_id The index of the iterator to be split.
639 * \param src_step_id The index of the split step to be followed in the history.
640 * \param n_split The number of split level.
641 */
642 FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
643
644 /*!
645 * \brief The constructor used to read a step record from JSONReader and create the
646 * corresponding step.
647 * \param reader The input JSONReader.
648 */
649 explicit FollowSplitStep(dmlc::JSONReader* reader);
650
651 TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
652};
653
654/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
655 * \note This can be used for the split in cooperative fetching.
656 */
657class FollowFusedSplitStepNode : public StepNode {
658 public:
659 /*! \brief The id of the iter to split. */
660 int iter_id;
661 /*! \brief The indices of the split steps to be followed in the history. */
662 Array<Integer> src_step_ids;
663 /*! \brief Use the length in this split level. */
664 int level;
665 /*! \brief If this is true, use factor. Otherwise, use nparts. */
666 bool factor_or_nparts;
667
668 void WriteToRecord(dmlc::JSONWriter* writer) const final;
669
670 /*!
671 * \brief Extract split length.
672 * \param transform_steps An array of history transform steps.
673 * \return Split factor.
674 */
675 Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
676
677 /*!
678 * \brief Apply the current step to State.
679 * \param state A mutable pointer to state, which will be updated.
680 * \return The iterator results after split.
681 */
682 Array<Iterator> ApplyToState(State* state) const;
683
684 /*!
685 * \brief Apply the current step to tvm.schedule.
686 * \param stages The list of current stages
687 * \param stage_to_axes A map that maps stage ot all its iterators.
688 * \param transform_steps An array of history transform steps.
689 * \return The iterator results after split.
690 */
691 Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
692 const Array<Step>& transform_steps) const;
693
694 /*!
695 * \brief Print the current step as equivalent python schedule API.
696 * \param stages The list of current stages
697 * \param stage_to_axes A map that maps stage ot all its iterators.
698 * \param transform_steps An array of history transform steps.
699 * \return Python schedule code.
700 */
701 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
702 const Array<Step>& transform_steps) const;
703
704 static constexpr const char* record_prefix_str = "FFSP";
705
706 static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
707 TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode);
708};
709
710/*!
711 * \brief Managed reference to FollowFusedSplitStepNode.
712 * \sa FollowFusedSplitStepNode
713 */
714class FollowFusedSplitStep : public Step {
715 public:
716 /*!
717 * \brief The constructor.
718 * \param stage_id The index of the stage to be split.
719 * \param iter_id The index of the iterator to be split.
720 * \param src_step_ids An array of index for split step to be followed in the history.
721 * \param level Use the length in this split level.
722 * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
723 */
724 FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
725 bool factor_or_nparts);
726
727 /*!
728 * \brief The constructor used to read a step record from JSONReader and create the
729 * corresponding step.
730 * \param reader The input JSONReader.
731 */
732 explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
733
734 TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
735};
736
737/*! \brief Storage align step that corresponds to te::Stage::storage_align */
738class StorageAlignStepNode : public StepNode {
739 public:
740 /*! \brief The iterator to be aligned. */
741 int iter_id;
742 /*! \brief The factor in alignment specification. */
743 int factor;
744 /*! \brief The offset in the alignment specification. */
745 int offset;
746
747 void WriteToRecord(dmlc::JSONWriter* writer) const final;
748
749 /*!
750 * \brief Apply the current step to State.
751 * \param state A mutable pointer to State, which will be updated.
752 */
753 void ApplyToState(State* state) const;
754
755 /*!
756 * \brief Apply the current step to tvm.schedule.
757 * \param stages The list of current stages
758 * \param stage_to_axes A map that maps stage ot all its iterators.
759 */
760 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
761
762 /*!
763 * \brief Print the current step as equivalent python schedule API.
764 * \param stages The list of current stages
765 * \param stage_to_axes A map that maps stage ot all its iterators.
766 * \return Python schedule code.
767 */
768 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
769
770 static constexpr const char* record_prefix_str = "SA";
771
772 static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
773 TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode);
774};
775
776/*!
777 * \brief Managed reference to StorageAlignStepNode.
778 * \sa StorageAlignStepNode
779 */
780class StorageAlignStep : public Step {
781 public:
782 /*!
783 * \brief The constructor.
784 * \param stage_id The index of the stage to be aligned.
785 * \param iter_id The index of the iterator to be aligned.
786 * \param factor The factor in alignment specification.
787 * \param offset The offset in the alignment specification.
788 */
789 StorageAlignStep(int stage_id, int iter_id, int factor, int offset);
790
791 /*!
792 * \brief The constructor used to read a step record from JSONReader and create the
793 * corresponding step.
794 * \param reader The input JSONReader.
795 */
796 explicit StorageAlignStep(dmlc::JSONReader* reader);
797
798 TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
799};
800
801/********** Steps working on multiple stages **********/
802
803/*! \brief Compute at step that corresponds to te::Stage::compute_at */
804class ComputeAtStepNode : public StepNode {
805 public:
806 /*! \brief The index of stage that this step will compute at to. */
807 int target_stage_id;
808 /*! \brief The index of iterator in target stage that this step will compute at to. */
809 int target_iter_id;
810
811 void WriteToRecord(dmlc::JSONWriter* writer) const final;
812
813 /*!
814 * \brief Apply the current step to State.
815 * \param state A mutable pointer to state, which will be updated.
816 * \note After compute_at, we need careful dependency analysis to compute the accurate bound
817 * information. However, it is relatively expensive and complicated, so we just fill "None" as
818 * bound for the newly created iterators.
819 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
820 */
821 void ApplyToState(State* state) const;
822
823 /*!
824 * \brief Apply the current step to tvm.schedule.
825 * \param stages The list of current stages
826 * \param stage_to_axes A map that maps stage ot all its iterators.
827 */
828 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
829
830 /*!
831 * \brief Print the current step as equivalent python schedule API.
832 * \param stages The list of current stages
833 * \param stage_to_axes A map that maps stage ot all its iterators.
834 * \return Python schedule code.
835 */
836 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
837
838 static constexpr const char* record_prefix_str = "CA";
839
840 static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
841 TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode);
842};
843
844/*!
845 * \brief Managed reference to ComputeAtStepNode.
846 * \sa ComputeAtStepNode
847 */
848class ComputeAtStep : public Step {
849 public:
850 /*!
851 * \brief The constructor.
852 * \param stage_id The index of the source stage.
853 * \param target_stage_id The index of stage that this step will compute at to.
854 * \param target_iter_id The index of iterator in target stage that this step will compute at to.
855 */
856 ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id);
857
858 /*!
859 * \brief The constructor used to read a step record from JSONReader and create the
860 * corresponding step.
861 * \param reader The input JSONReader.
862 */
863 explicit ComputeAtStep(dmlc::JSONReader* reader);
864
865 TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode);
866};
867
868/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */
869class ComputeInlineStepNode : public StepNode {
870 public:
871 void WriteToRecord(dmlc::JSONWriter* writer) const final;
872
873 /*!
874 * \brief Apply the current step to State.
875 * \param state A mutable pointer to state, which will be updated.
876 */
877 void ApplyToState(State* state) const;
878
879 /*!
880 * \brief Apply the current step to tvm.schedule.
881 * \param stages The list of current stages
882 * \param stage_to_axes A map that maps stage ot all its iterators.
883 * \return The iterator result after fuse.
884 */
885 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
886
887 /*!
888 * \brief Print the current step as equivalent python schedule API.
889 * \param stages The list of current stages
890 * \param stage_to_axes A map that maps stage ot all its iterators.
891 * \return Python schedule code.
892 */
893 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
894
895 static constexpr const char* record_prefix_str = "CI";
896
897 static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
898 TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode);
899};
900
901/*!
902 * \brief Managed reference to ComputeInlineStepNode.
903 * \sa ComputeInlineStepNode
904 */
905class ComputeInlineStep : public Step {
906 public:
907 /*!
908 * \brief The constructor.
909 * \param stage_id The index of the stage to be marked compute inlined.
910 */
911 explicit ComputeInlineStep(int stage_id);
912
913 /*!
914 * \brief The constructor used to read a step record from JSONReader and create the
915 * corresponding step.
916 * \param reader The input JSONReader.
917 */
918 explicit ComputeInlineStep(dmlc::JSONReader* reader);
919
920 TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode);
921};
922
923/*! \brief Compute root step that corresponds to te::Stage::compute_root */
924class ComputeRootStepNode : public StepNode {
925 public:
926 void WriteToRecord(dmlc::JSONWriter* writer) const final;
927
928 /*!
929 * \brief Apply the current step to State.
930 * \param state A mutable pointer to state, which will be updated.
931 * \note After compute_root, we need careful dependency analysis to compute the accurate bound
932 * information. However, it is relatively expensive and complicated, so we just fill "None" as
933 * bound for the newly created iterators.
934 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
935 */
936 void ApplyToState(State* state) const;
937
938 /*!
939 * \brief Apply the current step to tvm.schedule.
940 * \param stages The list of current stages
941 * \param stage_to_axes A map that maps stage ot all its iterators.
942 * \return The iterator result after fuse.
943 */
944 void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
945
946 /*!
947 * \brief Print the current step as equivalent python schedule API.
948 * \param stages The list of current stages
949 * \param stage_to_axes A map that maps stage ot all its iterators.
950 * \return Python schedule code.
951 */
952 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
953
954 static constexpr const char* record_prefix_str = "CR";
955
956 static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
957 TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode);
958};
959
960/*!
961 * \brief Managed reference to ComputeRootStepNode.
962 * \sa ComputeRootStepNode
963 */
964class ComputeRootStep : public Step {
965 public:
966 /*!
967 * \brief The constructor.
968 * \param stage_id The index of the stage to be marked compute at root.
969 */
970 explicit ComputeRootStep(int stage_id);
971
972 /*!
973 * \brief The constructor used to read a step record from JSONReader and create the
974 * corresponding step.
975 * \param reader The input JSONReader.
976 */
977 explicit ComputeRootStep(dmlc::JSONReader* reader);
978
979 TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
980};
981
982/********** Steps adding new stages **********/
983
984/*!
985 * \brief Cache read step that corresponds to te::Schedule::cache_read.
986 * \note Cache read step adds an extra stage to the original ComputeDAG,
987 * an up-to-date ComputeDAG will be stored in State's `current_compute_dag`.
988 */
989class CacheReadStepNode : public StepNode {
990 public:
991 /*! \brief The scope name of the newly added read stage. (e.g., local, shared, global) */
992 String scope_name;
993 /*! \brief The indices of read stages. */
994 Array<Integer> reader_stage_ids;
995
996 void WriteToRecord(dmlc::JSONWriter* writer) const final;
997
998 /*!
999 * \brief Apply the current step to State.
1000 * \param state A mutable pointer to state, which will be updated.
1001 * \param dag The original ComputeDAG of this state.
1002 * \return The index of the new added stage.
1003 */
1004 int ApplyToState(State* state, const ComputeDAG& dag) const;
1005
1006 /*!
1007 * \brief Apply the current step to tvm.schedule.
1008 * \param stages The list of current stages
1009 * \param stage_to_axes A map that maps stage ot all its iterators.
1010 * \param schedule A mutable pointer to a te::Schedule.
1011 * \return The output Tensor of the new added stage.
1012 */
1013 te::Tensor ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1014 te::Schedule* schedule) const;
1015
1016 /*!
1017 * \brief Print the current step as equivalent python schedule API.
1018 * \param stages The list of current stages
1019 * \param stage_to_axes A map that maps stage ot all its iterators.
1020 * \param schedule A mutable pointer to a te::Schedule.
1021 * \return Python schedule code.
1022 */
1023 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1024 te::Schedule* schedule) const;
1025
1026 static constexpr const char* record_prefix_str = "CHR";
1027
1028 static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
1029 TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode);
1030};
1031
1032/*!
1033 * \brief Managed reference to CacheReadStepNode.
1034 * \sa CacheReadStepNode
1035 */
1036class CacheReadStep : public Step {
1037 public:
1038 /*!
1039 * \brief The constructor.
1040 * \param stage_id The index of the stage to be cache_read.
1041 * \param scope_name The scope name of the newly added stage.
1042 * \param reader_stage_ids The indices of reader stages.
1043 */
1044 CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
1045
1046 /*!
1047 * \brief The constructor used to read a step record from JSONReader and create the
1048 * corresponding step.
1049 * \param reader The input JSONReader.
1050 */
1051 explicit CacheReadStep(dmlc::JSONReader* reader);
1052
1053 TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode);
1054};
1055
1056/*!
1057 * \brief Cache write step that corresponds to te::Schedule::cache_write.
1058 * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date
1059 * ComputeDAG is stored in State's `current_compute_dag`.
1060 * This step will cache write all output tensors of the target stage.
1061 */
1062class CacheWriteStepNode : public StepNode {
1063 public:
1064 /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */
1065 String scope_name;
1066
1067 void WriteToRecord(dmlc::JSONWriter* writer) const final;
1068
1069 /*!
1070 * \brief Apply the current step to State.
1071 * \param state A mutable pointer to state, which will be updated.
1072 * \param dag The original ComputeDAG of this state.
1073 * \return The index of the new added stage.
1074 */
1075 int ApplyToState(State* state, const ComputeDAG& dag) const;
1076
1077 /*!
1078 * \brief Apply the current step to tvm.schedule.
1079 * \param stages The list of current stages
1080 * \param stage_to_axes A map that maps stage ot all its iterators.
1081 * \param schedule A mutable pointer to a te::Schedule.
1082 * \return The output Tensors of the new added stage.
1083 */
1084 Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1085 te::Schedule* schedule) const;
1086
1087 /*!
1088 * \brief Print the current step as equivalent python schedule API.
1089 * \param stages The list of current stages
1090 * \param stage_to_axes A map that maps stage ot all its iterators.
1091 * \param schedule A mutable pointer to a te::Schedule.
1092 * \return Python schedule code.
1093 */
1094 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1095 te::Schedule* schedule) const;
1096
1097 static constexpr const char* record_prefix_str = "CHW";
1098
1099 static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
1100 TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode);
1101};
1102
1103/*!
1104 * \brief Managed reference to CacheWriteStepNode.
1105 * \sa CacheWriteStepNode
1106 */
1107class CacheWriteStep : public Step {
1108 public:
1109 /*!
1110 * \brief The constructor.
1111 * \param stage_id The index of the stage to be cache_write.
1112 * \param scope_name The scope name of the newly added stage.
1113 */
1114 CacheWriteStep(int stage_id, String scope_name);
1115
1116 /*!
1117 * \brief The constructor used to read a step record from JSONReader and create the
1118 * corresponding step.
1119 * \param reader The input JSONReader.
1120 */
1121 explicit CacheWriteStep(dmlc::JSONReader* reader);
1122
1123 TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
1124};
1125
1126/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
1127class RfactorStepNode : public StepNode {
1128 public:
1129 /*! \brief The index of the iterator to be factored. */
1130 int iter_id;
1131 /*! \brief The position where the new iterator is placed. */
1132 int factor_iter_id;
1133
1134 void WriteToRecord(dmlc::JSONWriter* writer) const final;
1135
1136 /*!
1137 * \brief Apply the current step to State.
1138 * \param state A mutable pointer to State, which will be updated.
1139 * \param dag The original ComputeDAG of this state.
1140 * \return The index of the new added stage.
1141 */
1142 int ApplyToState(State* state, const ComputeDAG& dag) const;
1143
1144 /*!
1145 * \brief Apply the current step to tvm.schedule.
1146 * \param stages The list of current stages
1147 * \param stage_to_axes A map that maps stage ot all its iterators.
1148 * \param schedule A mutable pointer to a te::Schedule.
1149 * \return The output Tensors of the new added stage.
1150 */
1151 Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1152 te::Schedule* schedule) const;
1153
1154 /*!
1155 * \brief Print the current step as equivalent python schedule API.
1156 * \param stages The list of current stages
1157 * \param stage_to_axes A map that maps stage ot all its iterators.
1158 * \param schedule A mutable pointer to a te::Schedule.
1159 * \return Python schedule code.
1160 */
1161 String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1162 te::Schedule* schedule) const;
1163
1164 static constexpr const char* record_prefix_str = "RF";
1165
1166 static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
1167 TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode);
1168};
1169
1170/*!
1171 * \brief Managed reference to RfactorStepNode.
1172 * \sa RfactorStepNode
1173 */
1174class RfactorStep : public Step {
1175 public:
1176 /*!
1177 * \brief The constructor.
1178 * \param stage_id The index of the stage to be factored.
1179 * \param iter_id The index of the iterator to be factored.
1180 * \param factor_iter_id The position where the new iterator is placed.
1181 */
1182 RfactorStep(int stage_id, int iter_id, int factor_iter_id);
1183
1184 /*!
1185 * \brief The constructor used to read a step record from JSONReader and create the
1186 * corresponding step.
1187 * \param reader The input JSONReader.
1188 */
1189 explicit RfactorStep(dmlc::JSONReader* reader);
1190
1191 TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
1192};
1193
1194} // namespace auto_scheduler
1195} // namespace tvm
1196
1197#endif // TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
1198