1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/loop_state.h |
22 | * \brief The definition of the "state" in the search. |
23 | * |
24 | * Each LoopState corresponds to a schedule for its ComputeDAG. |
25 | * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to |
26 | * construct the loop structure. |
27 | * The loop structure keeps a preview of how the schedule will finally look like after lowering the |
28 | * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations |
29 | * ...). |
30 | * During the schedule search process, the loop structure can provide search policy with necessary |
31 | * information on how to manipulate the current state. |
32 | * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM |
33 | * schedule primitives. The steps are also used for the serialization of a state. |
34 | * |
35 | * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. |
36 | * We don't use the existing TVM IR but to extend a new structure on it is because: |
37 | * 1. We want fast incremental change to the loop structures. The search policy needs to get the |
38 | * immediate loop structures update rather than after TVM lowering; |
39 | * 2. We want serializable transform history for replay, backtracking, and mutation; |
40 | * 3. We may create some macro schedule primitives that represent the combination of several |
41 | * TVM schedule primitives. |
42 | * |
43 | * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives. |
44 | * Since we share a lot of common objects during search, the transformation is implemented in |
45 | * copy on write style. All objects are immutable, which is similar to TVM IR. |
46 | */ |
47 | |
48 | #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_ |
49 | #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_ |
50 | |
51 | #include <dmlc/common.h> |
52 | #include <tvm/auto_scheduler/transform_step.h> |
53 | |
54 | #include <functional> |
55 | #include <unordered_map> |
56 | #include <utility> |
57 | #include <vector> |
58 | |
59 | namespace tvm { |
60 | namespace auto_scheduler { |
61 | |
62 | using namespace tvm::tir; |
63 | |
64 | class ComputeDAG; |
65 | |
66 | /*! \brief The type of a stage. */ |
67 | enum class StageKind : int { |
68 | /*! \brief A placeholder stage. */ |
69 | kPlaceholder = 0, |
70 | /*! \brief A compute stage. */ |
71 | kCompute = 1 |
72 | }; |
73 | |
74 | /*! \brief The type of compute location. */ |
75 | enum class ComputeAtKind : int { |
76 | /*! \brief Compute at root. */ |
77 | kRoot = 0, |
78 | /*! \brief Compute inlined. */ |
79 | kInlined = 1, |
80 | /*! \brief Compute at some iterator. */ |
81 | kIter = 2, |
82 | }; |
83 | |
84 | /*! \brief Stage-level attributes. */ |
85 | struct StageAttributes { |
86 | /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ |
87 | int auto_unroll_max_step; |
88 | /*! \brief The storage offset for the schedule primitive `storage_align`. */ |
89 | int storage_offset; |
90 | }; |
91 | |
92 | /*! |
93 | * \brief A op stage in the compute declaration. |
94 | * Similar to te::Stage in `include/tvm/te/schedule.h`. |
95 | */ |
96 | class StageNode : public Object { |
97 | public: |
98 | /*! \brief The operator of this stage */ |
99 | te::Operation op; |
100 | /*! \brief The iterators in this stage. */ |
101 | Array<Iterator> iters; |
102 | /*! \brief The type of this stage. */ |
103 | StageKind op_type; |
104 | /*! \brief The compute location of this stage. */ |
105 | ComputeAtKind compute_at; |
106 | /*! \brief Other stage-level attributes. */ |
107 | StageAttributes attrs; |
108 | |
109 | void VisitAttrs(tvm::AttrVisitor* v) { |
110 | v->Visit("op" , &op); |
111 | v->Visit("iters" , &iters); |
112 | v->Visit("op_type" , &op_type); |
113 | v->Visit("compute_at" , &compute_at); |
114 | } |
115 | |
116 | static constexpr const char* _type_key = "auto_scheduler.Stage" ; |
117 | TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); |
118 | }; |
119 | |
120 | /*! |
121 | * \brief Managed reference to StageNode. |
122 | * \sa StageNode |
123 | */ |
124 | class Stage : public ObjectRef { |
125 | public: |
126 | /*! |
127 | * \brief The constructor. |
128 | * \param op A `te::Operation`. |
129 | */ |
130 | explicit Stage(te::Operation op); |
131 | /*! |
132 | * \brief The constructor. |
133 | * \param op The source operation |
134 | * \param op_type The stage type of this op. |
135 | * \param iters The iterators of this op. |
136 | * \param compute_at The compute at type of this op. |
137 | * \param attrs Other stage-level attributes. |
138 | */ |
139 | Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at, |
140 | StageAttributes attrs); |
141 | |
142 | TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); |
143 | TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); |
144 | }; |
145 | |
146 | /*! \brief Use stage_id to represent a stage. */ |
147 | using StageKey = int; |
148 | /*! \brief Use stage_id and iter_id to represent a iterator. */ |
149 | using IterKey = std::pair<int, int>; |
150 | |
151 | /*! |
152 | * \brief stores the compute_at relation between stages |
153 | * This stores a bi-directional mapping from stages and iter: |
154 | * 1. Stage to its attached iterator |
155 | * 2. Iterator to the stage attached to it |
156 | * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages |
157 | * to query the relations |
158 | */ |
159 | class AttachMapNode : public Object { |
160 | public: |
161 | struct IterKeyHash { |
162 | std::size_t operator()(const IterKey& k) const { |
163 | return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second)); |
164 | } |
165 | }; |
166 | |
167 | /*! \brief A Map to store the mapping of stage to its attached iterator. */ |
168 | std::unordered_map<StageKey, IterKey> stage_to_attach_iter; |
169 | /*! \brief A Map to store the mapping of iterator to the stages attached to it. */ |
170 | std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages; |
171 | |
172 | static constexpr const char* _type_key = "auto_scheduler.AttachMap" ; |
173 | TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); |
174 | }; |
175 | |
176 | /*! |
177 | * \brief Managed reference to AttachMapNode. |
178 | * \sa AttachMapNode |
179 | */ |
180 | class AttachMap : public ObjectRef { |
181 | public: |
182 | /*! |
183 | * \brief Process the stage/iterator mapping after compute at. |
184 | * \param stage_id The index of the source stage of computed at. |
185 | * \param target_stage_id The index of stage that this step will compute at to. |
186 | * \param target_iter_id The index of target iterator in the target stage. |
187 | */ |
188 | void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); |
189 | |
190 | /*! |
191 | * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`. |
192 | * \param stage_id The index of the stage to be deleted. |
193 | */ |
194 | void DeleteStage(int stage_id); |
195 | |
196 | /*! |
197 | * \brief Find the relations of original iterators in AttachMap, and update them with the new |
198 | * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. |
199 | * \param original_iters The original IterKey. |
200 | * \param new_iters The new IterKey for replacing the old ones. |
201 | */ |
202 | void UpdateIters(const std::vector<IterKey>& original_iters, |
203 | const std::vector<IterKey>& new_iters); |
204 | |
205 | /*! |
206 | * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset |
207 | * to stage indexes that are larger than the start_id. Used for steps that insert new stages to |
208 | * ComputeDAG (e.g., CacheRead/CacheWrite step). |
209 | * \param start_id The index threshold. This function only adds offset for stages |
210 | * with indices larger then this threshold. |
211 | * \param offset The index offset to be added to the stage index. |
212 | * \return The updated AttachMap after applying stage index offset. |
213 | */ |
214 | AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const; |
215 | |
216 | TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); |
217 | TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); |
218 | |
219 | private: |
220 | /*! |
221 | * \brief Delete the entry of a specific stage. This will remove the items related to this |
222 | * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map. |
223 | * \param pnode A mutable pointer to AttachMapNode. |
224 | * \param stage_id The index of stage that will be removed from the map. |
225 | */ |
226 | static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); |
227 | }; |
228 | |
229 | /*! |
230 | * \brief A state in the search process. |
231 | * It consists of the current loop structure and a list of transformation steps used to construct |
232 | * it. |
233 | * Each State corresponds to a specific schedule for its ComputeDAG. |
234 | */ |
235 | class StateNode : public Object { |
236 | public: |
237 | /*! \brief Current stages and loop structures. */ |
238 | Array<Stage> stages; |
239 | /*! \brief History transformation steps. */ |
240 | Array<Step> transform_steps; |
241 | /*! |
242 | * \brief The attach relations of stages and iterators. This is used to track the compute at |
243 | * operation. |
244 | */ |
245 | AttachMap attach_map; |
246 | /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, |
247 | * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask. |
248 | * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps |
249 | * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG. |
250 | */ |
251 | Optional<ObjectRef> current_compute_dag; |
252 | /*! |
253 | * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all |
254 | * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. |
255 | */ |
256 | bool concrete; |
257 | |
258 | void VisitAttrs(tvm::AttrVisitor* v) { |
259 | v->Visit("stages" , &stages); |
260 | v->Visit("transform_steps" , &transform_steps); |
261 | v->Visit("concrete" , &concrete); |
262 | } |
263 | |
264 | static constexpr const char* _type_key = "auto_scheduler.State" ; |
265 | TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); |
266 | }; |
267 | |
268 | /*! |
269 | * \brief Managed reference to StateNode. |
270 | * \sa StateNode |
271 | */ |
272 | class State : public ObjectRef { |
273 | public: |
274 | /*! |
275 | * \brief The constructor. |
276 | * \param ops `te::Operation`s for a compute declaration. |
277 | */ |
278 | explicit State(const Array<te::Operation>& ops); |
279 | |
280 | /*! |
281 | * \brief Pretty-print the state to a human readable string. |
282 | * \param delete_trivial_loop True for skipping the trivial loops. |
283 | * (undefined or extent == 1, default set to True) |
284 | * \return The human readable string. |
285 | */ |
286 | String ToStr(bool delete_trivial_loop = true) const; |
287 | |
288 | /********** Step APIs working on a single stage **********/ |
289 | /*! |
290 | * \brief The schedule primitive corresponding to `te::Stage::bind`. |
291 | * \param stage_id The index of the stage to be binded. |
292 | * \param it The iterator to be binded. |
293 | * \param thread_type The thread type. |
294 | * \return The new iterator after binding. |
295 | */ |
296 | TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); |
297 | /*! |
298 | * \brief The schedule primitive corresponding to `te::Stage::parallel`. |
299 | * \param stage_id The index of the stage to be paralleled. |
300 | * \param it The iterator to be paralleled. |
301 | * \return The new iterator after parallel. |
302 | */ |
303 | TVM_DLL Iterator parallel(int stage_id, const Iterator& it); |
304 | /*! |
305 | * \brief The schedule primitive corresponding to `te::Stage::unroll`. |
306 | * \param stage_id The index of the stage to be unrolled. |
307 | * \param it The iterator to be unrolled. |
308 | * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be |
309 | * skipped. |
310 | * \return The new iterator after unroll. |
311 | */ |
312 | TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); |
313 | /*! |
314 | * \brief The schedule primitive corresponding to `te::Stage::vectorize`. |
315 | * \param stage_id The index of the stage to be vectorized. |
316 | * \param it The iterator to be vectorized. |
317 | * \return The new iterator after vectorization. |
318 | */ |
319 | TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); |
320 | /*! |
321 | * \brief The schedule primitive corresponding to `te::Stage::fuse`. |
322 | * \param stage_id The index of the stage to be fused. |
323 | * \param iters The iterators to be fused. |
324 | * \return The iterator result after fuse. |
325 | * \note If the iterators to be fused have stages attached at them(by compute_at), the fused |
326 | * result will become the new attach point. |
327 | */ |
328 | TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters); |
329 | /*! |
330 | * \brief The schedule primitive corresponding to `te.Stage.pragma`. |
331 | * \param stage_id The index of the stage to add pragma. |
332 | * \param it The iterator to add pragma. |
333 | * \param pragma_type The pragma string. |
334 | */ |
335 | TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type); |
336 | /*! |
337 | * \brief The schedule primitive corresponding to `te::Stage::reorder`. |
338 | * \param stage_id The index of the stage to be reordered. |
339 | * \param order The expected iterator order. |
340 | */ |
341 | TVM_DLL void reorder(int stage_id, const Array<Iterator>& order); |
342 | /*! |
343 | * \brief The schedule primitive corresponding to `te::Stage::split`. |
344 | * \param stage_id The index of the stage to be split. |
345 | * \param it The iterator to be split. |
346 | * \param lengths The multiple split factors. Can be None to be filled by search policy. |
347 | * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner. |
348 | * \return The new iterator after splitting. |
349 | * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner |
350 | * most iterator of split results will become the new attach point. |
351 | */ |
352 | TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it, |
353 | const Array<Optional<Integer>>& lengths, |
354 | bool inner_to_outer = true); |
355 | /*! |
356 | * \brief The schedule primitive similar to split, but uses split factors from previous steps. |
357 | * \param stage_id The index of the stage to be split. |
358 | * \param it The iterator to be split. |
359 | * \param src_step_id The index of the split step to be followed in the history. |
360 | * \param n_split The number of split level. |
361 | * \return The split new Iterators. |
362 | */ |
363 | TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, |
364 | int n_split); |
365 | /*! |
366 | * \brief The schedule primitive similar to split, but uses split factors from |
367 | * fused previous steps. |
368 | * \param stage_id The index of the stage to be split. |
369 | * \param it The iterator to be split. |
370 | * \param src_step_ids The indices of the split steps to be followed in the history. |
371 | * \param level Use the length in this split level. |
372 | * \param factor_or_nparts True to use `factor` for split from inner to outer, |
373 | False to use `nparts` for split from outer to inner. |
374 | * \return The split new Iterators. |
375 | */ |
376 | TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it, |
377 | const Array<Integer>& src_step_ids, int level, |
378 | bool factor_or_nparts); |
379 | /*! |
380 | * \brief The schedule primitive corresponding to `te.Stage.storage_align`. |
381 | * \param stage_id The index of the stage to be aligned. |
382 | * \param it The iterator to be aligned. |
383 | * \param factor The factor in alignment specification. |
384 | * \param offset The offset in the alignment specification. |
385 | */ |
386 | TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset); |
387 | |
388 | /********** Step APIs working on multiple stages **********/ |
389 | /*! |
390 | * \brief The schedule primitive corresponding to `te::Stage::compute_at`. |
391 | * \param stage_id The index of the source stage of computed at. |
392 | * \param target_stage_id The index of stage that this step will compute at to. |
393 | * \param target_iter The indiex of the target iterator in the target stage. |
394 | * \note After compute_at, we need careful dependency analysis to compute the accurate bound |
395 | * information. However, it is relatively expensive and complicated, so we just fill "None" as |
396 | * bound for the newly created iterators. |
397 | * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. |
398 | */ |
399 | TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); |
400 | /*! |
401 | * \brief The schedule primitive corresponding to `te::Stage::compute_inline`. |
402 | * \param stage_id The index of the stage to be marked compute inlined. |
403 | */ |
404 | TVM_DLL void compute_inline(int stage_id); |
405 | /*! |
406 | * \brief The schedule primitive corresponding to `te::Stage::compute_root`. |
407 | * \param stage_id The index of the stage to be marked compute at root. |
408 | * \note After compute_root, we need careful dependency analysis to compute the accurate bound |
409 | * information. However, it is relatively expensive and complicated, so we just fill "None" as |
410 | * bound for the newly created iterators. |
411 | * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. |
412 | */ |
413 | TVM_DLL void compute_root(int stage_id); |
414 | |
415 | /********** Step APIs adding new stages **********/ |
416 | /*! |
417 | * \brief The schedule primitive corresponding to `te::Schedule::cache_read`. |
418 | * \param stage_id The index of the stage to be cache_read. |
419 | * \param scope_name The scope name of the newly added stage. |
420 | * \param reader_stage_ids The indices of reader stages. |
421 | * \param dag The original ComputeDAG of this state. |
422 | * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the |
423 | * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. |
424 | */ |
425 | TVM_DLL int cache_read(int stage_id, const String& scope_name, |
426 | const Array<Integer>& reader_stage_ids, const ComputeDAG& dag); |
427 | /*! |
428 | * \brief The schedule primitive corresponding to `te::Schedule::cache_write`. |
429 | * \param stage_id The index of the stage to be cache_write. |
430 | * \param scope_name The scope name of the newly added stage. |
431 | * \param dag The original ComputeDAG of this state. |
432 | * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the |
433 | * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. |
434 | * This step will cache write all output tensors of the target stage. |
435 | */ |
436 | TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); |
437 | /*! |
438 | * \brief The schedule primitive corresponding to `te::Schedule::rfactor`. |
439 | * \param stage_id The index of the iterator to be factored. |
440 | * \param it The iterator to be factored. |
441 | * \param factor_iter_id The position where the new iterator is placed. |
442 | * \param dag The original ComputeDAG of this state. |
443 | * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the |
444 | * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. |
445 | */ |
446 | TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag); |
447 | |
448 | TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); |
449 | TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); |
450 | }; |
451 | |
452 | } // namespace auto_scheduler |
453 | } // namespace tvm |
454 | |
455 | // Hash and equal function for State |
456 | namespace std { |
457 | |
458 | /*! |
459 | * \brief The equal_to function for auto_scheduler::State. |
460 | * This function checks the equality by looking at the lowered string format of states. |
461 | * If two states with different transform history have the same lowered string format, |
462 | * they will be considered being equal. |
463 | */ |
464 | template <> |
465 | struct equal_to<::tvm::auto_scheduler::State> { |
466 | bool operator()(const ::tvm::auto_scheduler::State& lhs, |
467 | const ::tvm::auto_scheduler::State& rhs) const { |
468 | return lhs.ToStr() == rhs.ToStr(); |
469 | } |
470 | }; |
471 | |
472 | /*! \brief The hash function for auto_scheduler::State. */ |
473 | template <> |
474 | struct hash<::tvm::auto_scheduler::State> { |
475 | std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { |
476 | return tvm::runtime::ObjectHash()(state.ToStr()); |
477 | } |
478 | }; |
479 | |
480 | } // namespace std |
481 | |
482 | #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_ |
483 | |