1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/te/schedule.h |
22 | * \brief Define a schedule. |
23 | */ |
24 | // Acknowledgement: Many schedule primitives originate from Halide and Loopy. |
25 | #ifndef TVM_TE_SCHEDULE_H_ |
26 | #define TVM_TE_SCHEDULE_H_ |
27 | |
28 | #include <tvm/support/with.h> |
29 | #include <tvm/te/tensor.h> |
30 | #include <tvm/te/tensor_intrin.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/index_map.h> |
33 | |
34 | #include <string> |
35 | #include <unordered_map> |
36 | |
37 | namespace tvm { |
38 | namespace te { |
39 | // Node container for Stage |
40 | class StageNode; |
41 | // Node container for Schedule |
42 | class ScheduleNode; |
43 | // Node container for IterVarRelation |
44 | class IterVarRelationNode; |
45 | // Attribute of itervar. |
46 | class IterVarAttrNode; |
47 | |
48 | /*! \brief the attachment type */ |
49 | enum AttachType : int { |
50 | kGroupRoot = 1, |
51 | kInline = 2, |
52 | kInlinedAlready = 3, |
53 | kScope = 4, |
54 | kScanUpdate = 5 |
55 | }; |
56 | |
57 | /*! \brief Stage, contains scheduling for a stage of computation. */ |
58 | class Stage : public ObjectRef { |
59 | public: |
60 | Stage() {} |
61 | explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {} |
62 | /*! |
63 | * \brief create a new schedule for op. |
64 | * \param op The operator in the schedule |
65 | */ |
66 | explicit Stage(Operation op); |
67 | /*! |
68 | * \brief access the internal node container |
69 | * \return the pointer to the internal node container |
70 | */ |
71 | inline const StageNode* operator->() const; |
72 | /*! |
73 | * \brief access the internal node container |
74 | * \return the pointer to the internal node container |
75 | */ |
76 | inline StageNode* operator->(); |
77 | /*! |
78 | * \brief set the memory scope of the stage |
79 | * \param scope The memory scope. |
80 | */ |
81 | TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*) |
82 | /*! |
83 | * \brief specify the schedule to be computed at the parent schedule's scope. |
84 | * \param parent The parent schedule. |
85 | * \param scope The iteration point to carry the schedule. |
86 | * \return reference to self. |
87 | */ |
88 | TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) |
89 | /*! |
90 | * \brief Compute the function inline. |
91 | * \return reference to self. |
92 | */ |
93 | TVM_DLL Stage& compute_inline(); // NOLINT(*) |
94 | /*! |
95 | * \brief Compute the function at group root. |
96 | * \return reference to self. |
97 | */ |
98 | TVM_DLL Stage& compute_root(); // NOLINT(*) |
99 | /*! |
100 | * \brief Bind the IterVar to thread index. |
101 | * |
102 | * \param ivar The IterVar to be bound. |
103 | * \param thread_ivar The thread axis to be bound. |
104 | * \return reference to self. |
105 | */ |
106 | TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar); |
107 | /*! |
108 | * \brief Set the predicate to determine whether a store to the array should be performed. |
109 | * Use this when there are multiple threads performing the same store and we only |
110 | * need one of them to do the store. |
111 | * |
112 | * \note This is a dangerous scheduling primitive that can change behavior of program. |
113 | * Only do when we are certain that thare are duplicated stores. |
114 | * \param predicate The condition to be checked. |
115 | * \return reference to self. |
116 | */ |
117 | TVM_DLL Stage& set_store_predicate(PrimExpr predicate); |
118 | /*! |
119 | * \brief Specify environment threads that launched around the group's scope. |
120 | * This can only be used in group stage. |
121 | * \param threads The threads to be launched around the scope. |
122 | * \note Each thread can only appear in one env_threads. |
123 | * This is a beta feature. |
124 | * \return reference to self. |
125 | */ |
126 | TVM_DLL Stage& env_threads(Array<IterVar> threads); |
127 | /*! |
128 | * \brief Split the parent by factor, generate |
129 | * \param parent The parent iteration domain. |
130 | * \param factor The split factor of the loop. |
131 | * \param p_outer The result outer domain |
132 | * \param p_inner The result inner domain. |
133 | * \return reference to self. |
134 | */ |
135 | TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, |
136 | IterVar* p_inner); // NOLINT(*) |
137 | /*! |
138 | * \brief Split the iteration with given number of parts. |
139 | * |
140 | * \param parent The parent domain. |
141 | * \param nparts The number of parts in the outer domain. |
142 | * \param p_outer The result outer domain. |
143 | * \param p_inner The result inner domain. |
144 | * \return reference to self. |
145 | */ |
146 | TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, |
147 | IterVar* p_inner); // NOLINT(*) |
148 | /*! |
149 | * \brief Fuse the inner outer domain to the target |
150 | * \param outer The outer domain to be fused. |
151 | * \param inner The inner domain to be fused |
152 | * \param p_target The result target domain. |
153 | * \return reference to self. |
154 | */ |
155 | TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*) |
156 | /*! |
157 | * \brief Fuse all the axes together into a single axis. |
158 | * |
159 | * \param axes All the axes to be fused. |
160 | * \param p_target The result target domain. |
161 | * |
162 | * \note axes can be an empty array, |
163 | * in that case, a singleton IterVar is created and |
164 | * inserted to the outermost loop. |
165 | * The fuse of empty array is used to support zero-dimension tensors. |
166 | * |
167 | * \return reference to self. |
168 | */ |
169 | TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*) |
170 | /*! |
171 | * \brief Reorder the iteration |
172 | * \param order The order of iteration variable. |
173 | * \return reference to self. |
174 | */ |
175 | TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*) |
176 | /*! |
177 | * \brief Perform tiling on two dimensions |
178 | * The final loop order from outmost to inner most are |
179 | * [x_outer, y_outer, x_inner, y_inner] |
180 | * |
181 | * \param x_parent The original x dimension |
182 | * \param y_parent The original y dimension |
183 | * \param x_factor The stride factor on x axis |
184 | * \param y_factor The stride factor on y axis |
185 | * \param p_x_outer Outer axis of x dimension |
186 | * \param p_y_outer Outer axis of y dimension |
187 | * \param p_x_inner Inner axis of x dimension |
188 | * \param p_y_inner Inner axis of y dimension |
189 | * \return reference to self. |
190 | */ |
191 | TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) |
192 | PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, |
193 | IterVar* p_x_inner, IterVar* p_y_inner); |
194 | /*! |
195 | * \brief Vectorize iteration. |
196 | * \param var The axis to be vectorized. |
197 | * \return reference to self. |
198 | */ |
199 | TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) |
200 | /*! |
201 | * \brief Replace computation of the current stage by tensor intrinsic f. |
202 | * \param var The axis marks beginning of tensorization. |
203 | * Every operations inside the axis(include axis itself is tensorized). |
204 | * \param f The Tensor compute intrinsics. |
205 | * \return reference to self. |
206 | */ |
207 | TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) |
208 | /*! |
209 | * \brief Unroll iteration. |
210 | * \param var The axis to be unrolled. |
211 | * \return reference to self. |
212 | */ |
213 | TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) |
214 | /*! |
215 | * \brief Parallelize iteration. |
216 | * \param var The axis to be parallelized. |
217 | * \return reference to self. |
218 | */ |
219 | TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) |
220 | /*! |
221 | * \brief Annotate the iteration with pragma |
222 | * |
223 | * \param var The axis to be parallelized. |
224 | * \param pragma_type The pragma type. |
225 | * \param pragma_value The pragma value |
226 | * |
227 | * \return reference to self. |
228 | */ |
229 | TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, |
230 | const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) |
231 | /*! |
232 | * \brief Fetch data in advance. |
233 | * \param domain the tensor to be prefetched |
234 | * \param var the iteration point at which to apply prefetching |
235 | * \param offset the number of iterations be to fetched in advance |
236 | * \return reference to self |
237 | */ |
238 | TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*) |
239 | /*! |
240 | * \brief Set alignment requirement for specific dimension. |
241 | * |
242 | * Such that stride[axis] == k * factor + offset for some k. |
243 | * |
244 | * \param axis The dimension to be specified for alignment. |
245 | * \param factor The factor multiple of alignment |
246 | * \param offset The required offset factor. |
247 | * \return reference to self |
248 | */ |
249 | TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*) |
250 | /*! |
251 | * \brief Compute current stage with double buffering. |
252 | * \return reference to self. |
253 | */ |
254 | TVM_DLL Stage& double_buffer(); // NOLINT(*) |
255 | /*! |
256 | * \brief Compute current stage with rolling buffering. |
257 | * \return reference to self. |
258 | */ |
259 | TVM_DLL Stage& rolling_buffer(); // NOLINT(*) |
260 | /*! |
261 | * \brief Defines a layout transformation to be applied to the buffer. |
262 | * |
263 | * The map from initial_index to final_index must be an |
264 | * invertible affine transformation. |
265 | * |
266 | * \param initial_indices An array of variables to represent a |
267 | * value's location in the tensor, using the pre-transformation |
268 | * layout. These variables are used as binding occurrences to |
269 | * represent the initial indices when applying the initial->final |
270 | * mapping, and should not occur elsewhere in the |
271 | * Schedule. (i.e. Pass in newly constructed variables, not the |
272 | * initial IterVar::var) |
273 | * |
274 | * \param final_indices An array of expressions, giving the |
275 | * value's location in the tensor, using the post-transformation layout. |
276 | * Expressions should be in terms of the variables given in |
277 | * initial_indices. |
278 | * |
279 | * \param out_iter_vars An optional output location for the updated |
280 | * loop iteration variables. |
281 | * |
282 | * \return reference to self |
283 | */ |
284 | TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices, |
285 | const Array<PrimExpr>& final_indices, |
286 | Array<IterVar>* out_iter_vars = nullptr); |
287 | /*! \brief Defines separators between groups of axes. |
288 | * |
289 | * Used to define `BufferNode::axis_separators`, which has |
290 | * additional details. |
291 | * |
292 | * \param axis_separators A list of axis separators. |
293 | */ |
294 | TVM_DLL Stage& set_axis_separators(const Array<IntImm>& axis_separators); |
295 | /*! |
296 | * \brief whether the stage has been scheduled. |
297 | * \return whether the stage has been scheduled. |
298 | */ |
299 | bool is_scheduled() const; |
300 | /*! |
301 | * \brief Get attachment spec of current stage. |
302 | * If the stage compute at Group root, this function |
303 | * will traverse the group function to get the |
304 | * final spec from the group. |
305 | * \return A stage representing the attach spec of the group. |
306 | */ |
307 | Stage GetAttachSpec() const; |
308 | // declare container type |
309 | using ContainerType = StageNode; |
310 | }; |
311 | |
312 | /*! |
313 | * \brief Global schedule container |
314 | * For operations and all the operations they depend on. |
315 | * The schedule per Operation is named as stage. |
316 | */ |
317 | class Schedule : public ObjectRef { |
318 | public: |
319 | Schedule() {} |
320 | explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {} |
321 | /*! |
322 | * \brief Create a schedule for array of ops(and their dependencies). |
323 | * \param ops The ops to be scheduled. |
324 | * \return sch The created Schedule. |
325 | */ |
326 | TVM_DLL explicit Schedule(Array<Operation> ops); |
327 | /*! |
328 | * \brief Get a copy of current schedule. |
329 | * \return The copied schedule. |
330 | */ |
331 | Schedule copy() const; |
332 | /*! |
333 | * \brief Get the stage corresponds to the op |
334 | * \param op The operation. |
335 | */ |
336 | TVM_DLL Stage operator[](const Operation& op); |
337 | /*! |
338 | * \brief Short hand for getting the stage of tensor's operation. |
339 | * \param tensor The tensor |
340 | * \return The stage corresponding to the tensor's op |
341 | */ |
342 | TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } |
343 | /*! |
344 | * \brief Create a new stage group for all intermediate |
345 | * operations between inputs and outputs. |
346 | * |
347 | * \param outputs The output boundary of the group. |
348 | * \param inputs The input boundary of the group. |
349 | * \param include_inputs Whether include inputs if they are reachable from outputs. |
350 | * \return The new grouped stage. |
351 | */ |
352 | TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs, |
353 | bool include_inputs = false); |
354 | /*! |
355 | * \brief create a cache read of original tensor for readers. |
356 | * This will mutate the body of the readers. |
357 | * A new stage will be created for the tensor. |
358 | * \param tensor The tensor cached. |
359 | * \param scope The scope of the cache. |
360 | * \param readers The readers to redirect to the tensor. |
361 | * \return The created tensor. |
362 | */ |
363 | TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope, |
364 | const Array<Operation>& readers); |
365 | /*! |
366 | * \brief Create a cache write tensor for producing tensor. |
367 | * The tensor will take over body of original tensor op. |
368 | * |
369 | * This function can be used to do data layout transformation. |
370 | * If there is a split/fuse/reorder on the data parallel axis of tensor |
371 | * before cache_write is called. The intermediate cache stores |
372 | * the data in the layout as the iteration order of leave axis. |
373 | * The data will be transformed back to the original layout in the original tensor. |
374 | * User can further call compute_inline to inline the original layout and keep |
375 | * the data stored in the transformed layout. |
376 | * |
377 | * \param tensor The tensors to be produced. |
378 | * \param scope The scope of the storage. |
379 | * \return The created tensor. |
380 | */ |
381 | TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope); |
382 | /*! |
383 | * \brief Create a cache write tensor for producing tensor. |
384 | * The tensor will take over body of original tensor op. |
385 | * |
386 | * This function can be used to do data layout transformation. |
387 | * If there is a split/fuse/reorder on the data parallel axis of tensor |
388 | * before cache_write is called. The intermediate cache stores |
389 | * the data in the layout as the iteration order of leave axis. |
390 | * The data will be transformed back to the original layout in the original tensor. |
391 | * User can further call compute_inline to inline the original layout and keep |
392 | * the data stored in the transformed layout. |
393 | * |
394 | * \param tensor The tensor to be produced. |
395 | * \param scope The scope of the storage. |
396 | * \return The created tensor. |
397 | */ |
398 | TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope); |
399 | /*! |
400 | * \brief Factor a reduction axis in tensor's schedule to be an explicit axis. |
401 | * This will create a new stage that generated the new tensor with axis |
402 | * as the first dimension. The tensor's body will be rewritten as a reduction |
403 | * over the factored tensor. |
404 | * |
405 | * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17 |
406 | * |
407 | * \param tensor The tensor to be factored. |
408 | * \param axis The reduction axis in tensor's schedule to be factored. |
409 | * \param factor_axis The position where the new axis is placed. |
410 | * \return The created factored tensors. |
411 | */ |
412 | TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0); |
413 | /*! |
414 | * \brief Normalize the schedule. |
415 | * This is needed before bound inference. |
416 | * Insert necessary RebaseNode to make sure all leaf_iter_vars |
417 | * are in form [0, extent) |
418 | * |
419 | * \return A normalized schedule, can be same as current one. |
420 | */ |
421 | Schedule normalize(); |
422 | |
423 | /*! |
424 | * \brief Normalize the schedule for feature extraction in auto-scheduler. |
425 | * This is similar to `Schedule::normalize`, but we do aggressive simplification |
426 | * to the TE compute with const_matrix=True for faster compilation and feature extraction. |
427 | * The resulted schedule may be wrong, but it is good enough for feature extraction |
428 | * purposes. |
429 | * |
430 | * \return A normalized schedule, can be same as current one. |
431 | */ |
432 | Schedule (); |
433 | |
434 | /*! |
435 | * \brief access the internal node container |
436 | * \return the pointer to the internal node container |
437 | */ |
438 | inline const ScheduleNode* operator->() const; |
439 | /*! |
440 | * \brief access the internal node container |
441 | * \return the pointer to the internal node container |
442 | */ |
443 | inline ScheduleNode* operator->(); |
444 | // declare container type |
445 | using ContainerType = ScheduleNode; |
446 | }; |
447 | |
448 | /*! |
449 | * \brief The schedule relation between IterVars |
450 | * can be Split, Fuse. |
451 | */ |
452 | class IterVarRelation : public ObjectRef { |
453 | public: |
454 | IterVarRelation() {} |
455 | explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {} |
456 | /*! |
457 | * \brief access the internal node container |
458 | * \return the pointer to the internal node container |
459 | */ |
460 | inline const IterVarRelationNode* operator->() const; |
461 | }; |
462 | |
463 | /*! |
464 | * \brief Additional scheduable attributes about IterVar. |
465 | */ |
466 | class IterVarAttr : public ObjectRef { |
467 | public: |
468 | IterVarAttr() {} |
469 | explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {} |
470 | /*! |
471 | * \brief access the internal node container |
472 | * \return the pointer to the internal node container |
473 | */ |
474 | inline const IterVarAttrNode* operator->() const; |
475 | }; |
476 | |
477 | /*! |
478 | * \brief represents a stage. |
479 | * |
480 | * relations form a Directed acylic hypergraph in bipartite manner. |
481 | * With each node is represented by a IterVar, |
482 | * and each hyper-edge is represented by a IterVarRelation. |
483 | * The relations connects the IterVars in the graph. |
484 | * |
485 | * Besides typical stage that corresponds to operations. |
486 | * There is also group stage, which groups stages together. |
487 | * Each stage's group(given by group) represent an constraint, |
488 | * the stage can only be attached to stages within the group. |
489 | * |
490 | * The group stage node can be attached to IterVars as in normal stage. |
491 | */ |
492 | class StageNode : public Object { |
493 | public: |
494 | /*! |
495 | * \brief The operation of stage, can be different from original op. |
496 | * If it is null, then this stage is a group stage. |
497 | */ |
498 | Operation op; |
499 | /*! |
500 | * \brief The original operator. |
501 | * The op field can change during schedule to alternate the dataflow, |
502 | * while origin_op remains fixed. |
503 | */ |
504 | Operation origin_op; |
505 | /*! \brief All the nodes in the iter var |
506 | * |
507 | * Each element of all_iter_vars represents an iteration variable |
508 | * that may appear within this stage's computation. Any element |
509 | * of `all_iter_vars` that is in `leaf_iter_vars` represents a |
510 | * variable that is directly defined and usable within the stage's |
511 | * computation. All other elements of `all_iter_vars` represent |
512 | * variables whose value must be computed from the variables in |
513 | * `leaf_iter_vars`. (e.g. Support index k has been split by |
514 | * ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in |
515 | * `leaf_iter_vars`, while k will not, and must be computed as |
516 | * `4*ko + ki`. |
517 | */ |
518 | Array<IterVar> all_iter_vars; |
519 | /*! \brief The current active leaf iter vars in the stage. |
520 | * |
521 | * Each element of leaf_iter_vars will either be replaced with the |
522 | * bound index (e.g. threadIdx.x), or will be expanded into a loop |
523 | * over the variable's extent. `leaf_iter_vars` is a subset of |
524 | * `all_iter_vars`. |
525 | */ |
526 | Array<IterVar> leaf_iter_vars; |
527 | /*! |
528 | * \brief Specify threads to be launched at the stage. |
529 | * This is only valid for composite ops such as Scan. |
530 | * \note Experimental primitive: used for thread persistence. |
531 | */ |
532 | Array<IterVar> env_threads; |
533 | /*! |
534 | * \brief The predicate under which store can happen |
535 | * Use this when there can be duplicated threads doing the same store. |
536 | * \note Experimental primitive: used by cross thread-reduction. |
537 | */ |
538 | PrimExpr store_predicate; |
539 | /*! \brief The relation bwteen of IterVars */ |
540 | Array<IterVarRelation> relations; |
541 | /*! \brief additional attributes about iter var. */ |
542 | Map<IterVar, IterVarAttr> iter_var_attrs; |
543 | /*! \brief The attachment type of the schedule */ |
544 | AttachType attach_type{kGroupRoot}; |
545 | /*! \brief The attach point of this schedule. */ |
546 | IterVar attach_ivar; |
547 | /*! \brief The stage this node attaches to */ |
548 | Stage attach_stage; |
549 | /*! \brief The thread storage scope level of the stage */ |
550 | std::string scope; |
551 | /*! \brief Whether this is an output stage */ |
552 | bool is_output{false}; |
553 | /*! \brief Whether apply double buffer optimization to this stage */ |
554 | bool double_buffer{false}; |
555 | /*! \brief Whether apply rolling buffer optimization to this stage */ |
556 | bool rolling_buffer{false}; |
557 | /*! \brief Layout transformations to be applied onto the stage's tensors. */ |
558 | Array<IndexMap> layout_transforms; |
559 | /*! \brief List of axes after which to divide physical axes. |
560 | * |
561 | * Used to populate `BufferNode::axis_separators`, which has |
562 | * additional details. |
563 | */ |
564 | Array<IntImm> axis_separators; |
565 | /*! |
566 | * \brief The parent group of the current stage. |
567 | * The stage cannot be assigned to stages outside the group. |
568 | */ |
569 | Stage group; |
570 | /*! \brief Number of direct child stages, only used for group stage.*/ |
571 | int num_child_stages{0}; |
572 | |
573 | void VisitAttrs(AttrVisitor* v) { |
574 | v->Visit("op" , &op); |
575 | v->Visit("origin_op" , &origin_op); |
576 | v->Visit("all_iter_vars" , &all_iter_vars); |
577 | v->Visit("leaf_iter_vars" , &leaf_iter_vars); |
578 | v->Visit("env_threads" , &env_threads); |
579 | v->Visit("relations" , &relations); |
580 | v->Visit("iter_var_attrs" , &iter_var_attrs); |
581 | v->Visit("attach_type" , &attach_type); |
582 | v->Visit("attach_ivar" , &attach_ivar); |
583 | v->Visit("attach_stage" , &attach_stage); |
584 | v->Visit("scope" , &scope); |
585 | v->Visit("is_output" , &is_output); |
586 | v->Visit("double_buffer" , &double_buffer); |
587 | v->Visit("layout_transforms" , &layout_transforms); |
588 | v->Visit("axis_separators" , &axis_separators); |
589 | v->Visit("group" , &group); |
590 | v->Visit("num_child_stages" , &num_child_stages); |
591 | } |
592 | |
593 | static constexpr const char* _type_key = "Stage" ; |
594 | TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); |
595 | }; |
596 | |
597 | /*! \brief node container for schedule */ |
598 | class ScheduleNode : public Object { |
599 | public: |
600 | /*! \brief The output operations in original data flow graph */ |
601 | Array<Operation> outputs; |
602 | /*! |
603 | * \brief list of all stages for ops. |
604 | * The stages are sorted in dependency order. |
605 | */ |
606 | Array<Stage> stages; |
607 | /*! |
608 | * \brief List of all stage groups. |
609 | */ |
610 | Array<Stage> groups; |
611 | /*! \brief map of original operation to the stages */ |
612 | Map<Operation, Stage> stage_map; |
613 | /*! |
614 | * \brief Internal stage map to map internal ops to stages. |
615 | * This is created on demand and can be invalidated. |
616 | */ |
617 | std::unordered_map<const Object*, Stage> op2stage_cache_; |
618 | |
619 | void VisitAttrs(AttrVisitor* v) { |
620 | v->Visit("outputs" , &outputs); |
621 | v->Visit("stages" , &stages); |
622 | v->Visit("groups" , &groups); |
623 | v->Visit("stage_map" , &stage_map); |
624 | } |
625 | |
626 | /*! \brief Initialize temp cache. */ |
627 | void InitCache(); |
628 | /*! \brief Invalidate temp cache. */ |
629 | void InvalidateCache(); |
630 | |
631 | /*! |
632 | * \brief Check if the schedule contains an Operation. |
633 | * \param op The candidate Operation. |
634 | * \return true if the schedule has the Operation. Otherwise, false. |
635 | */ |
636 | TVM_DLL bool Contain(const Operation& op) const; |
637 | |
638 | /*! |
639 | * \brief Check if the schedule contains a Tensor. |
640 | * \param tensor The candidate tensor. |
641 | * \return true if the schedule has the tensor. Otherwise, false. |
642 | */ |
643 | TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } |
644 | |
645 | static constexpr const char* _type_key = "Schedule" ; |
646 | TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); |
647 | }; |
648 | |
649 | /*! |
650 | * \brief Create a schedule for array of ops(and their dependencies). |
651 | * \param ops The ops to be scheduled. |
652 | * \return sch The created Schedule. |
653 | */ |
654 | inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); } |
655 | |
656 | /*! \brief node container for IterVar attr */ |
657 | class IterVarAttrNode : public Object { |
658 | public: |
659 | /*! \brief The iteration type. */ |
660 | IterVarType iter_type{kDataPar}; |
661 | /*! \brief The thread this iter Var binds, can be null */ |
662 | IterVar bind_thread; |
663 | /*! \brief List of tensor to be prefetched in this loop */ |
664 | Array<Tensor> prefetch_data; |
665 | /*! \brief The offset used in each prefetch */ |
666 | Array<PrimExpr> prefetch_offset; |
667 | /*! |
668 | * \brief Tensor intrinsic used in tensorization, |
669 | * when the axis is marked as Tensorized |
670 | */ |
671 | TensorIntrin tensor_intrin; |
672 | /*! \brief Alignment factor of buffer dimension */ |
673 | int dim_align_factor{0}; |
674 | /*! \brief Alignment offset of buffer dimension */ |
675 | int dim_align_offset{0}; |
676 | /*! |
677 | * \brief Additional pragma keys, array of StringImm |
678 | */ |
679 | Array<PrimExpr> pragma_keys; |
680 | /*! |
681 | * \brief Additional values of pragma, if any |
682 | */ |
683 | Array<PrimExpr> pragma_values; |
684 | |
685 | void VisitAttrs(AttrVisitor* v) { |
686 | v->Visit("iter_type" , &iter_type); |
687 | v->Visit("bind_thread" , &bind_thread); |
688 | v->Visit("prefetch_data" , &prefetch_data); |
689 | v->Visit("prefetch_offset" , &prefetch_offset); |
690 | v->Visit("tensor_intrin" , &tensor_intrin); |
691 | v->Visit("dim_align_factor" , &dim_align_factor); |
692 | v->Visit("dim_align_offset" , &dim_align_offset); |
693 | v->Visit("pragma_keys" , &pragma_keys); |
694 | v->Visit("pragma_values" , &pragma_values); |
695 | } |
696 | |
697 | static constexpr const char* _type_key = "IterVarAttr" ; |
698 | TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object); |
699 | }; |
700 | |
701 | /*! \brief base node of iteration var */ |
702 | class IterVarRelationNode : public Object { |
703 | public: |
704 | static constexpr const char* _type_key = "IterVarRelation" ; |
705 | TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object); |
706 | }; |
707 | |
708 | /*! |
709 | * \brief Split the parent domain into product of |
710 | * outer and iter. |
711 | */ |
712 | class SplitNode : public IterVarRelationNode { |
713 | public: |
714 | /*! \brief The parent domain */ |
715 | IterVar parent; |
716 | /*! \brief The outer domain */ |
717 | IterVar outer; |
718 | /*! \brief The inner domain */ |
719 | IterVar inner; |
720 | /*! \brief The split factor */ |
721 | PrimExpr factor; |
722 | /*! \brief Number of parts, only factor or nparts can be given */ |
723 | PrimExpr nparts; |
724 | |
725 | void VisitAttrs(AttrVisitor* v) { |
726 | v->Visit("parent" , &parent); |
727 | v->Visit("outer" , &outer); |
728 | v->Visit("inner" , &inner); |
729 | v->Visit("factor" , &factor); |
730 | v->Visit("nparts" , &nparts); |
731 | } |
732 | |
733 | static constexpr const char* _type_key = "Split" ; |
734 | TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); |
735 | }; |
736 | |
737 | /*! |
738 | * \brief Managed reference to SplitNode |
739 | * \sa SplitNode |
740 | */ |
741 | class Split : public IterVarRelation { |
742 | public: |
743 | TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); |
744 | |
745 | TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); |
746 | }; |
747 | |
748 | /*! |
749 | * \brief Fuse two domains into one domain. |
750 | */ |
751 | class FuseNode : public IterVarRelationNode { |
752 | public: |
753 | /*! \brief The outer domain */ |
754 | IterVar outer; |
755 | /*! \brief The inner domain */ |
756 | IterVar inner; |
757 | /*! \brief The target domain */ |
758 | IterVar fused; |
759 | |
760 | void VisitAttrs(AttrVisitor* v) { |
761 | v->Visit("outer" , &outer); |
762 | v->Visit("inner" , &inner); |
763 | v->Visit("fused" , &fused); |
764 | } |
765 | |
766 | static constexpr const char* _type_key = "Fuse" ; |
767 | TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); |
768 | }; |
769 | |
770 | /*! |
771 | * \brief Managed reference to FuseNode |
772 | * \sa FuseNode |
773 | */ |
774 | class Fuse : public IterVarRelation { |
775 | public: |
776 | TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused); |
777 | |
778 | TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode); |
779 | }; |
780 | |
781 | /*! |
782 | * \brief Rebase the iteration to make min to be 0. |
783 | * This is useful to normalize the Schedule |
784 | * to make every leaf variable's min to be 0. |
785 | */ |
786 | class RebaseNode : public IterVarRelationNode { |
787 | public: |
788 | /*! \brief The parent domain */ |
789 | IterVar parent; |
790 | /*! \brief The inner domain */ |
791 | IterVar rebased; |
792 | |
793 | void VisitAttrs(AttrVisitor* v) { |
794 | v->Visit("parent" , &parent); |
795 | v->Visit("rebased" , &rebased); |
796 | } |
797 | |
798 | static constexpr const char* _type_key = "Rebase" ; |
799 | TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); |
800 | }; |
801 | |
802 | /*! |
803 | * \brief Managed reference to RebaseNode |
804 | * \sa RebaseNode |
805 | */ |
806 | class Rebase : public IterVarRelation { |
807 | public: |
808 | TVM_DLL Rebase(IterVar parent, IterVar rebased); |
809 | |
810 | TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode); |
811 | }; |
812 | |
813 | /*! |
814 | * \brief Singleton iterator [0, 1) |
815 | */ |
816 | class SingletonNode : public IterVarRelationNode { |
817 | public: |
818 | /*! \brief The singleton iterator */ |
819 | IterVar iter; |
820 | |
821 | void VisitAttrs(AttrVisitor* v) { v->Visit("iter" , &iter); } |
822 | |
823 | static constexpr const char* _type_key = "Singleton" ; |
824 | TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); |
825 | }; |
826 | |
827 | /*! |
828 | * \brief Managed reference to SingletonNode |
829 | * \sa SingletonNode |
830 | */ |
831 | class Singleton : public IterVarRelation { |
832 | public: |
833 | TVM_DLL explicit Singleton(IterVar iter); |
834 | |
835 | TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); |
836 | }; |
837 | |
838 | /*! |
839 | * \brief Transform iterator according to some arbitrary expression. |
840 | */ |
841 | class TransformNode : public IterVarRelationNode { |
842 | public: |
843 | /*! \brief The loop variables that were replaced by the transformation. |
844 | * |
845 | * Prior to applying a layout transformation, these represent the |
846 | * loops to iterate over a tensor as it is being computed, following |
847 | * a row-major traversal of the tensor's original shape in the |
848 | * compute definition. |
849 | */ |
850 | Array<IterVar> original_variables; |
851 | |
852 | /*! \brief The variables generated by the transformation. |
853 | * |
854 | * After to applying a layout transformation, these represent the |
855 | * loops to iterate over a tensor as it is being computed, following |
856 | * a row-major traversal of the transformed shape of the tensor. |
857 | */ |
858 | Array<IterVar> transformed_variables; |
859 | |
860 | /*! \brief Map from the original variables to the transformed variables. |
861 | * |
862 | * Used to determine iterator ranges over the transformed variables. |
863 | */ |
864 | IndexMap forward_transformation; |
865 | |
866 | /*! \brief Map from transformed variables to the original variables |
867 | * |
868 | * Used to rewrite expressions containing the original loop iterators |
869 | * in terms of the transformed loop iterators. |
870 | */ |
871 | IndexMap inverse_transformation; |
872 | |
873 | void VisitAttrs(AttrVisitor* v) { |
874 | v->Visit("original_variables" , &original_variables); |
875 | v->Visit("transformed_variables" , &transformed_variables); |
876 | v->Visit("forward_transformation" , &forward_transformation); |
877 | v->Visit("inverse_transformation" , &inverse_transformation); |
878 | } |
879 | |
880 | static constexpr const char* _type_key = "Transform" ; |
881 | TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode); |
882 | }; |
883 | |
884 | class Transform : public IterVarRelation { |
885 | public: |
886 | TVM_DLL explicit Transform(Array<IterVar> original_variables, |
887 | Array<IterVar> transformed_variables, IndexMap forward_transformation, |
888 | IndexMap inverse_transformation); |
889 | |
890 | TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode); |
891 | }; |
892 | |
893 | /*! \brief Container for specialization conditions. */ |
894 | class SpecializedConditionNode : public Object { |
895 | public: |
896 | /*! |
897 | * \brief List of conditions in conjunctive joint form (CNF). |
898 | * Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc., |
899 | * where n, m are tvm::Var that represents a dimension in the tensor shape. |
900 | */ |
901 | Array<PrimExpr> clauses; |
902 | |
903 | void VisitAttrs(AttrVisitor* v) { v->Visit("clauses" , &clauses); } |
904 | |
905 | static constexpr const char* _type_key = "SpecializedCondition" ; |
906 | TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object); |
907 | }; |
908 | |
909 | /*! |
910 | * \brief Specialized condition to enable op specialization |
911 | */ |
912 | class SpecializedCondition : public ObjectRef { |
913 | public: |
914 | /*! |
915 | * \brief construct from conditions |
916 | * \param conditions The clauses in the specialized condition. |
917 | */ |
918 | TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*) |
919 | |
920 | /*! |
921 | * \brief Get the current specialized condition. |
922 | * \return the current specialized condition. |
923 | */ |
924 | TVM_DLL static SpecializedCondition Current(); |
925 | |
926 | TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode); |
927 | class Internal; |
928 | |
929 | private: |
930 | // enable with syntax. |
931 | friend class Internal; |
932 | friend class With<SpecializedCondition>; |
933 | /*! \brief Push a new specialized condition onto the thread local stack. */ |
934 | TVM_DLL void EnterWithScope(); |
935 | /*! \brief Pop a specialized condition off the thread local context stack. */ |
936 | TVM_DLL void ExitWithScope(); |
937 | }; |
938 | |
939 | // implementations |
940 | inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); } |
941 | inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); } |
942 | |
943 | inline const ScheduleNode* Schedule::operator->() const { |
944 | return static_cast<const ScheduleNode*>(get()); |
945 | } |
946 | inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); } |
947 | |
948 | inline const IterVarRelationNode* IterVarRelation::operator->() const { |
949 | return static_cast<const IterVarRelationNode*>(get()); |
950 | } |
951 | |
952 | inline const IterVarAttrNode* IterVarAttr::operator->() const { |
953 | return static_cast<const IterVarAttrNode*>(get()); |
954 | } |
955 | |
956 | } // namespace te |
957 | } // namespace tvm |
958 | #endif // TVM_TE_SCHEDULE_H_ |
959 | |