1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24void PyPostprocNode::InitializeWithTuneContext(const TuneContext& context) {
25 ICHECK(f_initialize_with_tune_context != nullptr)
26 << "PyPostproc's InitializeWithTuneContext method not implemented!";
27 f_initialize_with_tune_context(context);
28}
29
30bool PyPostprocNode::Apply(const tir::Schedule& sch) {
31 ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
32 return f_apply(sch);
33}
34
35Postproc PyPostprocNode::Clone() const {
36 ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!";
37 return f_clone();
38}
39
40Postproc Postproc::PyPostproc(
41 PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
42 PyPostprocNode::FApply f_apply, //
43 PyPostprocNode::FClone f_clone, //
44 PyPostprocNode::FAsString f_as_string) {
45 ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
46 n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
47 n->f_apply = std::move(f_apply);
48 n->f_clone = std::move(f_clone);
49 n->f_as_string = std::move(f_as_string);
50 return Postproc(n);
51}
52
53Array<Postproc> Postproc::DefaultLLVM() {
54 return Array<Postproc>{
55 Postproc::DisallowDynamicLoop(),
56 Postproc::RewriteParallelVectorizeUnroll(),
57 Postproc::RewriteReductionBlock(),
58 Postproc::RewriteLayout(),
59 };
60}
61
62Array<Postproc> Postproc::DefaultCPUTensorization() {
63 return Array<Postproc>{
64 Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
65 Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true),
66 Postproc::RewriteLayout(),
67 };
68}
69
70Array<Postproc> Postproc::DefaultCUDA() {
71 return Array<Postproc>{
72 Postproc::DisallowDynamicLoop(),
73 Postproc::RewriteCooperativeFetch(),
74 Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256),
75 Postproc::RewriteParallelVectorizeUnroll(),
76 Postproc::RewriteReductionBlock(),
77 Postproc::VerifyGPUCode(),
78 };
79}
80
81Array<Postproc> Postproc::DefaultCUDATensorCore() {
82 return Array<Postproc>{
83 Postproc::DisallowDynamicLoop(),
84 Postproc::RewriteCooperativeFetch(),
85 Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256),
86 Postproc::RewriteParallelVectorizeUnroll(),
87 Postproc::RewriteReductionBlock(),
88 Postproc::VerifyGPUCode(),
89 // RewriteTensorize is relatively expensive and it doesn't affect the validity of a sample, so
90 // run it only on samples that have passed VerifyGPUCode.
91 Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
92 };
93}
94
95Array<Postproc> Postproc::DefaultHexagon() {
96 return Array<Postproc>{
97 Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
98 Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(),
99 Postproc::VerifyVTCMLimit(),
100 };
101}
102
103Array<Postproc> Postproc::DefaultMicro() {
104 return Array<Postproc>{
105 Postproc::DisallowDynamicLoop(),
106 Postproc::RewriteParallelVectorizeUnroll(),
107 Postproc::RewriteReductionBlock(),
108 };
109}
110
111TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
112 .set_dispatch<PyPostprocNode>([](const ObjectRef& n, ReprPrinter* p) {
113 const auto* self = n.as<PyPostprocNode>();
114 ICHECK(self);
115 PyPostprocNode::FAsString f_as_string = (*self).f_as_string;
116 ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!";
117 p->stream << f_as_string();
118 });
119
120TVM_REGISTER_OBJECT_TYPE(PostprocNode);
121TVM_REGISTER_NODE_TYPE(PyPostprocNode);
122
123TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
124 .set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
125TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
126TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method<Postproc>(&PostprocNode::Clone);
127TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
128TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM);
129TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA);
130TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore")
131 .set_body_typed(Postproc::DefaultCUDATensorCore);
132TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon")
133 .set_body_typed(Postproc::DefaultHexagon);
134
135} // namespace meta_schedule
136} // namespace tvm
137