1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) {
25 ICHECK(f_initialize_with_tune_context != nullptr)
26 << "PyScheduleRule's InitializeWithTuneContext method not implemented!";
27 f_initialize_with_tune_context(context);
28}
29
30Array<tir::Schedule> PyScheduleRuleNode::Apply(const tir::Schedule& sch,
31 const tir::BlockRV& block) {
32 ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
33 return f_apply(sch, block);
34}
35
36ScheduleRule PyScheduleRuleNode::Clone() const {
37 ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!";
38 return f_clone();
39}
40
41ScheduleRule ScheduleRule::PyScheduleRule(
42 PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
43 PyScheduleRuleNode::FApply f_apply, //
44 PyScheduleRuleNode::FClone f_clone, //
45 PyScheduleRuleNode::FAsString f_as_string) {
46 ObjectPtr<PyScheduleRuleNode> n = make_object<PyScheduleRuleNode>();
47 n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
48 n->f_apply = std::move(f_apply);
49 n->f_clone = std::move(f_clone);
50 n->f_as_string = std::move(f_as_string);
51 return ScheduleRule(n);
52}
53
54Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
55 return {
56 ScheduleRule::ApplyCustomRule(),
57 ScheduleRule::InlineConstantScalars(),
58 ScheduleRule::AutoInline(
59 /*into_producer=*/false,
60 /*into_consumer=*/true,
61 /*inline_const_tensor=*/true,
62 /*disallow_if_then_else=*/true,
63 /*require_injective=*/true,
64 /*require_ordered=*/true,
65 /*disallow_op=*/Array<String>{"tir.exp"}),
66 ScheduleRule::AddRFactor(
67 /*max_jobs_per_core=*/16,
68 /*max_innermost_factor=*/Integer(64)),
69 ScheduleRule::MultiLevelTiling(
70 /*structure=*/"SSRSRS",
71 /*tile_binds=*/NullOpt,
72 /*max_innermost_factor=*/Integer(64),
73 /*vector_load_lens=*/NullOpt,
74 /*reuse_read=*/NullOpt,
75 /*reuse_write=*/
76 Map<String, ObjectRef>{{"req", String("may")},
77 {"levels", Array<Integer>{1, 2}},
78 {"scope", String("global")}}),
79 ScheduleRule::ParallelizeVectorizeUnroll(
80 /*max_jobs_per_core=*/16,
81 /*max_vectorize_extent=*/64,
82 /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
83 /*unroll_explicit=*/true),
84 ScheduleRule::RandomComputeLocation(),
85 };
86}
87
88Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) {
89 static const Map<String, String> intrins = {{"vnni", "dot_16x4_vnni"},
90 {"avx512", "dot_16x4_avx512"}};
91 return {
92 ScheduleRule::ApplyCustomRule(),
93 ScheduleRule::InlineConstantScalars(),
94 ScheduleRule::AutoInline(
95 /*into_producer=*/false,
96 /*into_consumer=*/true,
97 /*inline_const_tensor=*/true,
98 /*disallow_if_then_else=*/true,
99 /*require_injective=*/true,
100 /*require_ordered=*/true,
101 /*disallow_op=*/Array<String>{"tir.exp"}),
102 ScheduleRule::AddRFactor(
103 /*max_jobs_per_core=*/16,
104 /*max_innermost_factor=*/Integer(64)),
105 ScheduleRule::MultiLevelTilingWithIntrin(
106 /*intrin_name=*/intrins[type],
107 /*structure=*/"SSRSRS",
108 /*tile_binds=*/NullOpt,
109 /*max_innermost_factor=*/Integer(64),
110 /*vector_load_lens=*/NullOpt,
111 /*reuse_read=*/NullOpt,
112 /*reuse_write=*/
113 Map<String, ObjectRef>{{"req", String("may")},
114 {"levels", Array<Integer>{1, 2}},
115 {"scope", String("global")}}),
116 ScheduleRule::MultiLevelTiling(
117 /*structure=*/"SSRSRS",
118 /*tile_binds=*/NullOpt,
119 /*max_innermost_factor=*/Integer(64),
120 /*vector_load_lens=*/NullOpt,
121 /*reuse_read=*/NullOpt,
122 /*reuse_write=*/
123 Map<String, ObjectRef>{{"req", String("may")},
124 {"levels", Array<Integer>{1, 2}},
125 {"scope", String("global")}}),
126 ScheduleRule::ParallelizeVectorizeUnroll(
127 /*max_jobs_per_core=*/16,
128 /*max_vectorize_extent=*/64,
129 /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
130 /*unroll_explicit=*/true),
131 ScheduleRule::RandomComputeLocation(),
132 };
133}
134
135Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
136 return {
137 ScheduleRule::ApplyCustomRule(),
138 ScheduleRule::MultiLevelTiling(
139 /*structure=*/"SSSRRSRS",
140 /*tile_binds=*/Array<String>{"blockIdx.x", "vthread.x", "threadIdx.x"},
141 /*max_innermost_factor=*/Integer(64),
142 /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
143 /*reuse_read=*/
144 Map<String, ObjectRef>{{"req", String("must")},
145 {"levels", Array<Integer>{4}}, //
146 {"scope", String("shared")}},
147 /*reuse_write=*/
148 Map<String, ObjectRef>{{"req", String("must")},
149 {"levels", Array<Integer>{3}}, //
150 {"scope", String("local")}}),
151 ScheduleRule::InlineConstantScalars(),
152 ScheduleRule::AutoInline(
153 /*into_producer=*/true,
154 /*into_consumer=*/true,
155 /*inline_const_tensor=*/true,
156 /*disallow_if_then_else=*/false,
157 /*require_injective=*/false,
158 /*require_ordered=*/false,
159 /*disallow_op=*/Array<String>{}),
160 ScheduleRule::CrossThreadReduction(
161 /*thread_extents=*/Array<Integer>{4, 8, 16, 32, 64, 128, 256, 512}),
162 ScheduleRule::ParallelizeVectorizeUnroll(
163 /*max_jobs_per_core=*/-1,
164 /*max_vectorize_extent=*/-1,
165 /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512, 1024},
166 /*unroll_explicit=*/true),
167 ScheduleRule::AutoBind(
168 /*max_threadblocks=*/256,
169 /*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024}),
170 };
171}
172
173Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
174 Array<Map<String, String>> intrin_groups = {
175 // Tensor Cores f32 += f16 * f16
176 {
177 {"init", "wmma_fill_16x16x16_f32"},
178 {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
179 {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
180 {"compute", "wmma_sync_16x16x16_f16f16f32"},
181 {"store", "wmma_store_16x16x16_f32_shared_dyn"},
182 },
183 {
184 {"init", "wmma_fill_16x16x16_f32"},
185 {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
186 {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
187 {"compute", "wmma_sync_16x16x16_f16f16f32_trans"},
188 {"store", "wmma_store_16x16x16_f32_shared_dyn"},
189 },
190 // Tensor Cores f16 += f16 * f16
191 {
192 {"init", "wmma_fill_16x16x16_f16"},
193 {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
194 {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
195 {"compute", "wmma_sync_16x16x16_f16f16f16"},
196 {"store", "wmma_store_16x16x16_f16_shared_dyn"},
197 },
198 {
199 {"init", "wmma_fill_16x16x16_f16"},
200 {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
201 {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
202 {"compute", "wmma_sync_16x16x16_f16f16f16_trans"},
203 {"store", "wmma_store_16x16x16_f16_shared_dyn"},
204 },
205 // Tensor Cores s32 += s8 * s8
206 {
207 {"init", "wmma_fill_16x16x16_s32"},
208 {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
209 {"load_b", "wmma_load_16x16x16_s8_b_shared_dyn"},
210 {"compute", "wmma_sync_16x16x16_s8s8s32"},
211 {"store", "wmma_store_16x16x16_s32_shared_dyn"},
212 },
213 {
214 {"init", "wmma_fill_16x16x16_s32"},
215 {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
216 {"load_b", "wmma_load_16x16x16_s8_b_trans_shared_dyn"},
217 {"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
218 {"store", "wmma_store_16x16x16_s32_shared_dyn"},
219 },
220 };
221 Array<ScheduleRule> results{
222 ScheduleRule::ApplyCustomRule(),
223 ScheduleRule::MultiLevelTilingTensorCore(
224 /*intrin_groups=*/intrin_groups,
225 /*structure=*/"SSSRRSRS",
226 /*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"},
227 /*max_innermost_factor=*/Integer(4),
228 /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
229 /*reuse_read=*/
230 Map<String, ObjectRef>{{"req", String("must")},
231 {"levels", Array<Integer>{4}}, //
232 {"scope", String("shared.dyn")}},
233 /*reuse_write=*/
234 Map<String, ObjectRef>{{"req", String("must")},
235 {"levels", Array<Integer>{2}}, //
236 {"scope", String("shared.dyn")}},
237 /*use_software_pipeline=*/false) //
238 };
239 Array<ScheduleRule> append = ScheduleRule::DefaultCUDA();
240 results.insert(results.end(), append.begin() + 1, append.end());
241 return results;
242}
243
244Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
245 return {
246 ScheduleRule::ApplyCustomRule(),
247 ScheduleRule::InlineConstantScalars(),
248 ScheduleRule::AutoInline(
249 /*into_producer=*/false,
250 /*into_consumer=*/true,
251 /*inline_const_tensor=*/true,
252 /*disallow_if_then_else=*/true,
253 /*require_injective=*/true,
254 /*require_ordered=*/true,
255 /*disallow_op=*/Array<String>{"tir.exp"}),
256 ScheduleRule::MultiLevelTilingWideVector(
257 /*structure=*/"SRSRS",
258 /*vector_length_in_bits=*/1024,
259 /*max_innermost_factor=*/Integer(128),
260 /*reuse_read=*/NullOpt,
261 /*reuse_write=*/
262 Map<String, ObjectRef>{{"req", String("may")},
263 {"levels", Array<Integer>{1, 2}},
264 {"scope", String("global")}}),
265 ScheduleRule::ParallelizeVectorizeUnroll(
266 /*max_jobs_per_core=*/16,
267 /*max_vectorize_extent=*/128,
268 /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
269 /*unroll_explicit=*/true),
270 };
271}
272
273Array<ScheduleRule> ScheduleRule::DefaultMicro() {
274 return {
275 ScheduleRule::ApplyCustomRule(),
276 ScheduleRule::InlineConstantScalars(),
277 ScheduleRule::AutoInline(
278 /*into_producer=*/false,
279 /*into_consumer=*/true,
280 /*inline_const_tensor=*/true,
281 /*disallow_if_then_else=*/true,
282 /*require_injective=*/true,
283 /*require_ordered=*/true,
284 /*disallow_op=*/Array<String>{"tir.exp"}),
285 ScheduleRule::MultiLevelTiling(
286 /*structure=*/"SSRSRS",
287 /*tile_binds=*/NullOpt,
288 /*max_innermost_factor=*/Integer(64),
289 /*vector_load_lens=*/NullOpt,
290 /*reuse_read=*/NullOpt,
291 /*reuse_write=*/
292 Map<String, ObjectRef>{{"req", String("may")},
293 {"levels", Array<Integer>{1, 2}},
294 {"scope", String("global")}}),
295 };
296}
297
298TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
299 .set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) {
300 const auto* self = n.as<PyScheduleRuleNode>();
301 ICHECK(self);
302 PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string;
303 ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!";
304 p->stream << f_as_string();
305 });
306
307TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode);
308TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode);
309
310TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext")
311 .set_body_method<ScheduleRule>(&ScheduleRuleNode::InitializeWithTuneContext);
312TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply")
313 .set_body_method<ScheduleRule>(&ScheduleRuleNode::Apply);
314TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone")
315 .set_body_method<ScheduleRule>(&ScheduleRuleNode::Clone);
316TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule")
317 .set_body_typed(ScheduleRule::PyScheduleRule);
318TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM")
319 .set_body_typed(ScheduleRule::DefaultLLVM);
320TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA")
321 .set_body_typed(ScheduleRule::DefaultCUDA);
322TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore")
323 .set_body_typed(ScheduleRule::DefaultCUDATensorCore);
324TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon")
325 .set_body_typed(ScheduleRule::DefaultHexagon);
326TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro")
327 .set_body_typed(ScheduleRule::DefaultMicro);
328
329} // namespace meta_schedule
330} // namespace tvm
331