1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_POSTPROC_H_ |
21 | #define TVM_META_SCHEDULE_POSTPROC_H_ |
22 | |
23 | #include <tvm/node/reflection.h> |
24 | #include <tvm/runtime/object.h> |
25 | #include <tvm/runtime/packed_func.h> |
26 | #include <tvm/tir/schedule/schedule.h> |
27 | |
28 | namespace tvm { |
29 | namespace meta_schedule { |
30 | |
31 | class TuneContext; |
32 | class Postproc; |
33 | |
34 | /*! |
35 | * \brief Rules to apply a postprocessor to a schedule. |
36 | */ |
37 | class PostprocNode : public runtime::Object { |
38 | public: |
39 | /*! \brief Virtual destructor. */ |
40 | virtual ~PostprocNode() = default; |
41 | |
42 | void VisitAttrs(tvm::AttrVisitor* v) {} |
43 | |
44 | /*! |
45 | * \brief Initialize the design space generator with tuning context. |
46 | * \param context The tuning context for initialization. |
47 | * \note This method is supposed to be called only once before every other method. |
48 | */ |
49 | virtual void InitializeWithTuneContext(const TuneContext& context) = 0; |
50 | |
51 | /*! |
52 | * \brief Apply a postprocessor to the given schedule. |
53 | * \param sch The schedule to be post processed. |
54 | * \return Whether the postprocessor was successfully applied. |
55 | */ |
56 | virtual bool Apply(const tir::Schedule& sch) = 0; |
57 | |
58 | /*! |
59 | * \brief Clone the postprocessor. |
60 | * \return The cloned postprocessor. |
61 | */ |
62 | virtual Postproc Clone() const = 0; |
63 | |
64 | static constexpr const char* _type_key = "meta_schedule.Postproc" ; |
65 | TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); |
66 | }; |
67 | |
68 | /*! |
69 | * \brief Managed reference to PostprocNode |
70 | * \sa PostprocNode |
71 | */ |
72 | class Postproc : public runtime::ObjectRef { |
73 | public: |
74 | /*! |
75 | * \brief The function type of `InitializeWithTuneContext` method. |
76 | * \param context The tuning context for initialization. |
77 | */ |
78 | using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>; |
79 | /*! |
80 | * \brief Apply a postprocessor to the given schedule. |
81 | * \param sch The schedule to be post processed. |
82 | * \return Whether the postprocessor was successfully applied. |
83 | */ |
84 | using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>; |
85 | /*! |
86 | * \brief Clone the postprocessor. |
87 | * \return The cloned postprocessor. |
88 | */ |
89 | using FClone = runtime::TypedPackedFunc<Postproc()>; |
90 | /*! |
91 | * \brief Get the postprocessor function as string with name. |
92 | * \return The string of the postprocessor function. |
93 | */ |
94 | using FAsString = runtime::TypedPackedFunc<String()>; |
95 | /*! |
96 | * \brief Create a postprocessor with customized methods on the python-side. |
97 | * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. |
98 | * \param f_apply The packed function of `Apply`. |
99 | * \param f_clone The packed function of `Clone`. |
100 | * \param f_as_string The packed function of `AsString`. |
101 | * \return The postprocessor created. |
102 | */ |
103 | TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, // |
104 | FApply f_apply, // |
105 | FClone f_clone, // |
106 | FAsString f_as_string); |
107 | /*! |
108 | * \brief Create a postprocessor that checks if all loops are static |
109 | * \return The postprocessor created |
110 | */ |
111 | TVM_DLL static Postproc DisallowDynamicLoop(); |
112 | /*! |
113 | * \brief Create a postprocessor that checks if all async mem copies are not strided. |
114 | * \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope. |
115 | * \return The postprocessor created |
116 | */ |
117 | TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true); |
118 | /*! |
119 | * \brief Create a postprocessor that rewrites the cooperative fetch annotation to |
120 | * actual vectorized cooperative fetching in loop bindings. |
121 | * \return The postprocessor created. |
122 | */ |
123 | TVM_DLL static Postproc RewriteCooperativeFetch(); |
124 | /*! |
125 | * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling |
126 | * according to the annotation of each block |
127 | * \return The postprocessor created |
128 | */ |
129 | TVM_DLL static Postproc RewriteParallelVectorizeUnroll(); |
130 | /*! |
131 | * \brief Create a postprocessor that rewrites reduction block by moving the init block out. |
132 | * \return The postprocessor created. |
133 | */ |
134 | TVM_DLL static Postproc RewriteReductionBlock(); |
135 | /*! |
136 | * \brief Create a postprocessor that adds thread binding to unbound blocks |
137 | * \param max_threadblocks The max number of threadblocks in the cuda device. |
138 | * \return The postprocessor created. |
139 | */ |
140 | TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblocks); |
141 | /*! |
142 | * \brief Create a postprocessor that applies tensorization to annotated blocks |
143 | * \param vectorize_init_loop Whether or not vectorize the initialization loop produced by |
144 | * DecomposeReduction |
145 | * \return The postprocessor created. |
146 | */ |
147 | TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false); |
148 | /*! |
149 | * \brief Creates a postprocessor that verifies if the GPU code is correct |
150 | * \return The postprocessor created |
151 | */ |
152 | TVM_DLL static Postproc VerifyGPUCode(); |
153 | /*! |
154 | * \brief Verifies that the VTCM usage of a given schedule is within the provided limit. |
155 | * \return The postprocessor created |
156 | */ |
157 | TVM_DLL static Postproc VerifyVTCMLimit(); |
158 | /*! |
159 | * \brief Creates a postprocessor that rewrites the layout of input tensor |
160 | * \note Weight layout rewrite is supported so far, activation layout rewrite will be added. |
161 | * \return The postprocessor created |
162 | */ |
163 | TVM_DLL static Postproc RewriteLayout(); |
164 | /*! \brief Create default postprocessors for LLVM */ |
165 | TVM_DLL static Array<Postproc, void> DefaultLLVM(); |
166 | /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */ |
167 | TVM_DLL static Array<Postproc, void> DefaultCPUTensorization(); |
168 | /*! \brief Create default postprocessors for CUDA */ |
169 | TVM_DLL static Array<Postproc, void> DefaultCUDA(); |
170 | /*! \brief Create default postprocessors for CUDA with TensorCore */ |
171 | TVM_DLL static Array<Postproc, void> DefaultCUDATensorCore(); |
172 | /*! \brief Create default postprocessors for Hexagon */ |
173 | TVM_DLL static Array<Postproc, void> DefaultHexagon(); |
174 | /*! \brief Create default postprocessors for Micro */ |
175 | TVM_DLL static Array<Postproc, void> DefaultMicro(); |
176 | |
177 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); |
178 | }; |
179 | |
180 | /*! \brief The postprocessor with customized methods on the python-side. */ |
181 | class PyPostprocNode : public PostprocNode { |
182 | public: |
183 | using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext; |
184 | using FApply = Postproc::FApply; |
185 | using FClone = Postproc::FClone; |
186 | using FAsString = Postproc::FAsString; |
187 | /*! \brief The packed function to the `InitializeWithTuneContext` function. */ |
188 | FInitializeWithTuneContext f_initialize_with_tune_context; |
189 | /*! \brief The packed function to the `Apply` function. */ |
190 | FApply f_apply; |
191 | /*! \brief The packed function to the `Clone` function. */ |
192 | FClone f_clone; |
193 | /*! \brief The packed function to the `AsString` function. */ |
194 | FAsString f_as_string; |
195 | |
196 | void VisitAttrs(tvm::AttrVisitor* v) { |
197 | // `f_initialize_with_tune_context` is not visited |
198 | // `f_apply` is not visited |
199 | // `f_clone` is not visited |
200 | // `f_as_string` is not visited |
201 | } |
202 | |
203 | void InitializeWithTuneContext(const TuneContext& context) final; |
204 | bool Apply(const tir::Schedule& sch) final; |
205 | Postproc Clone() const final; |
206 | |
207 | static constexpr const char* _type_key = "meta_schedule.PyPostproc" ; |
208 | TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); |
209 | }; |
210 | |
211 | } // namespace meta_schedule |
212 | } // namespace tvm |
213 | |
214 | #endif // TVM_META_SCHEDULE_POSTPROC_H_ |
215 | |