1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
21#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_
22
23#include <tvm/ir/expr.h>
24#include <tvm/node/reflection.h>
25#include <tvm/runtime/container/array.h>
26#include <tvm/runtime/container/map.h>
27#include <tvm/runtime/container/optional.h>
28#include <tvm/runtime/container/string.h>
29#include <tvm/runtime/object.h>
30#include <tvm/runtime/packed_func.h>
31#include <tvm/tir/schedule/schedule.h>
32
33namespace tvm {
34namespace meta_schedule {
35
36class TuneContext;
37class ScheduleRule;
38
39/*! \brief Rules to modify a block in a schedule. */
40class ScheduleRuleNode : public runtime::Object {
41 public:
42 /*! \brief Virtual destructor. */
43 virtual ~ScheduleRuleNode() = default;
44
45 void VisitAttrs(tvm::AttrVisitor* v) {}
46
47 /*!
48 * \brief Initialize the design space generator with tuning context.
49 * \param context The tuning context for initialization.
50 * \note This method is supposed to be called only once before every other method.
51 */
52 virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
53
54 /*!
55 * \brief Apply a schedule rule to the specific block in the given schedule.
56 * \param sch The schedule to be modified.
57 * \param block The specific block to apply the schedule rule.
58 * \return The list of schedules generated by applying the schedule rule.
59 */
60 virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
61 const tir::BlockRV& block) = 0;
62
63 /*!
64 * \brief Deep clone the schedule rule.
65 * \return The cloned schedule rule.
66 */
67 virtual ScheduleRule Clone() const = 0;
68
69 static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
70 TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
71};
72
73/*!
74 * \brief Managed reference to ScheduleRuleNode
75 * \sa ScheduleRuleNode
76 */
77class ScheduleRule : public runtime::ObjectRef {
78 public:
79 /*!
80 * \brief The function type of `InitializeWithTuneContext` method.
81 * \param context The tuning context for initialization.
82 */
83 using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
84 /*!
85 * \brief The function type of `Apply` method.
86 * \param sch The schedule to be modified.
87 * \param block The specific block to apply the schedule rule.
88 * \return The list of schedules generated by applying the schedule rule.
89 */
90 using FApply =
91 runtime::TypedPackedFunc<Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>;
92 /*!
93 * \brief Get the schedule rule as string with name.
94 * \return The string of the schedule rule.
95 */
96 using FAsString = runtime::TypedPackedFunc<String()>;
97 /*!
98 * \brief The function type of `Clone` method.
99 * \return The cloned schedule rule.
100 */
101 using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
102 /*!
103 * \brief Create a rule that applies customized rules registered using block attribute
104 * `schedule_rule`. The rule will be dispatched according to target keys.
105 * \return The created schedule rule.
106 */
107 TVM_DLL static ScheduleRule ApplyCustomRule();
108 /*! \brief Check if the rule is `ApplyCustomRule` */
109 TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
110 /*!
111 * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
112 * \param into_producer If allows to inline a block into its producer
113 * \param into_consumer If allows to inline a block into its consumer
114 * \param inline_const_tensor Always inline constant tensors
115 * \param disallow_if_then_else Always disallow if-then-else-like constructs
116 * \param require_ordered Always require the read-to-write mapping to be ordered
117 * \param require_injective Always require the read-to-write mapping to be injective
118 * \param disallow_op The operators that are disallowed in auto inline
119 * \return The schedule rule created
120 */
121 TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
122 bool into_consumer, //
123 bool inline_const_tensor, //
124 bool disallow_if_then_else, //
125 bool require_injective, //
126 bool require_ordered, //
127 Optional<Array<String>> disallow_op);
128
129 /*!
130 * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of
131 * ReverseComputeInline during AutoInline, since they are also counted as a producer block
132 * unless they are inlined first. So it is recommended to run InlineConstantScalars before
133 * AutoInline.
134 * \return The schedule rule created
135 */
136 TVM_DLL static ScheduleRule InlineConstantScalars();
137
138 /*!
139 * \brief Create a mega rule: multi-level tiling with data reuse
140 * \param structure The tiling structure. Recommended:
141 * - 'SSRSRS' on CPU
142 * - 'SSSRRSRS' on GPU
143 * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
144 * - NullOpt on CPU
145 * - [blockIdx.x, vthread.x, threadIdx.x] on GPU
146 * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
147 * \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
148 * NullOpt means disable vectorization
149 * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
150 * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
151 * \param filter_fn A function that can be passed to overwrite the default condition for applying
152 * MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool.
153 * This is useful if there is a need to apply MultiLevelTiling to an operation / block which is
154 * ignored by default. This function should return True for a block that should be tiled.
155 * \return The schedule rule created
156 */
157 TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
158 Optional<Array<String>> tile_binds, //
159 Optional<Integer> max_innermost_factor, //
160 Optional<Array<Integer>> vector_load_lens, //
161 Optional<Map<String, ObjectRef>> reuse_read, //
162 Optional<Map<String, ObjectRef>> reuse_write,
163 Optional<runtime::PackedFunc> filter_fn = NullOpt);
164
165 /*!
166 * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.
167 * \param intrin_name The name of a tensor intrinsic, must be registered via
168 * TensorIntrin.register(...) beforehand
169 * \param structure The tiling structure. Recommended:
170 * - 'SSRSRS' on CPU
171 * - 'SSSRRSRS' on GPU
172 * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
173 * - NullOpt on CPU
174 * - [blockIdx.x, vthread.x, threadIdx.x] on GPU
175 * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
176 * \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
177 * NullOpt means disable vectorization
178 * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
179 * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
180 * \return The schedule rule created
181 */
182 TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
183 String intrin_name, String structure, Optional<Array<String>> tile_binds,
184 Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
185 Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
186
187 /*!
188 * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate
189 * tensor core intrinsics
190 * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
191 * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
192 * initialization, loading operand A, loading operand B, tensor core computation, storing the
193 * result. The value of the map should be names of tensor intrinsics, must be registered via
194 * TensorIntrin.register(...) beforehand
195 * \param structure The tiling structure. Recommended:
196 * - 'SSSRRSRS' on GPU
197 * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
198 * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
199 * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
200 * \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
201 * NullOpt means disable vectorization
202 * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
203 * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
204 * \param use_software_pipeline Whether use the software pipeline.
205 * \return The schedule rule created
206 */
207 TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
208 Array<Map<String, String>> intrin_groups, String structure,
209 Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
210 Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
211 Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
212
213 /*!
214 * \brief Extension of MultiLevelTiling for backends with wide vectors.
215 * The loop over the innermost spatial axis of the output buffer is always vectorized with the
216 * maximum vector length.
217 * \param structure The tiling structure. 'SSRSRS' is recommended.
218 * \param vector_length_in_bits The length of a vector register in bits.
219 * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
220 * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
221 * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
222 * \return The schedule rule created
223 */
224 TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
225 String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
226 Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
227
228 /*!
229 * \brief Create a rule: add-rfactor to some blocks if needed
230 * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
231 * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
232 * parallelism.
233 * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
234 * \return The schedule rule created
235 */
236 TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
237 Optional<Integer> max_innermost_factor);
238 /*!
239 * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
240 * correspondingly when needed
241 * \param thread_extents Candidates of thread axis extent (values are required to be positive).
242 * \return The schedule rule created
243 */
244 TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
245 /*!
246 * \brief A rule that randomly select a compute-at location for a free block
247 * \return The schedule rule created
248 */
249 TVM_DLL static ScheduleRule RandomComputeLocation();
250 /*!
251 * \brief Mark parallelize, vectorize and unroll to the root block. The mark will be applied to
252 * each block in a follow-up post processor
253 * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
254 * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
255 * parallelism.
256 * \param max_vectorize_extent The maximum extent to be vectorized.
257 * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
258 * \param unroll_max_steps The options of the maximum number of unroll steps to be done.
259 * Use an empty array to disable unroll.
260 * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
261 * \return The schedule rule created
262 */
263 TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264 int max_vectorize_extent, //
265 Array<Integer> unroll_max_steps, //
266 bool unroll_explicit);
267 /*!
268 * \brief Auto bind loops around the block to BlockIdx and ThreadIdx
269 * \param max_threadblocks The maximum number of threadblock on GPU
270 * \param thread_extents Candidates of thread axis extent.
271 * \param max_threads_per_block The maximum number of threads per block, if it is known
272 * when this schedule rule is created.
273 * \return The schedule rule created
274 */
275 TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
276 int max_threads_per_block = -1);
277 /*!
278 * \brief Create a schedule rule with customized methods on the python-side.
279 * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
280 * \param f_apply The packed function of `Apply`.
281 * \param f_clone The packed function of `Clone`.
282 * \param f_as_string The packed function of `AsString`.
283 * \return The schedule rule created.
284 */
285 TVM_DLL static ScheduleRule PyScheduleRule(
286 FInitializeWithTuneContext f_initialize_with_tune_context, //
287 FApply f_apply, //
288 FClone f_clone, //
289 FAsString f_as_string);
290
291 /*! \brief Create default schedule rules for LLVM */
292 TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
293 /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */
294 TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
295 /*! \brief Create default schedule rules for CUDA */
296 TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
297 /*! \brief Create default postprocessors for CUDA with TensorCore */
298 TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore();
299 /*! \brief Create default schedule rules for Hexagon */
300 TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
301 /*! \brief Create default schedule rules for Micro */
302 TVM_DLL static Array<ScheduleRule, void> DefaultMicro();
303
304 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
305};
306
307/*! \brief The schedule rule with customized methods on the python-side. */
308class PyScheduleRuleNode : public ScheduleRuleNode {
309 public:
310 using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
311 using FApply = ScheduleRule::FApply;
312 using FClone = ScheduleRule::FClone;
313 using FAsString = ScheduleRule::FAsString;
314
315 /*! \brief The packed function to the `InitializeWithTuneContext` function. */
316 FInitializeWithTuneContext f_initialize_with_tune_context;
317 /*! \brief The packed function to the `Apply` function. */
318 FApply f_apply;
319 /*! \brief The packed function to the `AsString` function. */
320 FAsString f_as_string;
321 /*! \brief The packed function to the `Clone` function. */
322 FClone f_clone;
323
324 void VisitAttrs(tvm::AttrVisitor* v) {
325 // `f_initialize_with_tune_context` is not visited
326 // `f_apply` is not visited
327 // `f_as_string` is not visited
328 // `f_clone` is not visited
329 }
330
331 void InitializeWithTuneContext(const TuneContext& context) final;
332 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
333 ScheduleRule Clone() const final;
334
335 static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
336 TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
337};
338
339} // namespace meta_schedule
340} // namespace tvm
341
342#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
343