1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/sketch_policy_rules.h
22 * \brief Rules for generating the sketches, sampling the initial population, and mutating the
23 * population in SketchPolicy.
24 */
25
26#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
27#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
28
29#include <tvm/auto_scheduler/loop_state.h>
30#include <tvm/auto_scheduler/search_task.h>
31
32#include <string>
33#include <utility>
34#include <vector>
35
36#include "utils.h"
37
38namespace tvm {
39namespace auto_scheduler {
40
41class SketchPolicyNode;
42
43/********** Sketch Generation Rule **********/
44
45/*! \brief The base class for derivation rules used in the sketch generation. */
46class SketchGenerationRule {
47 public:
48 /*! \brief Result enumeration of the condition function. */
49 enum class ConditionKind : int {
50 /*! \brief Skip this rule and continue to try the next rules. */
51 kSkip = 0,
52 /*! \brief Apply this rule and continue to try the next rules. */
53 kApply = 1,
54 /*! \brief Apply this rule and skip the rest rules. */
55 kApplyAndSkipRest = 2
56 };
57
58 /*!
59 * \brief Condition check function of this rule.
60 * \param policy The SketchPolicyNode of this rule, some information may be used during
61 * the condition checking.
62 * \param state The original state to be checked.
63 * \param stage_id The index of the stage to process this condition check.
64 * \return The condition check result of this rule.
65 */
66 virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
67 int stage_id) const = 0;
68
69 /*!
70 * \brief Apply function of this rule.
71 * \param policy The SketchPolicyNode of this rule, some information may be used during
72 * the rule applying.
73 * \param state The original state to apply this rule.
74 * \param stage_id The index of the next stage to apply this rule.
75 * \return The state after applying this rule, and index of the next stage.
76 */
77 virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy,
78 const State& state, int stage_id) const = 0;
79
80 /*!
81 * \brief Get the name of this rule.
82 * \return A string of the rule name.
83 */
84 virtual std::string GetRuleName() const = 0;
85};
86
87#define DEFINE_SKETCH_GENERATION_RULE(rule_name) \
88 class rule_name : public SketchGenerationRule { \
89 public: \
90 ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, \
91 int stage_id) const final; \
92 std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
93 int stage_id) const final; \
94 std::string GetRuleName() const final { return #rule_name; } \
95 };
96
97/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
98 * the next stage. */
99DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage);
100
101/*! \brief The rule that inlines simple elementwise ops.
102 * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly
103 * inlineable will have a chance to try different compute at location in InitPopulation later.
104 */
105DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline);
106
107/*! \brief The rule that performs multi-level tiling. */
108DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling);
109
110/*! \brief The rule that performs multi-level tiling and fuses later consumers. */
111DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion);
112
113/*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching,
114 * Currently only support 1 to 1 match cache read. */
115DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead);
116
117/*! \brief The rule that adds a cache write stage. */
118DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite);
119
120/*! \brief The rule that adds rfactor stage. */
121DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor);
122
123/*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors.
124 * This kind of op comes from winograd transformation. */
125DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor);
126
127/*! \brief The rule that use cross thread reduction for GPU. */
128DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
129
130/*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute
131 * location of the producers of compute ops that perform "fake reduction" with const tensors. */
132DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
133
134/*! \brief The rule that allows users to generate custom sketches. */
135class RuleCustomSketch : public SketchGenerationRule {
136 public:
137 RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func,
138 String rule_name = "CustomSketchRule")
139 : meet_condition_func_(std::move(meet_condition_func)),
140 apply_func_(std::move(apply_func)),
141 rule_name_(std::move(rule_name)) {}
142
143 ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
144 int stage_id) const final;
145
146 std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
147 int stage_id) const final;
148
149 std::string GetRuleName() const final { return rule_name_; }
150
151 private:
152 PackedFunc meet_condition_func_;
153 PackedFunc apply_func_;
154 String rule_name_;
155};
156
157/********** Init Population **********/
158
159/*! \brief The base class for rules used to annotate the sketches to get the initial population. */
160class PopulationGenerationRule {
161 public:
162 /*! \brief Result enumeration of the apply function. */
163 enum class ResultKind : int { kValid = 0, kInvalid = 1 };
164
165 /*!
166 * \brief Apply function of this rule.
167 * \param policy The SketchPolicyNode of this rule, some member may get changed during the
168 * rule applying. (e.g. random number generator)
169 * \param state The state to apply this rule, update inplace.
170 * \return The result of this rule, indicate if there's any valid state generated.
171 */
172 virtual ResultKind Apply(SketchPolicyNode* policy, State* state,
173 std::mt19937* rand_gen) const = 0;
174
175 /*! \brief The deconstructor */
176 virtual ~PopulationGenerationRule() = default;
177};
178
179// A helper to define population initialization rules
180#define DEFINE_INIT_POPULATION_RULE(rule_name) \
181 class rule_name : public PopulationGenerationRule { \
182 public: \
183 ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
184 };
185
186/*! \brief The rule that fills the incomplete SplitSteps. */
187DEFINE_INIT_POPULATION_RULE(InitFillTileSize);
188
189/*! \brief The rule that randomly changes the computation location for some stages that do not
190 * need tiling and are not strictly inlineable(e.g. data padding). */
191DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation);
192
193/*! \brief The rule that annotates parallel for CPU. */
194DEFINE_INIT_POPULATION_RULE(InitParallel);
195
196/*! \brief The rule that annotates unroll. */
197DEFINE_INIT_POPULATION_RULE(InitUnroll);
198
199/*! \brief The rule that annotates vectorization. */
200DEFINE_INIT_POPULATION_RULE(InitVectorization);
201
202/*! \brief The rule that annotates thread binding for GPU. */
203DEFINE_INIT_POPULATION_RULE(InitThreadBind);
204
205/********** Mutation **********/
206
207/*! \brief The base class for mutation rules used in the evolutionary search. */
208class PopulationMutationRule : public PopulationGenerationRule {
209 public:
210 /* \brief The constructor
211 * \param selection_weight the probabiliy of applying this rule is
212 * proportional to this weight
213 */
214 explicit PopulationMutationRule(double selection_weight) : weight(selection_weight) {}
215
216 /* \brief The weight of this rule */
217 double weight;
218};
219
220// A helper to define mutation rules used in the evolutionary search
221#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \
222 class rule_name : public PopulationMutationRule { \
223 public: \
224 explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
225 ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
226 };
227
228/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor
229 and multipling it to another tile size. */
230DEFINE_MUTATE_POPULATION_RULE(MutateTileSize);
231
232/*! \brief The rule that mutates the number of fused outer iterators annotated by parallel. */
233DEFINE_MUTATE_POPULATION_RULE(MutateParallel);
234
235/*! \brief The rule that randomly changes the computation location for some stages that do not
236 * need tiling and are not strictly inlineable(e.g. data padding). */
237DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation);
238
239/*! \brief The rule that mutates the value of a randomly selected auto unroll pragma step. */
240DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll);
241
242} // namespace auto_scheduler
243} // namespace tvm
244
245#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
246