1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { |
25 | ICHECK(f_initialize_with_tune_context != nullptr) |
26 | << "PyScheduleRule's InitializeWithTuneContext method not implemented!" ; |
27 | f_initialize_with_tune_context(context); |
28 | } |
29 | |
30 | Array<tir::Schedule> PyScheduleRuleNode::Apply(const tir::Schedule& sch, |
31 | const tir::BlockRV& block) { |
32 | ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!" ; |
33 | return f_apply(sch, block); |
34 | } |
35 | |
36 | ScheduleRule PyScheduleRuleNode::Clone() const { |
37 | ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!" ; |
38 | return f_clone(); |
39 | } |
40 | |
41 | ScheduleRule ScheduleRule::PyScheduleRule( |
42 | PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // |
43 | PyScheduleRuleNode::FApply f_apply, // |
44 | PyScheduleRuleNode::FClone f_clone, // |
45 | PyScheduleRuleNode::FAsString f_as_string) { |
46 | ObjectPtr<PyScheduleRuleNode> n = make_object<PyScheduleRuleNode>(); |
47 | n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); |
48 | n->f_apply = std::move(f_apply); |
49 | n->f_clone = std::move(f_clone); |
50 | n->f_as_string = std::move(f_as_string); |
51 | return ScheduleRule(n); |
52 | } |
53 | |
54 | Array<ScheduleRule> ScheduleRule::DefaultLLVM() { |
55 | return { |
56 | ScheduleRule::ApplyCustomRule(), |
57 | ScheduleRule::InlineConstantScalars(), |
58 | ScheduleRule::AutoInline( |
59 | /*into_producer=*/false, |
60 | /*into_consumer=*/true, |
61 | /*inline_const_tensor=*/true, |
62 | /*disallow_if_then_else=*/true, |
63 | /*require_injective=*/true, |
64 | /*require_ordered=*/true, |
65 | /*disallow_op=*/Array<String>{"tir.exp" }), |
66 | ScheduleRule::AddRFactor( |
67 | /*max_jobs_per_core=*/16, |
68 | /*max_innermost_factor=*/Integer(64)), |
69 | ScheduleRule::MultiLevelTiling( |
70 | /*structure=*/"SSRSRS" , |
71 | /*tile_binds=*/NullOpt, |
72 | /*max_innermost_factor=*/Integer(64), |
73 | /*vector_load_lens=*/NullOpt, |
74 | /*reuse_read=*/NullOpt, |
75 | /*reuse_write=*/ |
76 | Map<String, ObjectRef>{{"req" , String("may" )}, |
77 | {"levels" , Array<Integer>{1, 2}}, |
78 | {"scope" , String("global" )}}), |
79 | ScheduleRule::ParallelizeVectorizeUnroll( |
80 | /*max_jobs_per_core=*/16, |
81 | /*max_vectorize_extent=*/64, |
82 | /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512}, |
83 | /*unroll_explicit=*/true), |
84 | ScheduleRule::RandomComputeLocation(), |
85 | }; |
86 | } |
87 | |
88 | Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) { |
89 | static const Map<String, String> intrins = {{"vnni" , "dot_16x4_vnni" }, |
90 | {"avx512" , "dot_16x4_avx512" }}; |
91 | return { |
92 | ScheduleRule::ApplyCustomRule(), |
93 | ScheduleRule::InlineConstantScalars(), |
94 | ScheduleRule::AutoInline( |
95 | /*into_producer=*/false, |
96 | /*into_consumer=*/true, |
97 | /*inline_const_tensor=*/true, |
98 | /*disallow_if_then_else=*/true, |
99 | /*require_injective=*/true, |
100 | /*require_ordered=*/true, |
101 | /*disallow_op=*/Array<String>{"tir.exp" }), |
102 | ScheduleRule::AddRFactor( |
103 | /*max_jobs_per_core=*/16, |
104 | /*max_innermost_factor=*/Integer(64)), |
105 | ScheduleRule::MultiLevelTilingWithIntrin( |
106 | /*intrin_name=*/intrins[type], |
107 | /*structure=*/"SSRSRS" , |
108 | /*tile_binds=*/NullOpt, |
109 | /*max_innermost_factor=*/Integer(64), |
110 | /*vector_load_lens=*/NullOpt, |
111 | /*reuse_read=*/NullOpt, |
112 | /*reuse_write=*/ |
113 | Map<String, ObjectRef>{{"req" , String("may" )}, |
114 | {"levels" , Array<Integer>{1, 2}}, |
115 | {"scope" , String("global" )}}), |
116 | ScheduleRule::MultiLevelTiling( |
117 | /*structure=*/"SSRSRS" , |
118 | /*tile_binds=*/NullOpt, |
119 | /*max_innermost_factor=*/Integer(64), |
120 | /*vector_load_lens=*/NullOpt, |
121 | /*reuse_read=*/NullOpt, |
122 | /*reuse_write=*/ |
123 | Map<String, ObjectRef>{{"req" , String("may" )}, |
124 | {"levels" , Array<Integer>{1, 2}}, |
125 | {"scope" , String("global" )}}), |
126 | ScheduleRule::ParallelizeVectorizeUnroll( |
127 | /*max_jobs_per_core=*/16, |
128 | /*max_vectorize_extent=*/64, |
129 | /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512}, |
130 | /*unroll_explicit=*/true), |
131 | ScheduleRule::RandomComputeLocation(), |
132 | }; |
133 | } |
134 | |
135 | Array<ScheduleRule> ScheduleRule::DefaultCUDA() { |
136 | return { |
137 | ScheduleRule::ApplyCustomRule(), |
138 | ScheduleRule::MultiLevelTiling( |
139 | /*structure=*/"SSSRRSRS" , |
140 | /*tile_binds=*/Array<String>{"blockIdx.x" , "vthread.x" , "threadIdx.x" }, |
141 | /*max_innermost_factor=*/Integer(64), |
142 | /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, |
143 | /*reuse_read=*/ |
144 | Map<String, ObjectRef>{{"req" , String("must" )}, |
145 | {"levels" , Array<Integer>{4}}, // |
146 | {"scope" , String("shared" )}}, |
147 | /*reuse_write=*/ |
148 | Map<String, ObjectRef>{{"req" , String("must" )}, |
149 | {"levels" , Array<Integer>{3}}, // |
150 | {"scope" , String("local" )}}), |
151 | ScheduleRule::InlineConstantScalars(), |
152 | ScheduleRule::AutoInline( |
153 | /*into_producer=*/true, |
154 | /*into_consumer=*/true, |
155 | /*inline_const_tensor=*/true, |
156 | /*disallow_if_then_else=*/false, |
157 | /*require_injective=*/false, |
158 | /*require_ordered=*/false, |
159 | /*disallow_op=*/Array<String>{}), |
160 | ScheduleRule::CrossThreadReduction( |
161 | /*thread_extents=*/Array<Integer>{4, 8, 16, 32, 64, 128, 256, 512}), |
162 | ScheduleRule::ParallelizeVectorizeUnroll( |
163 | /*max_jobs_per_core=*/-1, |
164 | /*max_vectorize_extent=*/-1, |
165 | /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512, 1024}, |
166 | /*unroll_explicit=*/true), |
167 | ScheduleRule::AutoBind( |
168 | /*max_threadblocks=*/256, |
169 | /*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024}), |
170 | }; |
171 | } |
172 | |
173 | Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { |
174 | Array<Map<String, String>> intrin_groups = { |
175 | // Tensor Cores f32 += f16 * f16 |
176 | { |
177 | {"init" , "wmma_fill_16x16x16_f32" }, |
178 | {"load_a" , "wmma_load_16x16x16_f16_a_shared_dyn" }, |
179 | {"load_b" , "wmma_load_16x16x16_f16_b_shared_dyn" }, |
180 | {"compute" , "wmma_sync_16x16x16_f16f16f32" }, |
181 | {"store" , "wmma_store_16x16x16_f32_shared_dyn" }, |
182 | }, |
183 | { |
184 | {"init" , "wmma_fill_16x16x16_f32" }, |
185 | {"load_a" , "wmma_load_16x16x16_f16_a_shared_dyn" }, |
186 | {"load_b" , "wmma_load_16x16x16_f16_b_trans_shared_dyn" }, |
187 | {"compute" , "wmma_sync_16x16x16_f16f16f32_trans" }, |
188 | {"store" , "wmma_store_16x16x16_f32_shared_dyn" }, |
189 | }, |
190 | // Tensor Cores f16 += f16 * f16 |
191 | { |
192 | {"init" , "wmma_fill_16x16x16_f16" }, |
193 | {"load_a" , "wmma_load_16x16x16_f16_a_shared_dyn" }, |
194 | {"load_b" , "wmma_load_16x16x16_f16_b_shared_dyn" }, |
195 | {"compute" , "wmma_sync_16x16x16_f16f16f16" }, |
196 | {"store" , "wmma_store_16x16x16_f16_shared_dyn" }, |
197 | }, |
198 | { |
199 | {"init" , "wmma_fill_16x16x16_f16" }, |
200 | {"load_a" , "wmma_load_16x16x16_f16_a_shared_dyn" }, |
201 | {"load_b" , "wmma_load_16x16x16_f16_b_trans_shared_dyn" }, |
202 | {"compute" , "wmma_sync_16x16x16_f16f16f16_trans" }, |
203 | {"store" , "wmma_store_16x16x16_f16_shared_dyn" }, |
204 | }, |
205 | // Tensor Cores s32 += s8 * s8 |
206 | { |
207 | {"init" , "wmma_fill_16x16x16_s32" }, |
208 | {"load_a" , "wmma_load_16x16x16_s8_a_shared_dyn" }, |
209 | {"load_b" , "wmma_load_16x16x16_s8_b_shared_dyn" }, |
210 | {"compute" , "wmma_sync_16x16x16_s8s8s32" }, |
211 | {"store" , "wmma_store_16x16x16_s32_shared_dyn" }, |
212 | }, |
213 | { |
214 | {"init" , "wmma_fill_16x16x16_s32" }, |
215 | {"load_a" , "wmma_load_16x16x16_s8_a_shared_dyn" }, |
216 | {"load_b" , "wmma_load_16x16x16_s8_b_trans_shared_dyn" }, |
217 | {"compute" , "wmma_sync_16x16x16_s8s8s32_trans" }, |
218 | {"store" , "wmma_store_16x16x16_s32_shared_dyn" }, |
219 | }, |
220 | }; |
221 | Array<ScheduleRule> results{ |
222 | ScheduleRule::ApplyCustomRule(), |
223 | ScheduleRule::MultiLevelTilingTensorCore( |
224 | /*intrin_groups=*/intrin_groups, |
225 | /*structure=*/"SSSRRSRS" , |
226 | /*tile_binds=*/Array<String>{"blockIdx.y" , "blockIdx.x" , "threadIdx.y" }, |
227 | /*max_innermost_factor=*/Integer(4), |
228 | /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, |
229 | /*reuse_read=*/ |
230 | Map<String, ObjectRef>{{"req" , String("must" )}, |
231 | {"levels" , Array<Integer>{4}}, // |
232 | {"scope" , String("shared.dyn" )}}, |
233 | /*reuse_write=*/ |
234 | Map<String, ObjectRef>{{"req" , String("must" )}, |
235 | {"levels" , Array<Integer>{2}}, // |
236 | {"scope" , String("shared.dyn" )}}, |
237 | /*use_software_pipeline=*/false) // |
238 | }; |
239 | Array<ScheduleRule> append = ScheduleRule::DefaultCUDA(); |
240 | results.insert(results.end(), append.begin() + 1, append.end()); |
241 | return results; |
242 | } |
243 | |
244 | Array<ScheduleRule> ScheduleRule::DefaultHexagon() { |
245 | return { |
246 | ScheduleRule::ApplyCustomRule(), |
247 | ScheduleRule::InlineConstantScalars(), |
248 | ScheduleRule::AutoInline( |
249 | /*into_producer=*/false, |
250 | /*into_consumer=*/true, |
251 | /*inline_const_tensor=*/true, |
252 | /*disallow_if_then_else=*/true, |
253 | /*require_injective=*/true, |
254 | /*require_ordered=*/true, |
255 | /*disallow_op=*/Array<String>{"tir.exp" }), |
256 | ScheduleRule::MultiLevelTilingWideVector( |
257 | /*structure=*/"SRSRS" , |
258 | /*vector_length_in_bits=*/1024, |
259 | /*max_innermost_factor=*/Integer(128), |
260 | /*reuse_read=*/NullOpt, |
261 | /*reuse_write=*/ |
262 | Map<String, ObjectRef>{{"req" , String("may" )}, |
263 | {"levels" , Array<Integer>{1, 2}}, |
264 | {"scope" , String("global" )}}), |
265 | ScheduleRule::ParallelizeVectorizeUnroll( |
266 | /*max_jobs_per_core=*/16, |
267 | /*max_vectorize_extent=*/128, |
268 | /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512}, |
269 | /*unroll_explicit=*/true), |
270 | }; |
271 | } |
272 | |
273 | Array<ScheduleRule> ScheduleRule::DefaultMicro() { |
274 | return { |
275 | ScheduleRule::ApplyCustomRule(), |
276 | ScheduleRule::InlineConstantScalars(), |
277 | ScheduleRule::AutoInline( |
278 | /*into_producer=*/false, |
279 | /*into_consumer=*/true, |
280 | /*inline_const_tensor=*/true, |
281 | /*disallow_if_then_else=*/true, |
282 | /*require_injective=*/true, |
283 | /*require_ordered=*/true, |
284 | /*disallow_op=*/Array<String>{"tir.exp" }), |
285 | ScheduleRule::MultiLevelTiling( |
286 | /*structure=*/"SSRSRS" , |
287 | /*tile_binds=*/NullOpt, |
288 | /*max_innermost_factor=*/Integer(64), |
289 | /*vector_load_lens=*/NullOpt, |
290 | /*reuse_read=*/NullOpt, |
291 | /*reuse_write=*/ |
292 | Map<String, ObjectRef>{{"req" , String("may" )}, |
293 | {"levels" , Array<Integer>{1, 2}}, |
294 | {"scope" , String("global" )}}), |
295 | }; |
296 | } |
297 | |
298 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
299 | .set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) { |
300 | const auto* self = n.as<PyScheduleRuleNode>(); |
301 | ICHECK(self); |
302 | PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; |
303 | ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!" ; |
304 | p->stream << f_as_string(); |
305 | }); |
306 | |
307 | TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); |
308 | TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); |
309 | |
310 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext" ) |
311 | .set_body_method<ScheduleRule>(&ScheduleRuleNode::InitializeWithTuneContext); |
312 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply" ) |
313 | .set_body_method<ScheduleRule>(&ScheduleRuleNode::Apply); |
314 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone" ) |
315 | .set_body_method<ScheduleRule>(&ScheduleRuleNode::Clone); |
316 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule" ) |
317 | .set_body_typed(ScheduleRule::PyScheduleRule); |
318 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM" ) |
319 | .set_body_typed(ScheduleRule::DefaultLLVM); |
320 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA" ) |
321 | .set_body_typed(ScheduleRule::DefaultCUDA); |
322 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore" ) |
323 | .set_body_typed(ScheduleRule::DefaultCUDATensorCore); |
324 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon" ) |
325 | .set_body_typed(ScheduleRule::DefaultHexagon); |
326 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro" ) |
327 | .set_body_typed(ScheduleRule::DefaultMicro); |
328 | |
329 | } // namespace meta_schedule |
330 | } // namespace tvm |
331 | |