1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/transform_step.h |
22 | * \brief Transformation steps. These steps are used to manipulate `LoopState`. |
23 | * They are similar to the schedule primitives in te::Stage. |
24 | * |
25 | * \note How to add a new transform step: |
26 | * Take fuse step for example: |
27 | * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first |
28 | * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. |
29 | * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`. |
30 | * - In these two functions you need to lower this step with tvm's te schedule API |
31 | * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`. |
32 | * - In these two functions you need to incrementally update all data structures in State with |
33 | * CopyOnWrite style. |
34 | * 4. Add your step to `StepApplyToState`, `StepApplyToSchedule`, and `StepPrintAsPythonAPI`. |
35 | * 5. Log record serialization support: |
36 | * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and |
37 | * output the record to it. |
38 | * - Add another construction function that takes a mutable JSONReader as input, this will get a |
39 | * step record from the reader and create the step. |
40 | * - Add the step implementation to `StepReadFromRecord`. |
41 | * 6. Add its corresponding Python API to `loop_state.py` with necessary unit tests. The test should |
42 | * at lease cover two parts: the functional test and the record serialization test. |
43 | */ |
44 | |
45 | #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ |
46 | #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ |
47 | |
48 | #include <dmlc/common.h> |
49 | #include <dmlc/json.h> |
50 | #include <tvm/node/node.h> |
51 | #include <tvm/te/schedule.h> |
52 | |
53 | #include <vector> |
54 | |
55 | namespace tvm { |
56 | namespace auto_scheduler { |
57 | |
58 | typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap; |
59 | |
60 | /*! |
61 | * \brief Update the current stage IterVar information to StageToAxesMap. |
62 | * \param stage The stage to be updated. |
63 | * \param stage_to_axes The map to be updated. |
64 | */ |
65 | void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes); |
66 | |
67 | /*! \brief The type of an iterator. */ |
68 | enum class IteratorKind : int { |
69 | /*! \brief Spatial iterator. */ |
70 | kSpatial = 0, |
71 | /*! \brief Reduction iterator. */ |
72 | kReduction = 1, |
73 | /*! \brief Fused spatial and reduction iterator. */ |
74 | kMixed = 2, |
75 | /*! \brief Special iterator. (e.g. virtual root iterator) */ |
76 | kSpecial = 3 |
77 | }; |
78 | |
79 | /*! \brief The type of an iterator's annotation. */ |
80 | enum class IteratorAnnotation : int { |
81 | /*! \brief This iterator has no annotation. */ |
82 | kNone = 0, |
83 | /*! \brief This iterator has been unrolled. */ |
84 | kUnroll = 1, |
85 | /*! \brief This iterator has been vectorized. */ |
86 | kVectorize = 2, |
87 | /*! \brief This iterator has been paralleld. */ |
88 | kParallel = 3, |
89 | /*! \brief This iterator has been bind to vthread. */ |
90 | kVThread = 4, |
91 | /*! \brief This iterator has been bind to blockIdx.x. */ |
92 | kBlockX = 5, |
93 | /*! \brief This iterator has been bind to threadIdx.x. */ |
94 | kThreadX = 6, |
95 | /*! \brief This iterator has been bind to blockIdx.y. */ |
96 | kBlockY = 7, |
97 | /*! \brief This iterator has been bind to threadIdx.y. */ |
98 | kThreadY = 8, |
99 | /*! \brief This iterator has been bind to blockIdx.y. */ |
100 | kBlockZ = 9, |
101 | /*! \brief This iterator has been bind to threadIdx.y. */ |
102 | kThreadZ = 10, |
103 | /*! \brief This iterator has been mapped with a tensorize intrinsic. */ |
104 | kTensorize = 11 |
105 | }; |
106 | |
107 | extern const char* IteratorAnnotationString[]; |
108 | |
109 | // forward declaration |
110 | class Iterator; |
111 | |
112 | /*! |
113 | * \brief An iterator of a for-loop |
114 | * Similar to tvm::IterVar in `include/tvm/tir/expr.h` |
115 | */ |
116 | class IteratorNode : public Object { |
117 | public: |
118 | /*! \brief The name of this iterator. */ |
119 | String name; |
120 | /*! \brief The range of this iterator. */ |
121 | Range range; |
122 | /*! \brief The iterator type of this iterator. */ |
123 | IteratorKind iter_kind; |
124 | /*! \brief The annotation type of this iterator. */ |
125 | IteratorAnnotation annotation; |
126 | /*! The original iterators before fusion. */ |
127 | std::vector<Iterator> orig_iters; |
128 | |
129 | void VisitAttrs(tvm::AttrVisitor* v) { |
130 | v->Visit("name" , &name); |
131 | v->Visit("range" , &range); |
132 | v->Visit("iter_kind" , &iter_kind); |
133 | v->Visit("annotation" , &annotation); |
134 | } |
135 | |
136 | static constexpr const char* _type_key = "auto_scheduler.Iterator" ; |
137 | TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); |
138 | }; |
139 | |
140 | /*! |
141 | * \brief Managed reference to IteratorNode. |
142 | * \sa IteratorNode |
143 | */ |
144 | class Iterator : public ObjectRef { |
145 | public: |
146 | /*! |
147 | * \brief The constructor. |
148 | * \param name The name of this iterator. |
149 | * \param range The range of this iterator. |
150 | * \param iter_kind The iterator type of this iterator. |
151 | * \param annotation The annotation type of this iterator. |
152 | * \param orig_iters The original iterators before fusion |
153 | */ |
154 | Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation, |
155 | const std::vector<Iterator>* orig_iters = nullptr); |
156 | |
157 | TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); |
158 | }; |
159 | |
160 | /*! |
161 | * \brief The base class of transformation steps. Each step has its corresponding tvm.te |
162 | * schedule primitives. |
163 | */ |
164 | class StepNode : public Object { |
165 | public: |
166 | /*! \brief The index of the stage. */ |
167 | int stage_id; |
168 | |
169 | /*! |
170 | * \brief Serialize the current step record to JSONWriter. |
171 | * \param writer The output JSONWriter. |
172 | */ |
173 | virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0; |
174 | |
175 | static constexpr const char* _type_key = "auto_scheduler.Step" ; |
176 | TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); |
177 | }; |
178 | |
179 | /*! |
180 | * \brief Managed reference to StepNode. |
181 | * \sa StepNode |
182 | */ |
183 | class Step : public ObjectRef { |
184 | public: |
185 | /*! |
186 | * \brief CopyOnWrite function for Step. |
187 | * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different |
188 | * steps. |
189 | * \return A base StepNode pointer, need to cast to its real StepNode type before doing any |
190 | * modifications. |
191 | * \code |
192 | * |
193 | * SplitStep ref; |
194 | * StepNode* mutable_ref = ref.CopyOnWrite(); |
195 | * dynamic_cast<SplitStepNode*>(mutable_ref)->... = ...; |
196 | * |
197 | * \endcode |
198 | */ |
199 | StepNode* CopyOnWrite(); |
200 | |
201 | TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); |
202 | }; |
203 | |
204 | // Forward declaration |
205 | class State; |
206 | class ComputeDAG; |
207 | |
208 | /*! |
209 | * \brief Read a step record from JSONReader and create the corresponding step. |
210 | * \param reader The input JSONReader. |
211 | */ |
212 | Step StepReadFromRecord(dmlc::JSONReader* reader); |
213 | |
214 | /*! |
215 | * \brief Apply a general step to a State with runtime dynamic dispatching. |
216 | * \param step The step to be applied to State. |
217 | * \param state A mutable pointer to state, which will be updated. |
218 | * \param dag The original ComputeDAG of this state. |
219 | */ |
220 | void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); |
221 | |
222 | /*! |
223 | * \brief Apply a general step to tvm.schedule with runtime dynamic dispatching. |
224 | * \param step The step to be applied to tvm.schedule. |
225 | * \param stages The list of current stages |
226 | * \param stage_to_axes A map that maps stage ot all its iterators. |
227 | * \param schedule A mutable point to the current schedule |
228 | * \param transform_steps An array of all history transform steps. |
229 | */ |
230 | void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
231 | te::Schedule* schedule, const Array<Step>& transform_steps); |
232 | |
233 | /*! |
234 | * \brief Print a general step as equivalent python schedule API with runtime dynamic dispatching. |
235 | * \param step The step to be printed as python API. |
236 | * \param stages The list of current stages |
237 | * \param stage_to_axes A map that maps stage ot all its iterators. |
238 | * \param schedule A mutable point to the current schedule |
239 | * \param transform_steps An array of all history transform steps. |
240 | * \return Python schedule code. |
241 | */ |
242 | String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages, |
243 | StageToAxesMap* stage_to_axes, te::Schedule* schedule, |
244 | const Array<Step>& transform_steps); |
245 | |
246 | /********** Steps working on single stage **********/ |
247 | |
248 | /*! |
249 | * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. |
250 | * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) |
251 | */ |
252 | class AnnotationStepNode : public StepNode { |
253 | public: |
254 | /*! \brief The index of the iterator to add annotation. */ |
255 | int iter_id; |
256 | /*! \brief The annotation type of this step. */ |
257 | IteratorAnnotation annotation; |
258 | |
259 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
260 | |
261 | /*! |
262 | * \brief Apply the current step to State. |
263 | * \param state A mutable pointer to state, which will be updated. |
264 | * \return The iterator result after annotate. |
265 | */ |
266 | Iterator ApplyToState(State* state) const; |
267 | |
268 | /*! |
269 | * \brief Apply the current step to tvm.schedule. |
270 | * \param stages The list of current stages |
271 | * \param stage_to_axes A map that maps stage ot all its iterators. |
272 | */ |
273 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
274 | |
275 | /*! |
276 | * \brief Print the current step as equivalent python schedule API. |
277 | * \param stages The list of current stages |
278 | * \param stage_to_axes A map that maps stage ot all its iterators. |
279 | * \return Python schedule code. |
280 | */ |
281 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
282 | |
283 | static constexpr const char* record_prefix_str = "AN" ; |
284 | |
285 | static constexpr const char* _type_key = "auto_scheduler.AnnotationStep" ; |
286 | TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode); |
287 | }; |
288 | |
289 | /*! |
290 | * \brief Managed reference to AnnotationStepNode. |
291 | * \sa AnnotationStepNode |
292 | */ |
293 | class AnnotationStep : public Step { |
294 | public: |
295 | /*! |
296 | * \brief The constructor. |
297 | * \param stage_id The index of the stage to add annotation. |
298 | * \param iter_id The index of the iterator to add annotation. |
299 | * \param ann The annotation type of this step. |
300 | */ |
301 | AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); |
302 | |
303 | /*! |
304 | * \brief The constructor used to read a step record from JSONReader and create the |
305 | * corresponding step. |
306 | * \param reader The input JSONReader. |
307 | */ |
308 | explicit AnnotationStep(dmlc::JSONReader* reader); |
309 | |
310 | TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); |
311 | }; |
312 | |
313 | /*! \brief Fuse step that corresponds to te::Stage::fuse */ |
314 | class FuseStepNode : public StepNode { |
315 | public: |
316 | /*! \brief The ids of iterators to fuse. */ |
317 | Array<Integer> fused_ids; |
318 | |
319 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
320 | |
321 | /*! |
322 | * \brief Apply the current step to State. |
323 | * \param state A mutable pointer to state, which will be updated. |
324 | * \return The iterator result after fuse. |
325 | * \note If the iterators to be fused have stages attached at them(by compute_at), the fused |
326 | * result will become the new attach point. |
327 | */ |
328 | Iterator ApplyToState(State* state) const; |
329 | |
330 | /*! |
331 | * \brief Apply the current step to tvm.schedule. |
332 | * \param stages The list of current stages |
333 | * \param stage_to_axes A map that maps stage ot all its iterators. |
334 | * \return The iterator result after fuse. |
335 | */ |
336 | tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
337 | |
338 | /*! |
339 | * \brief Print the current step as equivalent python schedule API. |
340 | * \param stages The list of current stages |
341 | * \param stage_to_axes A map that maps stage ot all its iterators. |
342 | * \return Python schedule code. |
343 | */ |
344 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
345 | |
346 | static constexpr const char* record_prefix_str = "FU" ; |
347 | |
348 | static constexpr const char* _type_key = "auto_scheduler.FuseStep" ; |
349 | TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode); |
350 | }; |
351 | |
352 | /*! |
353 | * \brief Managed reference to FuseStepNode. |
354 | * \sa FuseStepNode |
355 | */ |
356 | class FuseStep : public Step { |
357 | public: |
358 | /*! |
359 | * \brief The constructor. |
360 | * \param stage_id The index of the stage to be fused. |
361 | * \param fused_ids The index of the iterators to be fused. |
362 | */ |
363 | FuseStep(int stage_id, const Array<Integer>& fused_ids); |
364 | |
365 | /*! |
366 | * \brief The constructor used to read a step record from JSONReader and create the |
367 | * corresponding step. |
368 | * \param reader The input JSONReader. |
369 | */ |
370 | explicit FuseStep(dmlc::JSONReader* reader); |
371 | |
372 | TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); |
373 | }; |
374 | |
375 | /*! \brief Pragma step that corresponds to te::Stage::pragma */ |
376 | class PragmaStepNode : public StepNode { |
377 | public: |
378 | /*! \brief The index of the iterator to add pragma. */ |
379 | int iter_id; |
380 | /*! \brief The pragma string. */ |
381 | String pragma_type; |
382 | |
383 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
384 | |
385 | /*! |
386 | * \brief Apply the current step to State. |
387 | * \param state A mutable pointer to state, which will be updated. |
388 | */ |
389 | void ApplyToState(State* state) const; |
390 | |
391 | /*! |
392 | * \brief Apply the current step to tvm.schedule. |
393 | * \param stages The list of current stages |
394 | * \param stage_to_axes A map that maps stage ot all its iterators. |
395 | */ |
396 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
397 | |
398 | /*! |
399 | * \brief Print the current step as equivalent python schedule API. |
400 | * \param stages The list of current stages |
401 | * \param stage_to_axes A map that maps stage ot all its iterators. |
402 | * \return Python schedule code. |
403 | */ |
404 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
405 | |
406 | static constexpr const char* record_prefix_str = "PR" ; |
407 | |
408 | static constexpr const char* _type_key = "auto_scheduler.PragmaStep" ; |
409 | TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode); |
410 | }; |
411 | |
412 | /*! |
413 | * \brief Managed reference to PragmaStepNode. |
414 | * \sa PragmaStepNode |
415 | */ |
416 | class PragmaStep : public Step { |
417 | public: |
418 | /*! |
419 | * \brief The constructor. |
420 | * \param stage_id The index of the stage to be fused. |
421 | * \param iter_id The index of the iterator to add pragma. |
422 | * \param pragma_type The pragma string. |
423 | */ |
424 | PragmaStep(int stage_id, int iter_id, String pragma_type); |
425 | |
426 | /*! |
427 | * \brief The constructor used to read a step record from JSONReader and create the |
428 | * corresponding step. |
429 | * \param reader The input JSONReader. |
430 | */ |
431 | explicit PragmaStep(dmlc::JSONReader* reader); |
432 | |
433 | TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); |
434 | }; |
435 | |
436 | /*! \brief Reorder step that corresponds to te::Stage::reorder */ |
437 | class ReorderStepNode : public StepNode { |
438 | public: |
439 | /*! |
440 | * \brief The iterator ids after reorder. |
441 | * This array should specify the order of all iterators. |
442 | */ |
443 | Array<Integer> after_ids; |
444 | |
445 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
446 | |
447 | /*! |
448 | * \brief Apply the current step to State. |
449 | * \param state A mutable pointer to state, which will be updated. |
450 | */ |
451 | void ApplyToState(State* state) const; |
452 | |
453 | /*! |
454 | * \brief Apply the current step to tvm.schedule. |
455 | * \param stages The list of current stages |
456 | * \param stage_to_axes A map that maps stage ot all its iterators. |
457 | */ |
458 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
459 | |
460 | /*! |
461 | * \brief Print the current step as equivalent python schedule API. |
462 | * \param stages The list of current stages |
463 | * \param stage_to_axes A map that maps stage ot all its iterators. |
464 | * \return Python schedule code. |
465 | */ |
466 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
467 | |
468 | static constexpr const char* record_prefix_str = "RE" ; |
469 | |
470 | static constexpr const char* _type_key = "auto_scheduler.ReorderStep" ; |
471 | TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode); |
472 | }; |
473 | |
474 | /*! |
475 | * \brief Managed reference to ReorderStepNode. |
476 | * \sa ReorderStepNode |
477 | */ |
478 | class ReorderStep : public Step { |
479 | public: |
480 | /*! |
481 | * \brief The constructor. |
482 | * \param stage_id The index of the stage to be reordered. |
483 | * \param after_ids The expected indexes of the iterators after reorder. |
484 | */ |
485 | ReorderStep(int stage_id, const Array<Integer>& after_ids); |
486 | |
487 | /*! |
488 | * \brief The constructor used to read a step record from JSONReader and create the |
489 | * corresponding step. |
490 | * \param reader The input JSONReader. |
491 | */ |
492 | explicit ReorderStep(dmlc::JSONReader* reader); |
493 | |
494 | TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); |
495 | }; |
496 | |
497 | /*! |
498 | * \brief Split step that corresponds to te::Stage::split with additional |
499 | * support of multiple-level of factors |
500 | */ |
501 | class SplitStepNode : public StepNode { |
502 | public: |
503 | /*! \brief The id of the iter to split. */ |
504 | int iter_id; |
505 | /*! \brief The extent length of the axis to split. */ |
506 | Optional<PrimExpr> extent; |
507 | /*! \brief The split factors. */ |
508 | Array<Optional<Integer>> lengths; |
509 | /*! |
510 | * \brief If true, the `lengths` denote the lengths of iterators |
511 | * from inner level to outer level |
512 | */ |
513 | bool inner_to_outer; |
514 | |
515 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
516 | |
517 | /*! |
518 | * \brief Apply the current step to State. |
519 | * \param state A mutable pointer to state, which will be updated. |
520 | * \return The iterator results after split. |
521 | * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner |
522 | * most iterator of split results will become the new attach point. |
523 | */ |
524 | Array<Iterator> ApplyToState(State* state) const; |
525 | |
526 | /*! |
527 | * \brief Apply the current step to tvm.schedule. |
528 | * \param stages The list of current stages |
529 | * \param stage_to_axes A map that maps stage ot all its iterators. |
530 | * \return The iterator results after split. |
531 | */ |
532 | Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, |
533 | StageToAxesMap* stage_to_axes) const; |
534 | |
535 | /*! |
536 | * \brief Print the current step as equivalent python schedule API. |
537 | * \param stages The list of current stages |
538 | * \param stage_to_axes A map that maps stage ot all its iterators. |
539 | * \return Python schedule code. |
540 | */ |
541 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
542 | |
543 | static constexpr const char* record_prefix_str = "SP" ; |
544 | |
545 | static constexpr const char* _type_key = "auto_scheduler.SplitStep" ; |
546 | TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode); |
547 | }; |
548 | |
549 | /*! |
550 | * \brief Managed reference to SplitStepNode. |
551 | * \sa SplitStepNode |
552 | */ |
553 | class SplitStep : public Step { |
554 | public: |
555 | /*! |
556 | * \brief The constructor. |
557 | * \param stage_id The index of the stage to be split. |
558 | * \param iter_id The index of the iterator to be split. |
559 | * \param extent The extent length of the axis to split. |
560 | * \param lengths The multiple split factors. Can be None to be filled by search policy. |
561 | * \param inner_to_outer The split direction. |
562 | */ |
563 | SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent, |
564 | const Array<Optional<Integer>>& lengths, bool inner_to_outer); |
565 | |
566 | /*! |
567 | * \brief The constructor used to read a step record from JSONReader and create the |
568 | * corresponding step. |
569 | * \param reader The input JSONReader. |
570 | */ |
571 | explicit SplitStep(dmlc::JSONReader* reader); |
572 | |
573 | TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); |
574 | }; |
575 | |
576 | /*! \brief Similar to SplitStepNode, but uses split factors from another step |
577 | * (i.e. Follow another split step) */ |
578 | class FollowSplitStepNode : public StepNode { |
579 | public: |
580 | /*! \brief The id of the iter to be split. */ |
581 | int iter_id; |
582 | /*! \brief The index of the split step to be followed in the history. */ |
583 | int src_step_id; |
584 | /*! \brief The number of split level. */ |
585 | int n_split; |
586 | |
587 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
588 | |
589 | /*! |
590 | * \brief Extract split lengths. |
591 | * \param transform_steps An array of history transform steps. |
592 | * \return The multiple split factors. |
593 | */ |
594 | Array<Optional<Integer>> (const Array<Step>& transform_steps) const; |
595 | |
596 | /*! |
597 | * \brief Apply the current step to State. |
598 | * \param state A mutable pointer to state, which will be updated. |
599 | * \return The iterator results after split. |
600 | */ |
601 | Array<Iterator> ApplyToState(State* state) const; |
602 | |
603 | /*! |
604 | * \brief Apply the current step to tvm.schedule. |
605 | * \param stages The list of current stages |
606 | * \param stage_to_axes A map that maps stage ot all its iterators. |
607 | * \param transform_steps An array of history transform steps. |
608 | * \return The iterator results after split. |
609 | */ |
610 | Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
611 | const Array<Step>& transform_steps) const; |
612 | |
613 | /*! |
614 | * \brief Print the current step as equivalent python schedule API. |
615 | * \param stages The list of current stages |
616 | * \param stage_to_axes A map that maps stage ot all its iterators. |
617 | * \param transform_steps An array of history transform steps. |
618 | * \return Python schedule code. |
619 | */ |
620 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
621 | const Array<Step>& transform_steps) const; |
622 | |
623 | static constexpr const char* record_prefix_str = "FSP" ; |
624 | |
625 | static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep" ; |
626 | TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode); |
627 | }; |
628 | |
629 | /*! |
630 | * \brief Managed reference to FollowSplitStepNode. |
631 | * \sa FollowSplitStepNode |
632 | */ |
633 | class FollowSplitStep : public Step { |
634 | public: |
635 | /*! |
636 | * \brief The constructor. |
637 | * \param stage_id The index of the stage to be split. |
638 | * \param iter_id The index of the iterator to be split. |
639 | * \param src_step_id The index of the split step to be followed in the history. |
640 | * \param n_split The number of split level. |
641 | */ |
642 | FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); |
643 | |
644 | /*! |
645 | * \brief The constructor used to read a step record from JSONReader and create the |
646 | * corresponding step. |
647 | * \param reader The input JSONReader. |
648 | */ |
649 | explicit FollowSplitStep(dmlc::JSONReader* reader); |
650 | |
651 | TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); |
652 | }; |
653 | |
654 | /*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps. |
655 | * \note This can be used for the split in cooperative fetching. |
656 | */ |
657 | class FollowFusedSplitStepNode : public StepNode { |
658 | public: |
659 | /*! \brief The id of the iter to split. */ |
660 | int iter_id; |
661 | /*! \brief The indices of the split steps to be followed in the history. */ |
662 | Array<Integer> src_step_ids; |
663 | /*! \brief Use the length in this split level. */ |
664 | int level; |
665 | /*! \brief If this is true, use factor. Otherwise, use nparts. */ |
666 | bool factor_or_nparts; |
667 | |
668 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
669 | |
670 | /*! |
671 | * \brief Extract split length. |
672 | * \param transform_steps An array of history transform steps. |
673 | * \return Split factor. |
674 | */ |
675 | Optional<Integer> (const Array<Step>& transform_steps) const; |
676 | |
677 | /*! |
678 | * \brief Apply the current step to State. |
679 | * \param state A mutable pointer to state, which will be updated. |
680 | * \return The iterator results after split. |
681 | */ |
682 | Array<Iterator> ApplyToState(State* state) const; |
683 | |
684 | /*! |
685 | * \brief Apply the current step to tvm.schedule. |
686 | * \param stages The list of current stages |
687 | * \param stage_to_axes A map that maps stage ot all its iterators. |
688 | * \param transform_steps An array of history transform steps. |
689 | * \return The iterator results after split. |
690 | */ |
691 | Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
692 | const Array<Step>& transform_steps) const; |
693 | |
694 | /*! |
695 | * \brief Print the current step as equivalent python schedule API. |
696 | * \param stages The list of current stages |
697 | * \param stage_to_axes A map that maps stage ot all its iterators. |
698 | * \param transform_steps An array of history transform steps. |
699 | * \return Python schedule code. |
700 | */ |
701 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
702 | const Array<Step>& transform_steps) const; |
703 | |
704 | static constexpr const char* record_prefix_str = "FFSP" ; |
705 | |
706 | static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep" ; |
707 | TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode); |
708 | }; |
709 | |
710 | /*! |
711 | * \brief Managed reference to FollowFusedSplitStepNode. |
712 | * \sa FollowFusedSplitStepNode |
713 | */ |
714 | class FollowFusedSplitStep : public Step { |
715 | public: |
716 | /*! |
717 | * \brief The constructor. |
718 | * \param stage_id The index of the stage to be split. |
719 | * \param iter_id The index of the iterator to be split. |
720 | * \param src_step_ids An array of index for split step to be followed in the history. |
721 | * \param level Use the length in this split level. |
722 | * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts. |
723 | */ |
724 | FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level, |
725 | bool factor_or_nparts); |
726 | |
727 | /*! |
728 | * \brief The constructor used to read a step record from JSONReader and create the |
729 | * corresponding step. |
730 | * \param reader The input JSONReader. |
731 | */ |
732 | explicit FollowFusedSplitStep(dmlc::JSONReader* reader); |
733 | |
734 | TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); |
735 | }; |
736 | |
737 | /*! \brief Storage align step that corresponds to te::Stage::storage_align */ |
738 | class StorageAlignStepNode : public StepNode { |
739 | public: |
740 | /*! \brief The iterator to be aligned. */ |
741 | int iter_id; |
742 | /*! \brief The factor in alignment specification. */ |
743 | int factor; |
744 | /*! \brief The offset in the alignment specification. */ |
745 | int offset; |
746 | |
747 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
748 | |
749 | /*! |
750 | * \brief Apply the current step to State. |
751 | * \param state A mutable pointer to State, which will be updated. |
752 | */ |
753 | void ApplyToState(State* state) const; |
754 | |
755 | /*! |
756 | * \brief Apply the current step to tvm.schedule. |
757 | * \param stages The list of current stages |
758 | * \param stage_to_axes A map that maps stage ot all its iterators. |
759 | */ |
760 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
761 | |
762 | /*! |
763 | * \brief Print the current step as equivalent python schedule API. |
764 | * \param stages The list of current stages |
765 | * \param stage_to_axes A map that maps stage ot all its iterators. |
766 | * \return Python schedule code. |
767 | */ |
768 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
769 | |
770 | static constexpr const char* record_prefix_str = "SA" ; |
771 | |
772 | static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep" ; |
773 | TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode); |
774 | }; |
775 | |
776 | /*! |
777 | * \brief Managed reference to StorageAlignStepNode. |
778 | * \sa StorageAlignStepNode |
779 | */ |
780 | class StorageAlignStep : public Step { |
781 | public: |
782 | /*! |
783 | * \brief The constructor. |
784 | * \param stage_id The index of the stage to be aligned. |
785 | * \param iter_id The index of the iterator to be aligned. |
786 | * \param factor The factor in alignment specification. |
787 | * \param offset The offset in the alignment specification. |
788 | */ |
789 | StorageAlignStep(int stage_id, int iter_id, int factor, int offset); |
790 | |
791 | /*! |
792 | * \brief The constructor used to read a step record from JSONReader and create the |
793 | * corresponding step. |
794 | * \param reader The input JSONReader. |
795 | */ |
796 | explicit StorageAlignStep(dmlc::JSONReader* reader); |
797 | |
798 | TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); |
799 | }; |
800 | |
801 | /********** Steps working on multiple stages **********/ |
802 | |
803 | /*! \brief Compute at step that corresponds to te::Stage::compute_at */ |
804 | class ComputeAtStepNode : public StepNode { |
805 | public: |
806 | /*! \brief The index of stage that this step will compute at to. */ |
807 | int target_stage_id; |
808 | /*! \brief The index of iterator in target stage that this step will compute at to. */ |
809 | int target_iter_id; |
810 | |
811 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
812 | |
813 | /*! |
814 | * \brief Apply the current step to State. |
815 | * \param state A mutable pointer to state, which will be updated. |
816 | * \note After compute_at, we need careful dependency analysis to compute the accurate bound |
817 | * information. However, it is relatively expensive and complicated, so we just fill "None" as |
818 | * bound for the newly created iterators. |
819 | * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. |
820 | */ |
821 | void ApplyToState(State* state) const; |
822 | |
823 | /*! |
824 | * \brief Apply the current step to tvm.schedule. |
825 | * \param stages The list of current stages |
826 | * \param stage_to_axes A map that maps stage ot all its iterators. |
827 | */ |
828 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
829 | |
830 | /*! |
831 | * \brief Print the current step as equivalent python schedule API. |
832 | * \param stages The list of current stages |
833 | * \param stage_to_axes A map that maps stage ot all its iterators. |
834 | * \return Python schedule code. |
835 | */ |
836 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
837 | |
838 | static constexpr const char* record_prefix_str = "CA" ; |
839 | |
840 | static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep" ; |
841 | TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode); |
842 | }; |
843 | |
844 | /*! |
845 | * \brief Managed reference to ComputeAtStepNode. |
846 | * \sa ComputeAtStepNode |
847 | */ |
848 | class ComputeAtStep : public Step { |
849 | public: |
850 | /*! |
851 | * \brief The constructor. |
852 | * \param stage_id The index of the source stage. |
853 | * \param target_stage_id The index of stage that this step will compute at to. |
854 | * \param target_iter_id The index of iterator in target stage that this step will compute at to. |
855 | */ |
856 | ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); |
857 | |
858 | /*! |
859 | * \brief The constructor used to read a step record from JSONReader and create the |
860 | * corresponding step. |
861 | * \param reader The input JSONReader. |
862 | */ |
863 | explicit ComputeAtStep(dmlc::JSONReader* reader); |
864 | |
865 | TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); |
866 | }; |
867 | |
868 | /*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ |
869 | class ComputeInlineStepNode : public StepNode { |
870 | public: |
871 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
872 | |
873 | /*! |
874 | * \brief Apply the current step to State. |
875 | * \param state A mutable pointer to state, which will be updated. |
876 | */ |
877 | void ApplyToState(State* state) const; |
878 | |
879 | /*! |
880 | * \brief Apply the current step to tvm.schedule. |
881 | * \param stages The list of current stages |
882 | * \param stage_to_axes A map that maps stage ot all its iterators. |
883 | * \return The iterator result after fuse. |
884 | */ |
885 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
886 | |
887 | /*! |
888 | * \brief Print the current step as equivalent python schedule API. |
889 | * \param stages The list of current stages |
890 | * \param stage_to_axes A map that maps stage ot all its iterators. |
891 | * \return Python schedule code. |
892 | */ |
893 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
894 | |
895 | static constexpr const char* record_prefix_str = "CI" ; |
896 | |
897 | static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep" ; |
898 | TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode); |
899 | }; |
900 | |
901 | /*! |
902 | * \brief Managed reference to ComputeInlineStepNode. |
903 | * \sa ComputeInlineStepNode |
904 | */ |
905 | class ComputeInlineStep : public Step { |
906 | public: |
907 | /*! |
908 | * \brief The constructor. |
909 | * \param stage_id The index of the stage to be marked compute inlined. |
910 | */ |
911 | explicit ComputeInlineStep(int stage_id); |
912 | |
913 | /*! |
914 | * \brief The constructor used to read a step record from JSONReader and create the |
915 | * corresponding step. |
916 | * \param reader The input JSONReader. |
917 | */ |
918 | explicit ComputeInlineStep(dmlc::JSONReader* reader); |
919 | |
920 | TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); |
921 | }; |
922 | |
923 | /*! \brief Compute root step that corresponds to te::Stage::compute_root */ |
924 | class ComputeRootStepNode : public StepNode { |
925 | public: |
926 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
927 | |
928 | /*! |
929 | * \brief Apply the current step to State. |
930 | * \param state A mutable pointer to state, which will be updated. |
931 | * \note After compute_root, we need careful dependency analysis to compute the accurate bound |
932 | * information. However, it is relatively expensive and complicated, so we just fill "None" as |
933 | * bound for the newly created iterators. |
934 | * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. |
935 | */ |
936 | void ApplyToState(State* state) const; |
937 | |
938 | /*! |
939 | * \brief Apply the current step to tvm.schedule. |
940 | * \param stages The list of current stages |
941 | * \param stage_to_axes A map that maps stage ot all its iterators. |
942 | * \return The iterator result after fuse. |
943 | */ |
944 | void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
945 | |
946 | /*! |
947 | * \brief Print the current step as equivalent python schedule API. |
948 | * \param stages The list of current stages |
949 | * \param stage_to_axes A map that maps stage ot all its iterators. |
950 | * \return Python schedule code. |
951 | */ |
952 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const; |
953 | |
954 | static constexpr const char* record_prefix_str = "CR" ; |
955 | |
956 | static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep" ; |
957 | TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode); |
958 | }; |
959 | |
960 | /*! |
961 | * \brief Managed reference to ComputeRootStepNode. |
962 | * \sa ComputeRootStepNode |
963 | */ |
964 | class ComputeRootStep : public Step { |
965 | public: |
966 | /*! |
967 | * \brief The constructor. |
968 | * \param stage_id The index of the stage to be marked compute at root. |
969 | */ |
970 | explicit ComputeRootStep(int stage_id); |
971 | |
972 | /*! |
973 | * \brief The constructor used to read a step record from JSONReader and create the |
974 | * corresponding step. |
975 | * \param reader The input JSONReader. |
976 | */ |
977 | explicit ComputeRootStep(dmlc::JSONReader* reader); |
978 | |
979 | TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); |
980 | }; |
981 | |
982 | /********** Steps adding new stages **********/ |
983 | |
984 | /*! |
985 | * \brief Cache read step that corresponds to te::Schedule::cache_read. |
986 | * \note Cache read step adds an extra stage to the original ComputeDAG, |
987 | * an up-to-date ComputeDAG will be stored in State's `current_compute_dag`. |
988 | */ |
989 | class CacheReadStepNode : public StepNode { |
990 | public: |
991 | /*! \brief The scope name of the newly added read stage. (e.g., local, shared, global) */ |
992 | String scope_name; |
993 | /*! \brief The indices of read stages. */ |
994 | Array<Integer> reader_stage_ids; |
995 | |
996 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
997 | |
998 | /*! |
999 | * \brief Apply the current step to State. |
1000 | * \param state A mutable pointer to state, which will be updated. |
1001 | * \param dag The original ComputeDAG of this state. |
1002 | * \return The index of the new added stage. |
1003 | */ |
1004 | int ApplyToState(State* state, const ComputeDAG& dag) const; |
1005 | |
1006 | /*! |
1007 | * \brief Apply the current step to tvm.schedule. |
1008 | * \param stages The list of current stages |
1009 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1010 | * \param schedule A mutable pointer to a te::Schedule. |
1011 | * \return The output Tensor of the new added stage. |
1012 | */ |
1013 | te::Tensor ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1014 | te::Schedule* schedule) const; |
1015 | |
1016 | /*! |
1017 | * \brief Print the current step as equivalent python schedule API. |
1018 | * \param stages The list of current stages |
1019 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1020 | * \param schedule A mutable pointer to a te::Schedule. |
1021 | * \return Python schedule code. |
1022 | */ |
1023 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1024 | te::Schedule* schedule) const; |
1025 | |
1026 | static constexpr const char* record_prefix_str = "CHR" ; |
1027 | |
1028 | static constexpr const char* _type_key = "auto_scheduler.CacheReadStep" ; |
1029 | TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode); |
1030 | }; |
1031 | |
1032 | /*! |
1033 | * \brief Managed reference to CacheReadStepNode. |
1034 | * \sa CacheReadStepNode |
1035 | */ |
1036 | class CacheReadStep : public Step { |
1037 | public: |
1038 | /*! |
1039 | * \brief The constructor. |
1040 | * \param stage_id The index of the stage to be cache_read. |
1041 | * \param scope_name The scope name of the newly added stage. |
1042 | * \param reader_stage_ids The indices of reader stages. |
1043 | */ |
1044 | CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids); |
1045 | |
1046 | /*! |
1047 | * \brief The constructor used to read a step record from JSONReader and create the |
1048 | * corresponding step. |
1049 | * \param reader The input JSONReader. |
1050 | */ |
1051 | explicit CacheReadStep(dmlc::JSONReader* reader); |
1052 | |
1053 | TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); |
1054 | }; |
1055 | |
1056 | /*! |
1057 | * \brief Cache write step that corresponds to te::Schedule::cache_write. |
1058 | * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date |
1059 | * ComputeDAG is stored in State's `current_compute_dag`. |
1060 | * This step will cache write all output tensors of the target stage. |
1061 | */ |
1062 | class CacheWriteStepNode : public StepNode { |
1063 | public: |
1064 | /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */ |
1065 | String scope_name; |
1066 | |
1067 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
1068 | |
1069 | /*! |
1070 | * \brief Apply the current step to State. |
1071 | * \param state A mutable pointer to state, which will be updated. |
1072 | * \param dag The original ComputeDAG of this state. |
1073 | * \return The index of the new added stage. |
1074 | */ |
1075 | int ApplyToState(State* state, const ComputeDAG& dag) const; |
1076 | |
1077 | /*! |
1078 | * \brief Apply the current step to tvm.schedule. |
1079 | * \param stages The list of current stages |
1080 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1081 | * \param schedule A mutable pointer to a te::Schedule. |
1082 | * \return The output Tensors of the new added stage. |
1083 | */ |
1084 | Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1085 | te::Schedule* schedule) const; |
1086 | |
1087 | /*! |
1088 | * \brief Print the current step as equivalent python schedule API. |
1089 | * \param stages The list of current stages |
1090 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1091 | * \param schedule A mutable pointer to a te::Schedule. |
1092 | * \return Python schedule code. |
1093 | */ |
1094 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1095 | te::Schedule* schedule) const; |
1096 | |
1097 | static constexpr const char* record_prefix_str = "CHW" ; |
1098 | |
1099 | static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep" ; |
1100 | TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode); |
1101 | }; |
1102 | |
1103 | /*! |
1104 | * \brief Managed reference to CacheWriteStepNode. |
1105 | * \sa CacheWriteStepNode |
1106 | */ |
1107 | class CacheWriteStep : public Step { |
1108 | public: |
1109 | /*! |
1110 | * \brief The constructor. |
1111 | * \param stage_id The index of the stage to be cache_write. |
1112 | * \param scope_name The scope name of the newly added stage. |
1113 | */ |
1114 | CacheWriteStep(int stage_id, String scope_name); |
1115 | |
1116 | /*! |
1117 | * \brief The constructor used to read a step record from JSONReader and create the |
1118 | * corresponding step. |
1119 | * \param reader The input JSONReader. |
1120 | */ |
1121 | explicit CacheWriteStep(dmlc::JSONReader* reader); |
1122 | |
1123 | TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); |
1124 | }; |
1125 | |
1126 | /*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ |
1127 | class RfactorStepNode : public StepNode { |
1128 | public: |
1129 | /*! \brief The index of the iterator to be factored. */ |
1130 | int iter_id; |
1131 | /*! \brief The position where the new iterator is placed. */ |
1132 | int factor_iter_id; |
1133 | |
1134 | void WriteToRecord(dmlc::JSONWriter* writer) const final; |
1135 | |
1136 | /*! |
1137 | * \brief Apply the current step to State. |
1138 | * \param state A mutable pointer to State, which will be updated. |
1139 | * \param dag The original ComputeDAG of this state. |
1140 | * \return The index of the new added stage. |
1141 | */ |
1142 | int ApplyToState(State* state, const ComputeDAG& dag) const; |
1143 | |
1144 | /*! |
1145 | * \brief Apply the current step to tvm.schedule. |
1146 | * \param stages The list of current stages |
1147 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1148 | * \param schedule A mutable pointer to a te::Schedule. |
1149 | * \return The output Tensors of the new added stage. |
1150 | */ |
1151 | Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1152 | te::Schedule* schedule) const; |
1153 | |
1154 | /*! |
1155 | * \brief Print the current step as equivalent python schedule API. |
1156 | * \param stages The list of current stages |
1157 | * \param stage_to_axes A map that maps stage ot all its iterators. |
1158 | * \param schedule A mutable pointer to a te::Schedule. |
1159 | * \return Python schedule code. |
1160 | */ |
1161 | String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1162 | te::Schedule* schedule) const; |
1163 | |
1164 | static constexpr const char* record_prefix_str = "RF" ; |
1165 | |
1166 | static constexpr const char* _type_key = "auto_scheduler.RfactorStep" ; |
1167 | TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode); |
1168 | }; |
1169 | |
1170 | /*! |
1171 | * \brief Managed reference to RfactorStepNode. |
1172 | * \sa RfactorStepNode |
1173 | */ |
1174 | class RfactorStep : public Step { |
1175 | public: |
1176 | /*! |
1177 | * \brief The constructor. |
1178 | * \param stage_id The index of the stage to be factored. |
1179 | * \param iter_id The index of the iterator to be factored. |
1180 | * \param factor_iter_id The position where the new iterator is placed. |
1181 | */ |
1182 | RfactorStep(int stage_id, int iter_id, int factor_iter_id); |
1183 | |
1184 | /*! |
1185 | * \brief The constructor used to read a step record from JSONReader and create the |
1186 | * corresponding step. |
1187 | * \param reader The input JSONReader. |
1188 | */ |
1189 | explicit RfactorStep(dmlc::JSONReader* reader); |
1190 | |
1191 | TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); |
1192 | }; |
1193 | |
1194 | } // namespace auto_scheduler |
1195 | } // namespace tvm |
1196 | |
1197 | #endif // TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ |
1198 | |