1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <random>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27struct PrimeTable {
28 /*! \brief The table contains prime numbers in [2, kMaxPrime) */
29 static constexpr const int32_t kMaxPrime = 65536;
30 /*! \brief The exact number of prime numbers in the table */
31 static constexpr const int32_t kNumPrimes = 6542;
32 /*!
33 * \brief For each number in [2, kMaxPrime), the index of its min factor.
34 * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i].
35 */
36 int32_t min_factor_idx[kMaxPrime];
37 /*! \brief The prime numbers in [2, kMaxPrime) */
38 std::vector<int32_t> primes;
39 /*!
40 * \brief The power of each prime number.
41 * pow_table[i, j] stores the result of pow(prime[i], j + 1)
42 */
43 std::vector<std::vector<int32_t>> pow_tab;
44
45 /*! \brief Get a global instance of the prime table */
46 static const PrimeTable* Global() {
47 static const PrimeTable table;
48 return &table;
49 }
50
51 /*! \brief Constructor, pre-computes all info in the prime table */
52 PrimeTable() {
53 constexpr const int64_t int_max = std::numeric_limits<int32_t>::max();
54 // Euler's sieve: prime number in linear time
55 for (int32_t i = 0; i < kMaxPrime; ++i) {
56 min_factor_idx[i] = -1;
57 }
58 primes.reserve(kNumPrimes);
59 for (int32_t x = 2; x < kMaxPrime; ++x) {
60 if (min_factor_idx[x] == -1) {
61 min_factor_idx[x] = primes.size();
62 primes.push_back(x);
63 }
64 for (size_t i = 0; i < primes.size(); ++i) {
65 int64_t factor = primes[i];
66 int64_t y = x * factor;
67 if (y >= kMaxPrime) {
68 break;
69 }
70 min_factor_idx[y] = i;
71 if (x % factor == 0) {
72 break;
73 }
74 }
75 }
76 ICHECK_EQ(static_cast<int32_t>(primes.size()), static_cast<int32_t>(kNumPrimes));
77 // Calculate the power table for each prime number
78 pow_tab.reserve(primes.size());
79 for (int32_t prime : primes) {
80 std::vector<int32_t> tab;
81 tab.reserve(32);
82 for (int64_t pow = prime; pow <= int_max; pow *= prime) {
83 tab.push_back(pow);
84 }
85 tab.shrink_to_fit();
86 pow_tab.emplace_back(std::move(tab));
87 }
88 }
89 /*!
90 * \brief Factorize a number n, and return in a cryptic format
91 * \param n The number to be factorized
92 * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)]
93 * For each pair (i, j), we define
94 * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number)
95 * (primes[i], j) if i != -1
96 * Then the factorization is
97 * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l)
98 */
99 std::vector<std::pair<int32_t, int32_t>> Factorize(int32_t n) const {
100 std::vector<std::pair<int32_t, int32_t>> result;
101 result.reserve(16);
102 int32_t i = 0, n_primes = primes.size();
103 // Phase 1: n >= kMaxPrime
104 for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) {
105 for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) {
106 }
107 if (j != 0) {
108 result.emplace_back(i, j);
109 }
110 }
111 // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number
112 if (n >= kMaxPrime) {
113 result.emplace_back(-1, n);
114 return result;
115 }
116 // Phase 2: n < kMaxPrime
117 for (int32_t j; n > 1;) {
118 int32_t i = min_factor_idx[n];
119 for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) {
120 }
121 result.emplace_back(i, j);
122 }
123 return result;
124 }
125};
126
127int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive,
128 int32_t max_exclusive) {
129 CHECK(min_inclusive < max_exclusive)
130 << "ValueError: max_exclusive must be greater than min_inclusive.";
131 if (min_inclusive + 1 == max_exclusive) {
132 return min_inclusive;
133 }
134 support::LinearCongruentialEngine rand_(rand_state);
135 std::uniform_int_distribution<int32_t> dist(min_inclusive, max_exclusive - 1);
136 return dist(rand_);
137}
138
139std::vector<int32_t> SampleWithoutReplacement(
140 support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) {
141 if (k == 1) {
142 return {SampleInt(rand_state, 0, n)};
143 }
144 if (k == 2) {
145 int32_t result0 = SampleInt(rand_state, 0, n);
146 int32_t result1 = SampleInt(rand_state, 0, n - 1);
147 if (result1 >= result0) {
148 result1 += 1;
149 }
150 return {result0, result1};
151 }
152 std::vector<int32_t> order(n);
153 for (int32_t i = 0; i < n; ++i) {
154 order[i] = i;
155 }
156 for (int32_t i = 0; i < k; ++i) {
157 int32_t j = SampleInt(rand_state, i, n);
158 if (i != j) {
159 std::swap(order[i], order[j]);
160 }
161 }
162 return {order.begin(), order.begin() + k};
163}
164
165int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
166 const Array<Integer>& candidates, const Array<FloatImm>& probs,
167 Optional<Integer>* decision) {
168 CHECK(candidates.size() == probs.size())
169 << "ValueError: number of candidates does not match number of probabilities.";
170 int32_t i = -1;
171 int32_t n = candidates.size();
172 if (decision->defined()) {
173 const auto* int_imm = decision->as<IntImmNode>();
174 i = int_imm->value;
175 CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n
176 << ", but decision is: " << i;
177 } else {
178 std::vector<double> weights = support::AsVector<FloatImm, double>(probs);
179 std::discrete_distribution<int32_t> dist(weights.begin(), weights.end());
180 support::LinearCongruentialEngine rand_(rand_state);
181 i = dist(rand_);
182 ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n
183 << ", but decision is: " << i;
184 }
185
186 *decision = Integer(i); // decision is guaranteed not to be nullptr.
187 return candidates[i].IntValue();
188}
189
190std::function<int32_t()> MakeMultinomialSampler(
191 support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights) {
192 ICHECK(!weights.empty());
193 std::vector<double> sums;
194 sums.reserve(weights.size());
195 double sum = 0.0;
196 for (double w : weights) {
197 sums.push_back(sum += w);
198 }
199 return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(),
200 dist = std::uniform_real_distribution<double>(0.0, sum),
201 sums = std::move(sums)]() mutable -> int32_t {
202 support::LinearCongruentialEngine rand_(&rng);
203 double p = dist(rand_);
204 int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin();
205 int32_t n = sums.size();
206 CHECK_LE(0, idx);
207 CHECK_LE(idx, n);
208 return (idx == n) ? (n - 1) : idx;
209 };
210}
211
212std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
213 int32_t extent, int32_t n_splits) {
214 CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent";
215 CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits";
216 // Handle special case that we can potentially accelerate
217 if (n_splits == 1) {
218 return {extent};
219 }
220 if (extent == 1) {
221 return std::vector<int64_t>(n_splits, 1);
222 }
223 // Enumerate each pair (i, j), we define
224 // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number)
225 // (primes[i], j) if i != -1
226 // Then the factorization is
227 // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l)
228 const PrimeTable* prime_tab = PrimeTable::Global();
229 std::vector<std::pair<int32_t, int32_t>> factorized = prime_tab->Factorize(extent);
230 if (n_splits == 2) {
231 // n_splits = 2, this can be taken special care of,
232 // because general reservoir sampling can be avoided to accelerate the sampling
233 int32_t result0 = 1;
234 int32_t result1 = 1;
235 for (const std::pair<int32_t, int32_t>& ij : factorized) {
236 // Case 1: (a, p) = (j, 1), where j is a prime number
237 if (ij.first == -1) {
238 (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second;
239 continue;
240 }
241 // Case 2: (a = primes[i], p = 1)
242 int32_t p = ij.second;
243 const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1;
244 int32_t x1 = SampleInt(rand_state, 0, p + 1);
245 int32_t x2 = p - x1;
246 if (x1 != 0) {
247 result0 *= pow[x1];
248 }
249 if (x2 != 0) {
250 result1 *= pow[x2];
251 }
252 }
253 return {result0, result1};
254 }
255 // Data range:
256 // 2 <= extent <= 2^31 - 1
257 // 3 <= n_splits <= max tiling splits
258 // 1 <= p <= 31
259 std::vector<int64_t> result(n_splits, 1);
260 for (const std::pair<int32_t, int32_t>& ij : factorized) {
261 // Handle special cases to accelerate sampling
262 // Case 1: (a, p) = (j, 1), where j is a prime number
263 if (ij.first == -1) {
264 result[SampleInt(rand_state, 0, n_splits)] *= ij.second;
265 continue;
266 }
267 // Case 2: (a = primes[i], p = 1)
268 int32_t p = ij.second;
269 if (p == 1) {
270 result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first];
271 continue;
272 }
273 // The general case. We have to sample uniformly from the solution of:
274 // x_1 + x_2 + ... + x_{n_splits} = p
275 // where x_i >= 0
276 // Data range:
277 // 2 <= p <= 31
278 // 3 <= n_splits <= max tiling splits
279 std::vector<int32_t> sampled =
280 SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1);
281 std::sort(sampled.begin(), sampled.end());
282 sampled.push_back(p + n_splits - 1);
283 const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1;
284 for (int32_t i = 0, last = -1; i < n_splits; ++i) {
285 int32_t x = sampled[i] - last - 1;
286 last = sampled[i];
287 if (x != 0) {
288 result[i] *= pow[x];
289 }
290 }
291 }
292 return result;
293}
294
295std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
296 int32_t extent, int32_t n_splits,
297 int32_t max_innermost_factor) {
298 if (max_innermost_factor == -1) {
299 return SamplePerfectTile(rand_state, extent, n_splits);
300 }
301 CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits";
302 while (true) {
303 std::vector<int64_t> result = SamplePerfectTile(rand_state, extent, n_splits);
304 if (result.back() <= max_innermost_factor) {
305 return result;
306 }
307 }
308}
309
310std::vector<int64_t> SamplePerfectTile(
311 support::LinearCongruentialEngine::TRandState* rand_state, //
312 const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor,
313 Optional<Array<Integer>>* decision) {
314 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
315 const int64_t* extent = GetLoopIntExtent(loop);
316 std::vector<int64_t> result;
317 if (extent == nullptr) {
318 // Case 1. Handle loops with non-constant length
319 result = std::vector<int64_t>(n_splits, 1);
320 result[0] = -1;
321 } else if (decision->defined()) {
322 // Case 2. Use previous decision
323 result = support::AsVector<Integer, int64_t>(decision->value());
324 int n = result.size();
325 ICHECK_GE(n, 2);
326 int64_t len = *extent;
327 for (int i = n - 1; i > 0; --i) {
328 int64_t& l = result[i];
329 // A previous decision could become invalid because of the change of outer tiles
330 // To handle this case properly, we check if the tiling strategy is still perfect.
331 // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles
332 if (len % l != 0) {
333 l = len;
334 }
335 len /= l;
336 }
337 result[0] = len;
338 } else {
339 // Case 3. Use fresh new sampling result
340 result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor);
341 if (max_innermost_factor != -1) {
342 ICHECK_LE(result.back(), max_innermost_factor);
343 }
344 }
345 *decision = support::AsArray<int64_t, Integer>(result);
346 return result;
347}
348
349tir::StmtSRef SampleComputeLocation(tir::ScheduleState self,
350 support::LinearCongruentialEngine::TRandState* rand_state,
351 const StmtSRef& block_sref, Optional<Integer>* decision) {
352 // Step 1. Collect all possible compute-at locations.
353 auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref);
354 ICHECK_EQ(location_srefs.size(), location_indices.size());
355
356 // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the
357 // location candidates. Otherwise, pick the location before the previous decision.
358 // Step 3. If there was not a previous decision, sample a decision from the collected locations.
359 if (decision->defined()) {
360 int64_t old_decision = Downcast<Integer>(*decision)->value;
361 auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision);
362 int idx = it - location_indices.begin();
363
364 if (it != location_indices.end() && *it == old_decision) {
365 *decision = Integer(old_decision);
366 return location_srefs[idx];
367 } else if (it != location_indices.begin()) {
368 *decision = Integer(location_indices[idx - 1]);
369 return location_srefs[idx - 1];
370 } else {
371 *decision = Integer(-1);
372 return StmtSRef::RootMark();
373 }
374 } else {
375 int sampled_idx = SampleInt(rand_state, 0, location_indices.size());
376 *decision = Integer(location_indices[sampled_idx]);
377 return location_srefs[sampled_idx];
378 }
379}
380
381/******** InstructionKind Registration ********/
382
383struct SampleCategoricalTraits : public UnpackedInstTraits<SampleCategoricalTraits> {
384 static constexpr const char* kName = "SampleCategorical";
385 static constexpr bool kIsPure = true;
386
387 private:
388 static constexpr size_t kNumInputs = 0;
389 static constexpr size_t kNumAttrs = 2;
390 static constexpr size_t kNumDecisions = 1;
391
392 static ExprRV UnpackedApplyToSchedule(Schedule sch, //
393 Array<Integer> candidates, //
394 Array<FloatImm> probs, //
395 Optional<Integer> decision) {
396 return sch->SampleCategorical(candidates, probs, decision);
397 }
398
399 static String UnpackedAsPython(Array<String> outputs, //
400 Array<Integer> candidates, //
401 Array<FloatImm> probs, //
402 Optional<Integer> decision) {
403 PythonAPICall py("sample_categorical");
404 py.Input("candidates", candidates);
405 py.Input("probs", probs);
406 py.Decision(decision);
407 py.SingleOutput(outputs);
408 return py.Str();
409 }
410
411 template <typename>
412 friend struct ::tvm::tir::UnpackedInstTraits;
413};
414
415struct SamplePerfectTileTraits : public UnpackedInstTraits<SamplePerfectTileTraits> {
416 static constexpr const char* kName = "SamplePerfectTile";
417 static constexpr bool kIsPure = true;
418
419 private:
420 static constexpr size_t kNumInputs = 1;
421 static constexpr size_t kNumAttrs = 2;
422 static constexpr size_t kNumDecisions = 1;
423
424 static Array<ExprRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n,
425 Integer max_innermost_factor,
426 Optional<Array<Integer>> decision) {
427 return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision);
428 }
429
430 static String UnpackedAsPython(Array<String> outputs, String loop_rv, Integer n,
431 Integer max_innermost_factor, Optional<Array<Integer>> decision) {
432 PythonAPICall py("sample_perfect_tile");
433 py.Input("loop", loop_rv);
434 py.Input("n", n->value);
435 py.Input("max_innermost_factor", max_innermost_factor->value);
436 py.Decision(decision);
437 py.OutputList(outputs);
438 return py.Str();
439 }
440
441 template <typename>
442 friend struct ::tvm::tir::UnpackedInstTraits;
443};
444
445struct SampleComputeLocationTraits : public UnpackedInstTraits<SampleComputeLocationTraits> {
446 static constexpr const char* kName = "SampleComputeLocation";
447 static constexpr bool kIsPure = true;
448
449 private:
450 static constexpr size_t kNumInputs = 1;
451 static constexpr size_t kNumAttrs = 0;
452 static constexpr size_t kNumDecisions = 1;
453
454 static LoopRV UnpackedApplyToSchedule(Schedule sch, //
455 BlockRV block_rv, //
456 Optional<Integer> decision) {
457 return sch->SampleComputeLocation(block_rv, decision);
458 }
459
460 static String UnpackedAsPython(Array<String> outputs, //
461 String block_rv, //
462 Optional<Integer> decision) {
463 PythonAPICall py("sample_compute_location");
464 py.Input("block", block_rv);
465 py.Decision(decision);
466 py.SingleOutput(outputs);
467 return py.Str();
468 }
469
470 template <typename>
471 friend struct ::tvm::tir::UnpackedInstTraits;
472};
473
474TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits);
475TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits);
476TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits);
477
478} // namespace tir
479} // namespace tvm
480