1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ |
21 | #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ |
22 | |
23 | #include <tvm/ir/expr.h> |
24 | #include <tvm/node/reflection.h> |
25 | #include <tvm/runtime/container/array.h> |
26 | #include <tvm/runtime/container/map.h> |
27 | #include <tvm/runtime/container/optional.h> |
28 | #include <tvm/runtime/container/string.h> |
29 | #include <tvm/runtime/object.h> |
30 | #include <tvm/runtime/packed_func.h> |
31 | #include <tvm/tir/schedule/schedule.h> |
32 | |
33 | namespace tvm { |
34 | namespace meta_schedule { |
35 | |
36 | class TuneContext; |
37 | class ScheduleRule; |
38 | |
39 | /*! \brief Rules to modify a block in a schedule. */ |
40 | class ScheduleRuleNode : public runtime::Object { |
41 | public: |
42 | /*! \brief Virtual destructor. */ |
43 | virtual ~ScheduleRuleNode() = default; |
44 | |
45 | void VisitAttrs(tvm::AttrVisitor* v) {} |
46 | |
47 | /*! |
48 | * \brief Initialize the design space generator with tuning context. |
49 | * \param context The tuning context for initialization. |
50 | * \note This method is supposed to be called only once before every other method. |
51 | */ |
52 | virtual void InitializeWithTuneContext(const TuneContext& context) = 0; |
53 | |
54 | /*! |
55 | * \brief Apply a schedule rule to the specific block in the given schedule. |
56 | * \param sch The schedule to be modified. |
57 | * \param block The specific block to apply the schedule rule. |
58 | * \return The list of schedules generated by applying the schedule rule. |
59 | */ |
60 | virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch, |
61 | const tir::BlockRV& block) = 0; |
62 | |
63 | /*! |
64 | * \brief Deep clone the schedule rule. |
65 | * \return The cloned schedule rule. |
66 | */ |
67 | virtual ScheduleRule Clone() const = 0; |
68 | |
69 | static constexpr const char* _type_key = "meta_schedule.ScheduleRule" ; |
70 | TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); |
71 | }; |
72 | |
73 | /*! |
74 | * \brief Managed reference to ScheduleRuleNode |
75 | * \sa ScheduleRuleNode |
76 | */ |
77 | class ScheduleRule : public runtime::ObjectRef { |
78 | public: |
79 | /*! |
80 | * \brief The function type of `InitializeWithTuneContext` method. |
81 | * \param context The tuning context for initialization. |
82 | */ |
83 | using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>; |
84 | /*! |
85 | * \brief The function type of `Apply` method. |
86 | * \param sch The schedule to be modified. |
87 | * \param block The specific block to apply the schedule rule. |
88 | * \return The list of schedules generated by applying the schedule rule. |
89 | */ |
90 | using FApply = |
91 | runtime::TypedPackedFunc<Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>; |
92 | /*! |
93 | * \brief Get the schedule rule as string with name. |
94 | * \return The string of the schedule rule. |
95 | */ |
96 | using FAsString = runtime::TypedPackedFunc<String()>; |
97 | /*! |
98 | * \brief The function type of `Clone` method. |
99 | * \return The cloned schedule rule. |
100 | */ |
101 | using FClone = runtime::TypedPackedFunc<ScheduleRule()>; |
102 | /*! |
103 | * \brief Create a rule that applies customized rules registered using block attribute |
104 | * `schedule_rule`. The rule will be dispatched according to target keys. |
105 | * \return The created schedule rule. |
106 | */ |
107 | TVM_DLL static ScheduleRule ApplyCustomRule(); |
108 | /*! \brief Check if the rule is `ApplyCustomRule` */ |
109 | TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule); |
110 | /*! |
111 | * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions |
112 | * \param into_producer If allows to inline a block into its producer |
113 | * \param into_consumer If allows to inline a block into its consumer |
114 | * \param inline_const_tensor Always inline constant tensors |
115 | * \param disallow_if_then_else Always disallow if-then-else-like constructs |
116 | * \param require_ordered Always require the read-to-write mapping to be ordered |
117 | * \param require_injective Always require the read-to-write mapping to be injective |
118 | * \param disallow_op The operators that are disallowed in auto inline |
119 | * \return The schedule rule created |
120 | */ |
121 | TVM_DLL static ScheduleRule AutoInline(bool into_producer, // |
122 | bool into_consumer, // |
123 | bool inline_const_tensor, // |
124 | bool disallow_if_then_else, // |
125 | bool require_injective, // |
126 | bool require_ordered, // |
127 | Optional<Array<String>> disallow_op); |
128 | |
129 | /*! |
130 | * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of |
131 | * ReverseComputeInline during AutoInline, since they are also counted as a producer block |
132 | * unless they are inlined first. So it is recommended to run InlineConstantScalars before |
133 | * AutoInline. |
134 | * \return The schedule rule created |
135 | */ |
136 | TVM_DLL static ScheduleRule InlineConstantScalars(); |
137 | |
138 | /*! |
139 | * \brief Create a mega rule: multi-level tiling with data reuse |
140 | * \param structure The tiling structure. Recommended: |
141 | * - 'SSRSRS' on CPU |
142 | * - 'SSSRRSRS' on GPU |
143 | * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: |
144 | * - NullOpt on CPU |
145 | * - [blockIdx.x, vthread.x, threadIdx.x] on GPU |
146 | * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit |
147 | * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. |
148 | * NullOpt means disable vectorization |
149 | * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. |
150 | * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. |
151 | * \param filter_fn A function that can be passed to overwrite the default condition for applying |
152 | * MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool. |
153 | * This is useful if there is a need to apply MultiLevelTiling to an operation / block which is |
154 | * ignored by default. This function should return True for a block that should be tiled. |
155 | * \return The schedule rule created |
156 | */ |
157 | TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // |
158 | Optional<Array<String>> tile_binds, // |
159 | Optional<Integer> max_innermost_factor, // |
160 | Optional<Array<Integer>> vector_load_lens, // |
161 | Optional<Map<String, ObjectRef>> reuse_read, // |
162 | Optional<Map<String, ObjectRef>> reuse_write, |
163 | Optional<runtime::PackedFunc> filter_fn = NullOpt); |
164 | |
165 | /*! |
166 | * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. |
167 | * \param intrin_name The name of a tensor intrinsic, must be registered via |
168 | * TensorIntrin.register(...) beforehand |
169 | * \param structure The tiling structure. Recommended: |
170 | * - 'SSRSRS' on CPU |
171 | * - 'SSSRRSRS' on GPU |
172 | * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: |
173 | * - NullOpt on CPU |
174 | * - [blockIdx.x, vthread.x, threadIdx.x] on GPU |
175 | * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit |
176 | * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. |
177 | * NullOpt means disable vectorization |
178 | * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. |
179 | * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. |
180 | * \return The schedule rule created |
181 | */ |
182 | TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( |
183 | String intrin_name, String structure, Optional<Array<String>> tile_binds, |
184 | Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens, |
185 | Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write); |
186 | |
187 | /*! |
188 | * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate |
189 | * tensor core intrinsics |
190 | * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key |
191 | * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for |
192 | * initialization, loading operand A, loading operand B, tensor core computation, storing the |
193 | * result. The value of the map should be names of tensor intrinsics, must be registered via |
194 | * TensorIntrin.register(...) beforehand |
195 | * \param structure The tiling structure. Recommended: |
196 | * - 'SSSRRSRS' on GPU |
197 | * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: |
198 | * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU |
199 | * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit |
200 | * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. |
201 | * NullOpt means disable vectorization |
202 | * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. |
203 | * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. |
204 | * \param use_software_pipeline Whether use the software pipeline. |
205 | * \return The schedule rule created |
206 | */ |
207 | TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( |
208 | Array<Map<String, String>> intrin_groups, String structure, |
209 | Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor, |
210 | Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read, |
211 | Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline); |
212 | |
213 | /*! |
214 | * \brief Extension of MultiLevelTiling for backends with wide vectors. |
215 | * The loop over the innermost spatial axis of the output buffer is always vectorized with the |
216 | * maximum vector length. |
217 | * \param structure The tiling structure. 'SSRSRS' is recommended. |
218 | * \param vector_length_in_bits The length of a vector register in bits. |
219 | * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit |
220 | * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. |
221 | * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. |
222 | * \return The schedule rule created |
223 | */ |
224 | TVM_DLL static ScheduleRule MultiLevelTilingWideVector( |
225 | String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor, |
226 | Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write); |
227 | |
228 | /*! |
229 | * \brief Create a rule: add-rfactor to some blocks if needed |
230 | * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the |
231 | * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable |
232 | * parallelism. |
233 | * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit |
234 | * \return The schedule rule created |
235 | */ |
236 | TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // |
237 | Optional<Integer> max_innermost_factor); |
238 | /*! |
239 | * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks |
240 | * correspondingly when needed |
241 | * \param thread_extents Candidates of thread axis extent (values are required to be positive). |
242 | * \return The schedule rule created |
243 | */ |
244 | TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents); |
245 | /*! |
246 | * \brief A rule that randomly select a compute-at location for a free block |
247 | * \return The schedule rule created |
248 | */ |
249 | TVM_DLL static ScheduleRule RandomComputeLocation(); |
250 | /*! |
251 | * \brief Mark parallelize, vectorize and unroll to the root block. The mark will be applied to |
252 | * each block in a follow-up post processor |
253 | * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the |
254 | * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable |
255 | * parallelism. |
256 | * \param max_vectorize_extent The maximum extent to be vectorized. |
257 | * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization. |
258 | * \param unroll_max_steps The options of the maximum number of unroll steps to be done. |
259 | * Use an empty array to disable unroll. |
260 | * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. |
261 | * \return The schedule rule created |
262 | */ |
263 | TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // |
264 | int max_vectorize_extent, // |
265 | Array<Integer> unroll_max_steps, // |
266 | bool unroll_explicit); |
267 | /*! |
268 | * \brief Auto bind loops around the block to BlockIdx and ThreadIdx |
269 | * \param max_threadblocks The maximum number of threadblock on GPU |
270 | * \param thread_extents Candidates of thread axis extent. |
271 | * \param max_threads_per_block The maximum number of threads per block, if it is known |
272 | * when this schedule rule is created. |
273 | * \return The schedule rule created |
274 | */ |
275 | TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents, |
276 | int max_threads_per_block = -1); |
277 | /*! |
278 | * \brief Create a schedule rule with customized methods on the python-side. |
279 | * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. |
280 | * \param f_apply The packed function of `Apply`. |
281 | * \param f_clone The packed function of `Clone`. |
282 | * \param f_as_string The packed function of `AsString`. |
283 | * \return The schedule rule created. |
284 | */ |
285 | TVM_DLL static ScheduleRule PyScheduleRule( |
286 | FInitializeWithTuneContext f_initialize_with_tune_context, // |
287 | FApply f_apply, // |
288 | FClone f_clone, // |
289 | FAsString f_as_string); |
290 | |
291 | /*! \brief Create default schedule rules for LLVM */ |
292 | TVM_DLL static Array<ScheduleRule, void> DefaultLLVM(); |
293 | /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */ |
294 | TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type); |
295 | /*! \brief Create default schedule rules for CUDA */ |
296 | TVM_DLL static Array<ScheduleRule, void> DefaultCUDA(); |
297 | /*! \brief Create default postprocessors for CUDA with TensorCore */ |
298 | TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore(); |
299 | /*! \brief Create default schedule rules for Hexagon */ |
300 | TVM_DLL static Array<ScheduleRule, void> DefaultHexagon(); |
301 | /*! \brief Create default schedule rules for Micro */ |
302 | TVM_DLL static Array<ScheduleRule, void> DefaultMicro(); |
303 | |
304 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); |
305 | }; |
306 | |
307 | /*! \brief The schedule rule with customized methods on the python-side. */ |
308 | class PyScheduleRuleNode : public ScheduleRuleNode { |
309 | public: |
310 | using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext; |
311 | using FApply = ScheduleRule::FApply; |
312 | using FClone = ScheduleRule::FClone; |
313 | using FAsString = ScheduleRule::FAsString; |
314 | |
315 | /*! \brief The packed function to the `InitializeWithTuneContext` function. */ |
316 | FInitializeWithTuneContext f_initialize_with_tune_context; |
317 | /*! \brief The packed function to the `Apply` function. */ |
318 | FApply f_apply; |
319 | /*! \brief The packed function to the `AsString` function. */ |
320 | FAsString f_as_string; |
321 | /*! \brief The packed function to the `Clone` function. */ |
322 | FClone f_clone; |
323 | |
324 | void VisitAttrs(tvm::AttrVisitor* v) { |
325 | // `f_initialize_with_tune_context` is not visited |
326 | // `f_apply` is not visited |
327 | // `f_as_string` is not visited |
328 | // `f_clone` is not visited |
329 | } |
330 | |
331 | void InitializeWithTuneContext(const TuneContext& context) final; |
332 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; |
333 | ScheduleRule Clone() const final; |
334 | |
335 | static constexpr const char* _type_key = "meta_schedule.PyScheduleRule" ; |
336 | TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); |
337 | }; |
338 | |
339 | } // namespace meta_schedule |
340 | } // namespace tvm |
341 | |
342 | #endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_ |
343 | |