1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/sketch_policy_rules.h |
22 | * \brief Rules for generating the sketches, sampling the initial population, and mutating the |
23 | * population in SketchPolicy. |
24 | */ |
25 | |
26 | #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |
27 | #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |
28 | |
29 | #include <tvm/auto_scheduler/loop_state.h> |
30 | #include <tvm/auto_scheduler/search_task.h> |
31 | |
32 | #include <string> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | #include "utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace auto_scheduler { |
40 | |
41 | class SketchPolicyNode; |
42 | |
43 | /********** Sketch Generation Rule **********/ |
44 | |
45 | /*! \brief The base class for derivation rules used in the sketch generation. */ |
46 | class SketchGenerationRule { |
47 | public: |
48 | /*! \brief Result enumeration of the condition function. */ |
49 | enum class ConditionKind : int { |
50 | /*! \brief Skip this rule and continue to try the next rules. */ |
51 | kSkip = 0, |
52 | /*! \brief Apply this rule and continue to try the next rules. */ |
53 | kApply = 1, |
54 | /*! \brief Apply this rule and skip the rest rules. */ |
55 | kApplyAndSkipRest = 2 |
56 | }; |
57 | |
58 | /*! |
59 | * \brief Condition check function of this rule. |
60 | * \param policy The SketchPolicyNode of this rule, some information may be used during |
61 | * the condition checking. |
62 | * \param state The original state to be checked. |
63 | * \param stage_id The index of the stage to process this condition check. |
64 | * \return The condition check result of this rule. |
65 | */ |
66 | virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, |
67 | int stage_id) const = 0; |
68 | |
69 | /*! |
70 | * \brief Apply function of this rule. |
71 | * \param policy The SketchPolicyNode of this rule, some information may be used during |
72 | * the rule applying. |
73 | * \param state The original state to apply this rule. |
74 | * \param stage_id The index of the next stage to apply this rule. |
75 | * \return The state after applying this rule, and index of the next stage. |
76 | */ |
77 | virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, |
78 | const State& state, int stage_id) const = 0; |
79 | |
80 | /*! |
81 | * \brief Get the name of this rule. |
82 | * \return A string of the rule name. |
83 | */ |
84 | virtual std::string GetRuleName() const = 0; |
85 | }; |
86 | |
87 | #define DEFINE_SKETCH_GENERATION_RULE(rule_name) \ |
88 | class rule_name : public SketchGenerationRule { \ |
89 | public: \ |
90 | ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, \ |
91 | int stage_id) const final; \ |
92 | std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \ |
93 | int stage_id) const final; \ |
94 | std::string GetRuleName() const final { return #rule_name; } \ |
95 | }; |
96 | |
97 | /*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to |
98 | * the next stage. */ |
99 | DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage); |
100 | |
101 | /*! \brief The rule that inlines simple elementwise ops. |
102 | * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly |
103 | * inlineable will have a chance to try different compute at location in InitPopulation later. |
104 | */ |
105 | DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline); |
106 | |
107 | /*! \brief The rule that performs multi-level tiling. */ |
108 | DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling); |
109 | |
110 | /*! \brief The rule that performs multi-level tiling and fuses later consumers. */ |
111 | DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion); |
112 | |
113 | /*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching, |
114 | * Currently only support 1 to 1 match cache read. */ |
115 | DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead); |
116 | |
117 | /*! \brief The rule that adds a cache write stage. */ |
118 | DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite); |
119 | |
120 | /*! \brief The rule that adds rfactor stage. */ |
121 | DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor); |
122 | |
123 | /*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors. |
124 | * This kind of op comes from winograd transformation. */ |
125 | DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor); |
126 | |
127 | /*! \brief The rule that use cross thread reduction for GPU. */ |
128 | DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction); |
129 | |
130 | /*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute |
131 | * location of the producers of compute ops that perform "fake reduction" with const tensors. */ |
132 | DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); |
133 | |
134 | /*! \brief The rule that allows users to generate custom sketches. */ |
135 | class RuleCustomSketch : public SketchGenerationRule { |
136 | public: |
137 | RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func, |
138 | String rule_name = "CustomSketchRule" ) |
139 | : meet_condition_func_(std::move(meet_condition_func)), |
140 | apply_func_(std::move(apply_func)), |
141 | rule_name_(std::move(rule_name)) {} |
142 | |
143 | ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, |
144 | int stage_id) const final; |
145 | |
146 | std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, |
147 | int stage_id) const final; |
148 | |
149 | std::string GetRuleName() const final { return rule_name_; } |
150 | |
151 | private: |
152 | PackedFunc meet_condition_func_; |
153 | PackedFunc apply_func_; |
154 | String rule_name_; |
155 | }; |
156 | |
157 | /********** Init Population **********/ |
158 | |
159 | /*! \brief The base class for rules used to annotate the sketches to get the initial population. */ |
160 | class PopulationGenerationRule { |
161 | public: |
162 | /*! \brief Result enumeration of the apply function. */ |
163 | enum class ResultKind : int { kValid = 0, kInvalid = 1 }; |
164 | |
165 | /*! |
166 | * \brief Apply function of this rule. |
167 | * \param policy The SketchPolicyNode of this rule, some member may get changed during the |
168 | * rule applying. (e.g. random number generator) |
169 | * \param state The state to apply this rule, update inplace. |
170 | * \return The result of this rule, indicate if there's any valid state generated. |
171 | */ |
172 | virtual ResultKind Apply(SketchPolicyNode* policy, State* state, |
173 | std::mt19937* rand_gen) const = 0; |
174 | |
175 | /*! \brief The deconstructor */ |
176 | virtual ~PopulationGenerationRule() = default; |
177 | }; |
178 | |
179 | // A helper to define population initialization rules |
180 | #define DEFINE_INIT_POPULATION_RULE(rule_name) \ |
181 | class rule_name : public PopulationGenerationRule { \ |
182 | public: \ |
183 | ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ |
184 | }; |
185 | |
186 | /*! \brief The rule that fills the incomplete SplitSteps. */ |
187 | DEFINE_INIT_POPULATION_RULE(InitFillTileSize); |
188 | |
189 | /*! \brief The rule that randomly changes the computation location for some stages that do not |
190 | * need tiling and are not strictly inlineable(e.g. data padding). */ |
191 | DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation); |
192 | |
193 | /*! \brief The rule that annotates parallel for CPU. */ |
194 | DEFINE_INIT_POPULATION_RULE(InitParallel); |
195 | |
196 | /*! \brief The rule that annotates unroll. */ |
197 | DEFINE_INIT_POPULATION_RULE(InitUnroll); |
198 | |
199 | /*! \brief The rule that annotates vectorization. */ |
200 | DEFINE_INIT_POPULATION_RULE(InitVectorization); |
201 | |
202 | /*! \brief The rule that annotates thread binding for GPU. */ |
203 | DEFINE_INIT_POPULATION_RULE(InitThreadBind); |
204 | |
205 | /********** Mutation **********/ |
206 | |
207 | /*! \brief The base class for mutation rules used in the evolutionary search. */ |
208 | class PopulationMutationRule : public PopulationGenerationRule { |
209 | public: |
210 | /* \brief The constructor |
211 | * \param selection_weight the probabiliy of applying this rule is |
212 | * proportional to this weight |
213 | */ |
214 | explicit PopulationMutationRule(double selection_weight) : weight(selection_weight) {} |
215 | |
216 | /* \brief The weight of this rule */ |
217 | double weight; |
218 | }; |
219 | |
220 | // A helper to define mutation rules used in the evolutionary search |
221 | #define DEFINE_MUTATE_POPULATION_RULE(rule_name) \ |
222 | class rule_name : public PopulationMutationRule { \ |
223 | public: \ |
224 | explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ |
225 | ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ |
226 | }; |
227 | |
228 | /*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor |
229 | and multipling it to another tile size. */ |
230 | DEFINE_MUTATE_POPULATION_RULE(MutateTileSize); |
231 | |
232 | /*! \brief The rule that mutates the number of fused outer iterators annotated by parallel. */ |
233 | DEFINE_MUTATE_POPULATION_RULE(MutateParallel); |
234 | |
235 | /*! \brief The rule that randomly changes the computation location for some stages that do not |
236 | * need tiling and are not strictly inlineable(e.g. data padding). */ |
237 | DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation); |
238 | |
239 | /*! \brief The rule that mutates the value of a randomly selected auto unroll pragma step. */ |
240 | DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll); |
241 | |
242 | } // namespace auto_scheduler |
243 | } // namespace tvm |
244 | |
245 | #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |
246 | |