1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <mutex> |
20 | #include <unordered_map> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace meta_schedule { |
26 | |
27 | using tir::Instruction; |
28 | using tir::InstructionKind; |
29 | using tir::Trace; |
30 | |
31 | /*! |
32 | * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers |
33 | * \param decision The decision of Sample-Perfect-Tile |
34 | * \return The result of downcast |
35 | */ |
36 | std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) { |
37 | const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode); |
38 | return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr)); |
39 | } |
40 | |
41 | /*! |
42 | * \brief Calculate the product of elements in an array |
43 | * \param array The array |
44 | * \return The product of elements in the array |
45 | */ |
46 | int64_t Product(const std::vector<int64_t>& array) { |
47 | int64_t result = 1; |
48 | for (int64_t x : array) { |
49 | result *= x; |
50 | } |
51 | return result; |
52 | } |
53 | |
54 | /*! \brief A mutator that mutates the tile size */ |
55 | class MutateTileSizeNode : public MutatorNode { |
56 | public: |
57 | void VisitAttrs(tvm::AttrVisitor* v) {} |
58 | static constexpr const char* _type_key = "meta_schedule.MutateTileSize" ; |
59 | TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); |
60 | |
61 | public: |
62 | // Inherit from `MutatorNode` |
63 | void InitializeWithTuneContext(const TuneContext& context) final {} |
64 | // Inherit from `MutatorNode` |
65 | Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
66 | // Inherit from `MutatorNode` |
67 | Mutator Clone() const final { |
68 | ObjectPtr<MutateTileSizeNode> n = make_object<MutateTileSizeNode>(*this); |
69 | return Mutator(n); |
70 | } |
71 | }; |
72 | |
73 | /*! |
74 | * \brief Find a sample-perfect-tile decision in the trace |
75 | * \param trace The trace |
76 | * \param rand_state The random state |
77 | * \param inst The instruction selected |
78 | * \param decision The decision selected |
79 | * \return Whether a decision is found |
80 | */ |
81 | void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst, |
82 | std::vector<std::vector<int64_t>>* decision) { |
83 | static const InstructionKind& inst_sample_perfect_tile = |
84 | InstructionKind::Get("SamplePerfectTile" ); |
85 | std::vector<Instruction>& instructions = *inst; |
86 | std::vector<std::vector<int64_t>>& decisions = *decision; |
87 | instructions.reserve(trace->decisions.size()); |
88 | decisions.reserve(trace->decisions.size()); |
89 | for (const auto& kv : trace->decisions) { |
90 | const Instruction& inst = kv.first; |
91 | const ObjectRef& decision = kv.second; |
92 | if (inst->kind.same_as(inst_sample_perfect_tile)) { |
93 | std::vector<int64_t> tiles = DowncastTilingDecision(decision); |
94 | if (tiles.size() >= 2 && Product(tiles) >= 2) { |
95 | instructions.push_back(inst); |
96 | decisions.push_back(tiles); |
97 | } |
98 | } |
99 | } |
100 | } |
101 | |
102 | void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, |
103 | std::vector<int64_t>* decision) { |
104 | static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical" ); |
105 | static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate" ); |
106 | std::vector<Instruction>& instructions = *inst; |
107 | std::vector<int64_t>& decisions = *decision; |
108 | std::unordered_set<const Object*> annotated; |
109 | instructions.reserve(trace->decisions.size()); |
110 | decisions.reserve(trace->decisions.size()); |
111 | annotated.reserve(trace->decisions.size()); |
112 | // Find annotation with `meta_schedule_cooperative_fetch` |
113 | for (const Instruction& inst : trace->insts) { |
114 | if (inst->kind.same_as(inst_annotate)) { |
115 | ICHECK_EQ(inst->attrs.size(), 1); |
116 | ICHECK_EQ(inst->inputs.size(), 2); |
117 | if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { |
118 | const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>(); |
119 | ICHECK(ann_val); |
120 | annotated.insert(ann_val); |
121 | } |
122 | } |
123 | } |
124 | // Find sampling instruction that generates the annotation |
125 | for (const auto& kv : trace->decisions) { |
126 | const Instruction& inst = kv.first; |
127 | const ObjectRef& decision = kv.second; |
128 | if (inst->kind.same_as(inst_sample_categorical)) { |
129 | ICHECK_EQ(inst->outputs.size(), 1); |
130 | if (annotated.count(inst->outputs[0].get())) { |
131 | const auto* d = TVM_TYPE_AS(decision, IntImmNode); |
132 | instructions.push_back(inst); |
133 | decisions.push_back(d->value); |
134 | } |
135 | } |
136 | } |
137 | } |
138 | |
139 | struct FactorMemo { |
140 | static std::vector<int> Factorize(int n) { |
141 | if (const std::vector<int>* result = Global()->Query(n)) { |
142 | return *result; |
143 | } |
144 | std::vector<int> result; |
145 | for (int64_t i = 1; i * i <= n; ++i) { |
146 | if (n % i == 0) { |
147 | result.push_back(i); |
148 | if (i * i != n) { |
149 | result.push_back(n / i); |
150 | } |
151 | } |
152 | } |
153 | std::sort(result.begin(), result.end()); |
154 | Global()->Add(n, result); |
155 | return result; |
156 | } |
157 | |
158 | private: |
159 | const std::vector<int>* Query(int n) { |
160 | std::unique_lock<std::mutex> lock(mutex_); |
161 | auto it = memo_.find(n); |
162 | if (it != memo_.end()) { |
163 | return &it->second; |
164 | } |
165 | return nullptr; |
166 | } |
167 | |
168 | void Add(int n, std::vector<int> result) { |
169 | std::unique_lock<std::mutex> lock(mutex_); |
170 | memo_.emplace(n, std::move(result)); |
171 | } |
172 | |
173 | static FactorMemo* Global() { |
174 | static FactorMemo singleton; |
175 | return &singleton; |
176 | } |
177 | |
178 | std::unordered_map<int, std::vector<int>> memo_; |
179 | std::mutex mutex_; |
180 | }; |
181 | |
182 | Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst, |
183 | std::vector<int64_t> tiles, TRandState* rand_state) { |
184 | int n_splits = tiles.size(); |
185 | // Step 1. Choose two loops, `x` and `y` |
186 | int x, y; |
187 | // select source |
188 | while (true) { |
189 | x = tir::SampleInt(rand_state, 0, n_splits); |
190 | if (tiles[x] <= 1) { |
191 | continue; |
192 | } |
193 | y = tir::SampleInt(rand_state, 0, n_splits - 1); |
194 | if (y >= x) { |
195 | ++y; |
196 | } |
197 | std::vector<int> factors = FactorMemo::Factorize(tiles[x]); |
198 | // Step 2. Choose the divide factor |
199 | int64_t divide_factor; |
200 | if (y != n_splits - 1) { |
201 | divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())]; |
202 | } else { |
203 | int64_t limit = Downcast<Integer>(inst->attrs[1])->value; |
204 | int max_factor_index = static_cast<int>(factors.size()) - 1; |
205 | for (; max_factor_index >= 1; max_factor_index--) { |
206 | if (factors[max_factor_index] * tiles[y] <= limit) { |
207 | break; |
208 | } |
209 | } |
210 | if (max_factor_index == 0) { |
211 | if (n_splits <= 2) { |
212 | return NullOpt; |
213 | } |
214 | // Failed on this dst_idx, try next one. |
215 | continue; |
216 | } |
217 | divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)]; |
218 | } |
219 | tiles[x] /= divide_factor; |
220 | tiles[y] *= divide_factor; |
221 | return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles), |
222 | /*remove_postproc=*/true); |
223 | } |
224 | } |
225 | |
226 | Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst, |
227 | int64_t original_decision, TRandState* rand_state) { |
228 | ICHECK_EQ(inst->attrs.size(), 2); |
229 | std::vector<double> probs = |
230 | support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1])); |
231 | probs.erase(probs.begin() + original_decision); |
232 | int result = tir::MakeMultinomialSampler(rand_state, probs)(); |
233 | if (result >= original_decision) { |
234 | result += 1; |
235 | } |
236 | return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); |
237 | } |
238 | |
239 | Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { |
240 | std::vector<Instruction> sample_perfect_tile_insts; |
241 | std::vector<Instruction> sample_vectorize_insts; |
242 | std::vector<std::vector<int64_t>> sample_perfect_tile_tiles; |
243 | std::vector<int64_t> sample_vectorize_decisions; |
244 | FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles); |
245 | FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions); |
246 | int size_a = sample_perfect_tile_insts.size(); |
247 | int size_b = sample_vectorize_insts.size(); |
248 | if (size_a == 0 && size_b == 0) { |
249 | return NullOpt; |
250 | } |
251 | int n = tir::SampleInt(rand_state, 0, size_a + size_b); |
252 | if (n < size_a) { |
253 | return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n], |
254 | rand_state); |
255 | } else { |
256 | n -= size_a; |
257 | return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n], |
258 | rand_state); |
259 | } |
260 | } |
261 | |
262 | Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); } |
263 | |
264 | TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); |
265 | TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize" ).set_body_typed(Mutator::MutateTileSize); |
266 | |
267 | } // namespace meta_schedule |
268 | } // namespace tvm |
269 | |