1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_MUTATOR_H_ |
21 | #define TVM_META_SCHEDULE_MUTATOR_H_ |
22 | |
23 | #include <tvm/node/reflection.h> |
24 | #include <tvm/runtime/container/optional.h> |
25 | #include <tvm/runtime/object.h> |
26 | #include <tvm/runtime/packed_func.h> |
27 | #include <tvm/support/random_engine.h> |
28 | #include <tvm/tir/schedule/schedule.h> |
29 | #include <tvm/tir/schedule/trace.h> |
30 | |
31 | namespace tvm { |
32 | namespace meta_schedule { |
33 | |
34 | class TuneContext; |
35 | class Mutator; |
36 | |
37 | /*! \brief Mutator is designed to mutate the trace to explore the design space. */ |
38 | class MutatorNode : public runtime::Object { |
39 | public: |
40 | /*! \brief Virtual destructor. */ |
41 | virtual ~MutatorNode() = default; |
42 | |
43 | void VisitAttrs(tvm::AttrVisitor* v) {} |
44 | |
45 | /*! |
46 | * \brief Initialize the design space generator with tuning context. |
47 | * \param context The tuning context for initialization. |
48 | * \note This method is supposed to be called only once before every other method. |
49 | */ |
50 | virtual void InitializeWithTuneContext(const TuneContext& context) = 0; |
51 | |
52 | /*! |
53 | * \brief Apply the mutator function to the given trace. |
54 | * \param trace The given trace for mutation. |
55 | * \param rand_state The random state for mutation. |
56 | * \return None if mutator failed, otherwise return the mutated trace. |
57 | */ |
58 | virtual Optional<tir::Trace> Apply(const tir::Trace& trace, |
59 | support::LinearCongruentialEngine::TRandState* rand_state) = 0; |
60 | |
61 | /*! |
62 | * \brief Clone the mutator. |
63 | * \return The cloned mutator. |
64 | */ |
65 | virtual Mutator Clone() const = 0; |
66 | |
67 | static constexpr const char* _type_key = "meta_schedule.Mutator" ; |
68 | TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); |
69 | }; |
70 | |
71 | /*! |
72 | * \brief Managed reference to MutatorNode |
73 | * \sa MutatorNode |
74 | */ |
75 | class Mutator : public runtime::ObjectRef { |
76 | public: |
77 | /*! |
78 | * \brief The function type of `InitializeWithTuneContext` method. |
79 | * \param context The tuning context for initialization. |
80 | */ |
81 | using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>; |
82 | /*! |
83 | * \brief Apply the mutator function to the given trace. |
84 | * \param trace The given trace for mutation. |
85 | * \return None if mutator failed, otherwise return the mutated trace. |
86 | */ |
87 | using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>( |
88 | const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; |
89 | /*! |
90 | * \brief Clone the mutator. |
91 | * \return The cloned mutator. |
92 | */ |
93 | using FClone = runtime::TypedPackedFunc<Mutator()>; |
94 | /*! |
95 | * \brief Get the mutator as string with name. |
96 | * \return The string of the mutator. |
97 | */ |
98 | using FAsString = runtime::TypedPackedFunc<String()>; |
99 | /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ |
100 | TVM_DLL static Mutator MutateTileSize(); |
101 | /*! |
102 | * \brief Create a Mutator that mutates the parallel extent |
103 | * \param max_jobs_per_core The maximum number of parallel jobs per core. |
104 | * \return The created mutator. |
105 | */ |
106 | TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core); |
107 | /*! |
108 | * \brief Create a Mutator that mutates auto unroll step |
109 | * \return The mutator created |
110 | */ |
111 | TVM_DLL static Mutator MutateUnroll(); |
112 | /*! |
113 | * \brief Create a Mutator that mutates the outcome of SampleComputeLocation |
114 | * \return The mutator created |
115 | */ |
116 | TVM_DLL static Mutator MutateComputeLocation(); |
117 | /*! |
118 | * \brief Create a Mutator that mutates auto thread binding. |
119 | * \return The mutator created |
120 | */ |
121 | TVM_DLL static Mutator MutateThreadBinding(); |
122 | /*! |
123 | * \brief Create a mutator with customized methods on the python-side. |
124 | * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. |
125 | * \param f_apply The packed function of `Apply`. |
126 | * \param f_clone The packed function of `Clone`. |
127 | * \param f_as_string The packed function of `AsString`. |
128 | * \return The mutator created. |
129 | */ |
130 | TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, |
131 | FApply f_apply, FClone f_clone, FAsString f_as_string); |
132 | /*! \brief Create default mutators for LLVM */ |
133 | TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM(); |
134 | /*! \brief Create default mutators for CUDA */ |
135 | TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA(); |
136 | /*! \brief Create default mutators for CUDA with TensorCore */ |
137 | TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDATensorCore(); |
138 | /*! \brief Create default mutators for Hexagon */ |
139 | TVM_DLL static Map<Mutator, FloatImm, void> DefaultHexagon(); |
140 | /*! \brief Create default mutators for Micro */ |
141 | TVM_DLL static Map<Mutator, FloatImm, void> DefaultMicro(); |
142 | |
143 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); |
144 | }; |
145 | |
146 | /*! \brief The mutator with customized methods on the python-side. */ |
147 | class PyMutatorNode : public MutatorNode { |
148 | public: |
149 | using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext; |
150 | using FApply = Mutator::FApply; |
151 | using FClone = Mutator::FClone; |
152 | using FAsString = Mutator::FAsString; |
153 | /*! \brief The packed function to the `InitializeWithTuneContext` function. */ |
154 | FInitializeWithTuneContext f_initialize_with_tune_context; |
155 | /*! \brief The packed function to the `Apply` function. */ |
156 | FApply f_apply; |
157 | /*! \brief The packed function to the `Clone` function. */ |
158 | FClone f_clone; |
159 | /*! \brief The packed function to the `AsString` function. */ |
160 | FAsString f_as_string; |
161 | |
162 | void VisitAttrs(tvm::AttrVisitor* v) { |
163 | // `f_initialize_with_tune_context` is not visited |
164 | // `f_apply` is not visited |
165 | // `f_clone` is not visited |
166 | // `f_as_string` is not visited |
167 | } |
168 | |
169 | void InitializeWithTuneContext(const TuneContext& context) final; |
170 | Optional<tir::Trace> Apply(const tir::Trace& trace, |
171 | support::LinearCongruentialEngine::TRandState* rand_state) final; |
172 | Mutator Clone() const final; |
173 | |
174 | static constexpr const char* _type_key = "meta_schedule.PyMutator" ; |
175 | TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); |
176 | }; |
177 | |
178 | } // namespace meta_schedule |
179 | } // namespace tvm |
180 | |
181 | #endif // TVM_META_SCHEDULE_MUTATOR_H_ |
182 | |