1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <mutex>
20#include <unordered_map>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace meta_schedule {
26
27using tir::Instruction;
28using tir::InstructionKind;
29using tir::Trace;
30
31/*!
32 * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers
33 * \param decision The decision of Sample-Perfect-Tile
34 * \return The result of downcast
35 */
36std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
37 const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode);
38 return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
39}
40
41/*!
42 * \brief Calculate the product of elements in an array
43 * \param array The array
44 * \return The product of elements in the array
45 */
46int64_t Product(const std::vector<int64_t>& array) {
47 int64_t result = 1;
48 for (int64_t x : array) {
49 result *= x;
50 }
51 return result;
52}
53
54/*! \brief A mutator that mutates the tile size */
55class MutateTileSizeNode : public MutatorNode {
56 public:
57 void VisitAttrs(tvm::AttrVisitor* v) {}
58 static constexpr const char* _type_key = "meta_schedule.MutateTileSize";
59 TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode);
60
61 public:
62 // Inherit from `MutatorNode`
63 void InitializeWithTuneContext(const TuneContext& context) final {}
64 // Inherit from `MutatorNode`
65 Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
66 // Inherit from `MutatorNode`
67 Mutator Clone() const final {
68 ObjectPtr<MutateTileSizeNode> n = make_object<MutateTileSizeNode>(*this);
69 return Mutator(n);
70 }
71};
72
73/*!
74 * \brief Find a sample-perfect-tile decision in the trace
75 * \param trace The trace
76 * \param rand_state The random state
77 * \param inst The instruction selected
78 * \param decision The decision selected
79 * \return Whether a decision is found
80 */
81void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
82 std::vector<std::vector<int64_t>>* decision) {
83 static const InstructionKind& inst_sample_perfect_tile =
84 InstructionKind::Get("SamplePerfectTile");
85 std::vector<Instruction>& instructions = *inst;
86 std::vector<std::vector<int64_t>>& decisions = *decision;
87 instructions.reserve(trace->decisions.size());
88 decisions.reserve(trace->decisions.size());
89 for (const auto& kv : trace->decisions) {
90 const Instruction& inst = kv.first;
91 const ObjectRef& decision = kv.second;
92 if (inst->kind.same_as(inst_sample_perfect_tile)) {
93 std::vector<int64_t> tiles = DowncastTilingDecision(decision);
94 if (tiles.size() >= 2 && Product(tiles) >= 2) {
95 instructions.push_back(inst);
96 decisions.push_back(tiles);
97 }
98 }
99 }
100}
101
102void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
103 std::vector<int64_t>* decision) {
104 static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
105 static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
106 std::vector<Instruction>& instructions = *inst;
107 std::vector<int64_t>& decisions = *decision;
108 std::unordered_set<const Object*> annotated;
109 instructions.reserve(trace->decisions.size());
110 decisions.reserve(trace->decisions.size());
111 annotated.reserve(trace->decisions.size());
112 // Find annotation with `meta_schedule_cooperative_fetch`
113 for (const Instruction& inst : trace->insts) {
114 if (inst->kind.same_as(inst_annotate)) {
115 ICHECK_EQ(inst->attrs.size(), 1);
116 ICHECK_EQ(inst->inputs.size(), 2);
117 if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
118 const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
119 ICHECK(ann_val);
120 annotated.insert(ann_val);
121 }
122 }
123 }
124 // Find sampling instruction that generates the annotation
125 for (const auto& kv : trace->decisions) {
126 const Instruction& inst = kv.first;
127 const ObjectRef& decision = kv.second;
128 if (inst->kind.same_as(inst_sample_categorical)) {
129 ICHECK_EQ(inst->outputs.size(), 1);
130 if (annotated.count(inst->outputs[0].get())) {
131 const auto* d = TVM_TYPE_AS(decision, IntImmNode);
132 instructions.push_back(inst);
133 decisions.push_back(d->value);
134 }
135 }
136 }
137}
138
139struct FactorMemo {
140 static std::vector<int> Factorize(int n) {
141 if (const std::vector<int>* result = Global()->Query(n)) {
142 return *result;
143 }
144 std::vector<int> result;
145 for (int64_t i = 1; i * i <= n; ++i) {
146 if (n % i == 0) {
147 result.push_back(i);
148 if (i * i != n) {
149 result.push_back(n / i);
150 }
151 }
152 }
153 std::sort(result.begin(), result.end());
154 Global()->Add(n, result);
155 return result;
156 }
157
158 private:
159 const std::vector<int>* Query(int n) {
160 std::unique_lock<std::mutex> lock(mutex_);
161 auto it = memo_.find(n);
162 if (it != memo_.end()) {
163 return &it->second;
164 }
165 return nullptr;
166 }
167
168 void Add(int n, std::vector<int> result) {
169 std::unique_lock<std::mutex> lock(mutex_);
170 memo_.emplace(n, std::move(result));
171 }
172
173 static FactorMemo* Global() {
174 static FactorMemo singleton;
175 return &singleton;
176 }
177
178 std::unordered_map<int, std::vector<int>> memo_;
179 std::mutex mutex_;
180};
181
182Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
183 std::vector<int64_t> tiles, TRandState* rand_state) {
184 int n_splits = tiles.size();
185 // Step 1. Choose two loops, `x` and `y`
186 int x, y;
187 // select source
188 while (true) {
189 x = tir::SampleInt(rand_state, 0, n_splits);
190 if (tiles[x] <= 1) {
191 continue;
192 }
193 y = tir::SampleInt(rand_state, 0, n_splits - 1);
194 if (y >= x) {
195 ++y;
196 }
197 std::vector<int> factors = FactorMemo::Factorize(tiles[x]);
198 // Step 2. Choose the divide factor
199 int64_t divide_factor;
200 if (y != n_splits - 1) {
201 divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())];
202 } else {
203 int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
204 int max_factor_index = static_cast<int>(factors.size()) - 1;
205 for (; max_factor_index >= 1; max_factor_index--) {
206 if (factors[max_factor_index] * tiles[y] <= limit) {
207 break;
208 }
209 }
210 if (max_factor_index == 0) {
211 if (n_splits <= 2) {
212 return NullOpt;
213 }
214 // Failed on this dst_idx, try next one.
215 continue;
216 }
217 divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)];
218 }
219 tiles[x] /= divide_factor;
220 tiles[y] *= divide_factor;
221 return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
222 /*remove_postproc=*/true);
223 }
224}
225
226Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
227 int64_t original_decision, TRandState* rand_state) {
228 ICHECK_EQ(inst->attrs.size(), 2);
229 std::vector<double> probs =
230 support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
231 probs.erase(probs.begin() + original_decision);
232 int result = tir::MakeMultinomialSampler(rand_state, probs)();
233 if (result >= original_decision) {
234 result += 1;
235 }
236 return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
237}
238
239Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
240 std::vector<Instruction> sample_perfect_tile_insts;
241 std::vector<Instruction> sample_vectorize_insts;
242 std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
243 std::vector<int64_t> sample_vectorize_decisions;
244 FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
245 FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
246 int size_a = sample_perfect_tile_insts.size();
247 int size_b = sample_vectorize_insts.size();
248 if (size_a == 0 && size_b == 0) {
249 return NullOpt;
250 }
251 int n = tir::SampleInt(rand_state, 0, size_a + size_b);
252 if (n < size_a) {
253 return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
254 rand_state);
255 } else {
256 n -= size_a;
257 return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
258 rand_state);
259 }
260}
261
262Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }
263
264TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
265TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize);
266
267} // namespace meta_schedule
268} // namespace tvm
269