1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/loop_state.h
22 * \brief The definition of the "state" in the search.
23 *
24 * Each LoopState corresponds to a schedule for its ComputeDAG.
25 * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
26 * construct the loop structure.
27 * The loop structure keeps a preview of how the schedule will finally look like after lowering the
28 * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
29 * ...).
30 * During the schedule search process, the loop structure can provide search policy with necessary
31 * information on how to manipulate the current state.
32 * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
33 * schedule primitives. The steps are also used for the serialization of a state.
34 *
35 * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
36 * We don't use the existing TVM IR but to extend a new structure on it is because:
37 * 1. We want fast incremental change to the loop structures. The search policy needs to get the
38 * immediate loop structures update rather than after TVM lowering;
39 * 2. We want serializable transform history for replay, backtracking, and mutation;
40 * 3. We may create some macro schedule primitives that represent the combination of several
41 * TVM schedule primitives.
42 *
43 * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
44 * Since we share a lot of common objects during search, the transformation is implemented in
45 * copy on write style. All objects are immutable, which is similar to TVM IR.
46 */
47
48#ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
49#define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
50
51#include <dmlc/common.h>
52#include <tvm/auto_scheduler/transform_step.h>
53
54#include <functional>
55#include <unordered_map>
56#include <utility>
57#include <vector>
58
59namespace tvm {
60namespace auto_scheduler {
61
62using namespace tvm::tir;
63
64class ComputeDAG;
65
66/*! \brief The type of a stage. */
67enum class StageKind : int {
68 /*! \brief A placeholder stage. */
69 kPlaceholder = 0,
70 /*! \brief A compute stage. */
71 kCompute = 1
72};
73
74/*! \brief The type of compute location. */
75enum class ComputeAtKind : int {
76 /*! \brief Compute at root. */
77 kRoot = 0,
78 /*! \brief Compute inlined. */
79 kInlined = 1,
80 /*! \brief Compute at some iterator. */
81 kIter = 2,
82};
83
84/*! \brief Stage-level attributes. */
85struct StageAttributes {
86 /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */
87 int auto_unroll_max_step;
88 /*! \brief The storage offset for the schedule primitive `storage_align`. */
89 int storage_offset;
90};
91
92/*!
93 * \brief A op stage in the compute declaration.
94 * Similar to te::Stage in `include/tvm/te/schedule.h`.
95 */
96class StageNode : public Object {
97 public:
98 /*! \brief The operator of this stage */
99 te::Operation op;
100 /*! \brief The iterators in this stage. */
101 Array<Iterator> iters;
102 /*! \brief The type of this stage. */
103 StageKind op_type;
104 /*! \brief The compute location of this stage. */
105 ComputeAtKind compute_at;
106 /*! \brief Other stage-level attributes. */
107 StageAttributes attrs;
108
109 void VisitAttrs(tvm::AttrVisitor* v) {
110 v->Visit("op", &op);
111 v->Visit("iters", &iters);
112 v->Visit("op_type", &op_type);
113 v->Visit("compute_at", &compute_at);
114 }
115
116 static constexpr const char* _type_key = "auto_scheduler.Stage";
117 TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
118};
119
120/*!
121 * \brief Managed reference to StageNode.
122 * \sa StageNode
123 */
124class Stage : public ObjectRef {
125 public:
126 /*!
127 * \brief The constructor.
128 * \param op A `te::Operation`.
129 */
130 explicit Stage(te::Operation op);
131 /*!
132 * \brief The constructor.
133 * \param op The source operation
134 * \param op_type The stage type of this op.
135 * \param iters The iterators of this op.
136 * \param compute_at The compute at type of this op.
137 * \param attrs Other stage-level attributes.
138 */
139 Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at,
140 StageAttributes attrs);
141
142 TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode);
143 TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
144};
145
146/*! \brief Use stage_id to represent a stage. */
147using StageKey = int;
148/*! \brief Use stage_id and iter_id to represent a iterator. */
149using IterKey = std::pair<int, int>;
150
151/*!
152 * \brief stores the compute_at relation between stages
153 * This stores a bi-directional mapping from stages and iter:
154 * 1. Stage to its attached iterator
155 * 2. Iterator to the stage attached to it
156 * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages
157 * to query the relations
158 */
159class AttachMapNode : public Object {
160 public:
161 struct IterKeyHash {
162 std::size_t operator()(const IterKey& k) const {
163 return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
164 }
165 };
166
167 /*! \brief A Map to store the mapping of stage to its attached iterator. */
168 std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
169 /*! \brief A Map to store the mapping of iterator to the stages attached to it. */
170 std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
171
172 static constexpr const char* _type_key = "auto_scheduler.AttachMap";
173 TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
174};
175
176/*!
177 * \brief Managed reference to AttachMapNode.
178 * \sa AttachMapNode
179 */
180class AttachMap : public ObjectRef {
181 public:
182 /*!
183 * \brief Process the stage/iterator mapping after compute at.
184 * \param stage_id The index of the source stage of computed at.
185 * \param target_stage_id The index of stage that this step will compute at to.
186 * \param target_iter_id The index of target iterator in the target stage.
187 */
188 void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
189
190 /*!
191 * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`.
192 * \param stage_id The index of the stage to be deleted.
193 */
194 void DeleteStage(int stage_id);
195
196 /*!
197 * \brief Find the relations of original iterators in AttachMap, and update them with the new
198 * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
199 * \param original_iters The original IterKey.
200 * \param new_iters The new IterKey for replacing the old ones.
201 */
202 void UpdateIters(const std::vector<IterKey>& original_iters,
203 const std::vector<IterKey>& new_iters);
204
205 /*!
206 * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
207 * to stage indexes that are larger than the start_id. Used for steps that insert new stages to
208 * ComputeDAG (e.g., CacheRead/CacheWrite step).
209 * \param start_id The index threshold. This function only adds offset for stages
210 * with indices larger then this threshold.
211 * \param offset The index offset to be added to the stage index.
212 * \return The updated AttachMap after applying stage index offset.
213 */
214 AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
215
216 TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
217 TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
218
219 private:
220 /*!
221 * \brief Delete the entry of a specific stage. This will remove the items related to this
222 * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
223 * \param pnode A mutable pointer to AttachMapNode.
224 * \param stage_id The index of stage that will be removed from the map.
225 */
226 static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
227};
228
229/*!
230 * \brief A state in the search process.
231 * It consists of the current loop structure and a list of transformation steps used to construct
232 * it.
233 * Each State corresponds to a specific schedule for its ComputeDAG.
234 */
235class StateNode : public Object {
236 public:
237 /*! \brief Current stages and loop structures. */
238 Array<Stage> stages;
239 /*! \brief History transformation steps. */
240 Array<Step> transform_steps;
241 /*!
242 * \brief The attach relations of stages and iterators. This is used to track the compute at
243 * operation.
244 */
245 AttachMap attach_map;
246 /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt,
247 * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask.
248 * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps
249 * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG.
250 */
251 Optional<ObjectRef> current_compute_dag;
252 /*!
253 * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
254 * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
255 */
256 bool concrete;
257
258 void VisitAttrs(tvm::AttrVisitor* v) {
259 v->Visit("stages", &stages);
260 v->Visit("transform_steps", &transform_steps);
261 v->Visit("concrete", &concrete);
262 }
263
264 static constexpr const char* _type_key = "auto_scheduler.State";
265 TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);
266};
267
268/*!
269 * \brief Managed reference to StateNode.
270 * \sa StateNode
271 */
272class State : public ObjectRef {
273 public:
274 /*!
275 * \brief The constructor.
276 * \param ops `te::Operation`s for a compute declaration.
277 */
278 explicit State(const Array<te::Operation>& ops);
279
280 /*!
281 * \brief Pretty-print the state to a human readable string.
282 * \param delete_trivial_loop True for skipping the trivial loops.
283 * (undefined or extent == 1, default set to True)
284 * \return The human readable string.
285 */
286 String ToStr(bool delete_trivial_loop = true) const;
287
288 /********** Step APIs working on a single stage **********/
289 /*!
290 * \brief The schedule primitive corresponding to `te::Stage::bind`.
291 * \param stage_id The index of the stage to be binded.
292 * \param it The iterator to be binded.
293 * \param thread_type The thread type.
294 * \return The new iterator after binding.
295 */
296 TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
297 /*!
298 * \brief The schedule primitive corresponding to `te::Stage::parallel`.
299 * \param stage_id The index of the stage to be paralleled.
300 * \param it The iterator to be paralleled.
301 * \return The new iterator after parallel.
302 */
303 TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
304 /*!
305 * \brief The schedule primitive corresponding to `te::Stage::unroll`.
306 * \param stage_id The index of the stage to be unrolled.
307 * \param it The iterator to be unrolled.
308 * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
309 * skipped.
310 * \return The new iterator after unroll.
311 */
312 TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
313 /*!
314 * \brief The schedule primitive corresponding to `te::Stage::vectorize`.
315 * \param stage_id The index of the stage to be vectorized.
316 * \param it The iterator to be vectorized.
317 * \return The new iterator after vectorization.
318 */
319 TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
320 /*!
321 * \brief The schedule primitive corresponding to `te::Stage::fuse`.
322 * \param stage_id The index of the stage to be fused.
323 * \param iters The iterators to be fused.
324 * \return The iterator result after fuse.
325 * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
326 * result will become the new attach point.
327 */
328 TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
329 /*!
330 * \brief The schedule primitive corresponding to `te.Stage.pragma`.
331 * \param stage_id The index of the stage to add pragma.
332 * \param it The iterator to add pragma.
333 * \param pragma_type The pragma string.
334 */
335 TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
336 /*!
337 * \brief The schedule primitive corresponding to `te::Stage::reorder`.
338 * \param stage_id The index of the stage to be reordered.
339 * \param order The expected iterator order.
340 */
341 TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
342 /*!
343 * \brief The schedule primitive corresponding to `te::Stage::split`.
344 * \param stage_id The index of the stage to be split.
345 * \param it The iterator to be split.
346 * \param lengths The multiple split factors. Can be None to be filled by search policy.
347 * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner.
348 * \return The new iterator after splitting.
349 * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
350 * most iterator of split results will become the new attach point.
351 */
352 TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
353 const Array<Optional<Integer>>& lengths,
354 bool inner_to_outer = true);
355 /*!
356 * \brief The schedule primitive similar to split, but uses split factors from previous steps.
357 * \param stage_id The index of the stage to be split.
358 * \param it The iterator to be split.
359 * \param src_step_id The index of the split step to be followed in the history.
360 * \param n_split The number of split level.
361 * \return The split new Iterators.
362 */
363 TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
364 int n_split);
365 /*!
366 * \brief The schedule primitive similar to split, but uses split factors from
367 * fused previous steps.
368 * \param stage_id The index of the stage to be split.
369 * \param it The iterator to be split.
370 * \param src_step_ids The indices of the split steps to be followed in the history.
371 * \param level Use the length in this split level.
372 * \param factor_or_nparts True to use `factor` for split from inner to outer,
373 False to use `nparts` for split from outer to inner.
374 * \return The split new Iterators.
375 */
376 TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
377 const Array<Integer>& src_step_ids, int level,
378 bool factor_or_nparts);
379 /*!
380 * \brief The schedule primitive corresponding to `te.Stage.storage_align`.
381 * \param stage_id The index of the stage to be aligned.
382 * \param it The iterator to be aligned.
383 * \param factor The factor in alignment specification.
384 * \param offset The offset in the alignment specification.
385 */
386 TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
387
388 /********** Step APIs working on multiple stages **********/
389 /*!
390 * \brief The schedule primitive corresponding to `te::Stage::compute_at`.
391 * \param stage_id The index of the source stage of computed at.
392 * \param target_stage_id The index of stage that this step will compute at to.
393 * \param target_iter The indiex of the target iterator in the target stage.
394 * \note After compute_at, we need careful dependency analysis to compute the accurate bound
395 * information. However, it is relatively expensive and complicated, so we just fill "None" as
396 * bound for the newly created iterators.
397 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
398 */
399 TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
400 /*!
401 * \brief The schedule primitive corresponding to `te::Stage::compute_inline`.
402 * \param stage_id The index of the stage to be marked compute inlined.
403 */
404 TVM_DLL void compute_inline(int stage_id);
405 /*!
406 * \brief The schedule primitive corresponding to `te::Stage::compute_root`.
407 * \param stage_id The index of the stage to be marked compute at root.
408 * \note After compute_root, we need careful dependency analysis to compute the accurate bound
409 * information. However, it is relatively expensive and complicated, so we just fill "None" as
410 * bound for the newly created iterators.
411 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
412 */
413 TVM_DLL void compute_root(int stage_id);
414
415 /********** Step APIs adding new stages **********/
416 /*!
417 * \brief The schedule primitive corresponding to `te::Schedule::cache_read`.
418 * \param stage_id The index of the stage to be cache_read.
419 * \param scope_name The scope name of the newly added stage.
420 * \param reader_stage_ids The indices of reader stages.
421 * \param dag The original ComputeDAG of this state.
422 * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
423 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
424 */
425 TVM_DLL int cache_read(int stage_id, const String& scope_name,
426 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
427 /*!
428 * \brief The schedule primitive corresponding to `te::Schedule::cache_write`.
429 * \param stage_id The index of the stage to be cache_write.
430 * \param scope_name The scope name of the newly added stage.
431 * \param dag The original ComputeDAG of this state.
432 * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
433 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
434 * This step will cache write all output tensors of the target stage.
435 */
436 TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
437 /*!
438 * \brief The schedule primitive corresponding to `te::Schedule::rfactor`.
439 * \param stage_id The index of the iterator to be factored.
440 * \param it The iterator to be factored.
441 * \param factor_iter_id The position where the new iterator is placed.
442 * \param dag The original ComputeDAG of this state.
443 * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
444 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
445 */
446 TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
447
448 TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
449 TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
450};
451
452} // namespace auto_scheduler
453} // namespace tvm
454
455// Hash and equal function for State
456namespace std {
457
458/*!
459 * \brief The equal_to function for auto_scheduler::State.
460 * This function checks the equality by looking at the lowered string format of states.
461 * If two states with different transform history have the same lowered string format,
462 * they will be considered being equal.
463 */
464template <>
465struct equal_to<::tvm::auto_scheduler::State> {
466 bool operator()(const ::tvm::auto_scheduler::State& lhs,
467 const ::tvm::auto_scheduler::State& rhs) const {
468 return lhs.ToStr() == rhs.ToStr();
469 }
470};
471
472/*! \brief The hash function for auto_scheduler::State. */
473template <>
474struct hash<::tvm::auto_scheduler::State> {
475 std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
476 return tvm::runtime::ObjectHash()(state.ToStr());
477 }
478};
479
480} // namespace std
481
482#endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
483