1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef TVM_META_SCHEDULE_POSTPROC_H_
21#define TVM_META_SCHEDULE_POSTPROC_H_
22
23#include <tvm/node/reflection.h>
24#include <tvm/runtime/object.h>
25#include <tvm/runtime/packed_func.h>
26#include <tvm/tir/schedule/schedule.h>
27
28namespace tvm {
29namespace meta_schedule {
30
31class TuneContext;
32class Postproc;
33
34/*!
35 * \brief Rules to apply a postprocessor to a schedule.
36 */
37class PostprocNode : public runtime::Object {
38 public:
39 /*! \brief Virtual destructor. */
40 virtual ~PostprocNode() = default;
41
42 void VisitAttrs(tvm::AttrVisitor* v) {}
43
44 /*!
45 * \brief Initialize the design space generator with tuning context.
46 * \param context The tuning context for initialization.
47 * \note This method is supposed to be called only once before every other method.
48 */
49 virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
50
51 /*!
52 * \brief Apply a postprocessor to the given schedule.
53 * \param sch The schedule to be post processed.
54 * \return Whether the postprocessor was successfully applied.
55 */
56 virtual bool Apply(const tir::Schedule& sch) = 0;
57
58 /*!
59 * \brief Clone the postprocessor.
60 * \return The cloned postprocessor.
61 */
62 virtual Postproc Clone() const = 0;
63
64 static constexpr const char* _type_key = "meta_schedule.Postproc";
65 TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
66};
67
68/*!
69 * \brief Managed reference to PostprocNode
70 * \sa PostprocNode
71 */
72class Postproc : public runtime::ObjectRef {
73 public:
74 /*!
75 * \brief The function type of `InitializeWithTuneContext` method.
76 * \param context The tuning context for initialization.
77 */
78 using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
79 /*!
80 * \brief Apply a postprocessor to the given schedule.
81 * \param sch The schedule to be post processed.
82 * \return Whether the postprocessor was successfully applied.
83 */
84 using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
85 /*!
86 * \brief Clone the postprocessor.
87 * \return The cloned postprocessor.
88 */
89 using FClone = runtime::TypedPackedFunc<Postproc()>;
90 /*!
91 * \brief Get the postprocessor function as string with name.
92 * \return The string of the postprocessor function.
93 */
94 using FAsString = runtime::TypedPackedFunc<String()>;
95 /*!
96 * \brief Create a postprocessor with customized methods on the python-side.
97 * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
98 * \param f_apply The packed function of `Apply`.
99 * \param f_clone The packed function of `Clone`.
100 * \param f_as_string The packed function of `AsString`.
101 * \return The postprocessor created.
102 */
103 TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, //
104 FApply f_apply, //
105 FClone f_clone, //
106 FAsString f_as_string);
107 /*!
108 * \brief Create a postprocessor that checks if all loops are static
109 * \return The postprocessor created
110 */
111 TVM_DLL static Postproc DisallowDynamicLoop();
112 /*!
113 * \brief Create a postprocessor that checks if all async mem copies are not strided.
114 * \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
115 * \return The postprocessor created
116 */
117 TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
118 /*!
119 * \brief Create a postprocessor that rewrites the cooperative fetch annotation to
120 * actual vectorized cooperative fetching in loop bindings.
121 * \return The postprocessor created.
122 */
123 TVM_DLL static Postproc RewriteCooperativeFetch();
124 /*!
125 * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling
126 * according to the annotation of each block
127 * \return The postprocessor created
128 */
129 TVM_DLL static Postproc RewriteParallelVectorizeUnroll();
130 /*!
131 * \brief Create a postprocessor that rewrites reduction block by moving the init block out.
132 * \return The postprocessor created.
133 */
134 TVM_DLL static Postproc RewriteReductionBlock();
135 /*!
136 * \brief Create a postprocessor that adds thread binding to unbound blocks
137 * \param max_threadblocks The max number of threadblocks in the cuda device.
138 * \return The postprocessor created.
139 */
140 TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblocks);
141 /*!
142 * \brief Create a postprocessor that applies tensorization to annotated blocks
143 * \param vectorize_init_loop Whether or not vectorize the initialization loop produced by
144 * DecomposeReduction
145 * \return The postprocessor created.
146 */
147 TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);
148 /*!
149 * \brief Creates a postprocessor that verifies if the GPU code is correct
150 * \return The postprocessor created
151 */
152 TVM_DLL static Postproc VerifyGPUCode();
153 /*!
154 * \brief Verifies that the VTCM usage of a given schedule is within the provided limit.
155 * \return The postprocessor created
156 */
157 TVM_DLL static Postproc VerifyVTCMLimit();
158 /*!
159 * \brief Creates a postprocessor that rewrites the layout of input tensor
160 * \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
161 * \return The postprocessor created
162 */
163 TVM_DLL static Postproc RewriteLayout();
164 /*! \brief Create default postprocessors for LLVM */
165 TVM_DLL static Array<Postproc, void> DefaultLLVM();
166 /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
167 TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
168 /*! \brief Create default postprocessors for CUDA */
169 TVM_DLL static Array<Postproc, void> DefaultCUDA();
170 /*! \brief Create default postprocessors for CUDA with TensorCore */
171 TVM_DLL static Array<Postproc, void> DefaultCUDATensorCore();
172 /*! \brief Create default postprocessors for Hexagon */
173 TVM_DLL static Array<Postproc, void> DefaultHexagon();
174 /*! \brief Create default postprocessors for Micro */
175 TVM_DLL static Array<Postproc, void> DefaultMicro();
176
177 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
178};
179
180/*! \brief The postprocessor with customized methods on the python-side. */
181class PyPostprocNode : public PostprocNode {
182 public:
183 using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
184 using FApply = Postproc::FApply;
185 using FClone = Postproc::FClone;
186 using FAsString = Postproc::FAsString;
187 /*! \brief The packed function to the `InitializeWithTuneContext` function. */
188 FInitializeWithTuneContext f_initialize_with_tune_context;
189 /*! \brief The packed function to the `Apply` function. */
190 FApply f_apply;
191 /*! \brief The packed function to the `Clone` function. */
192 FClone f_clone;
193 /*! \brief The packed function to the `AsString` function. */
194 FAsString f_as_string;
195
196 void VisitAttrs(tvm::AttrVisitor* v) {
197 // `f_initialize_with_tune_context` is not visited
198 // `f_apply` is not visited
199 // `f_clone` is not visited
200 // `f_as_string` is not visited
201 }
202
203 void InitializeWithTuneContext(const TuneContext& context) final;
204 bool Apply(const tir::Schedule& sch) final;
205 Postproc Clone() const final;
206
207 static constexpr const char* _type_key = "meta_schedule.PyPostproc";
208 TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
209};
210
211} // namespace meta_schedule
212} // namespace tvm
213
214#endif // TVM_META_SCHEDULE_POSTPROC_H_
215