1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <random> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | struct PrimeTable { |
28 | /*! \brief The table contains prime numbers in [2, kMaxPrime) */ |
29 | static constexpr const int32_t kMaxPrime = 65536; |
30 | /*! \brief The exact number of prime numbers in the table */ |
31 | static constexpr const int32_t kNumPrimes = 6542; |
32 | /*! |
33 | * \brief For each number in [2, kMaxPrime), the index of its min factor. |
34 | * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. |
35 | */ |
36 | int32_t min_factor_idx[kMaxPrime]; |
37 | /*! \brief The prime numbers in [2, kMaxPrime) */ |
38 | std::vector<int32_t> primes; |
39 | /*! |
40 | * \brief The power of each prime number. |
41 | * pow_table[i, j] stores the result of pow(prime[i], j + 1) |
42 | */ |
43 | std::vector<std::vector<int32_t>> pow_tab; |
44 | |
45 | /*! \brief Get a global instance of the prime table */ |
46 | static const PrimeTable* Global() { |
47 | static const PrimeTable table; |
48 | return &table; |
49 | } |
50 | |
51 | /*! \brief Constructor, pre-computes all info in the prime table */ |
52 | PrimeTable() { |
53 | constexpr const int64_t int_max = std::numeric_limits<int32_t>::max(); |
54 | // Euler's sieve: prime number in linear time |
55 | for (int32_t i = 0; i < kMaxPrime; ++i) { |
56 | min_factor_idx[i] = -1; |
57 | } |
58 | primes.reserve(kNumPrimes); |
59 | for (int32_t x = 2; x < kMaxPrime; ++x) { |
60 | if (min_factor_idx[x] == -1) { |
61 | min_factor_idx[x] = primes.size(); |
62 | primes.push_back(x); |
63 | } |
64 | for (size_t i = 0; i < primes.size(); ++i) { |
65 | int64_t factor = primes[i]; |
66 | int64_t y = x * factor; |
67 | if (y >= kMaxPrime) { |
68 | break; |
69 | } |
70 | min_factor_idx[y] = i; |
71 | if (x % factor == 0) { |
72 | break; |
73 | } |
74 | } |
75 | } |
76 | ICHECK_EQ(static_cast<int32_t>(primes.size()), static_cast<int32_t>(kNumPrimes)); |
77 | // Calculate the power table for each prime number |
78 | pow_tab.reserve(primes.size()); |
79 | for (int32_t prime : primes) { |
80 | std::vector<int32_t> tab; |
81 | tab.reserve(32); |
82 | for (int64_t pow = prime; pow <= int_max; pow *= prime) { |
83 | tab.push_back(pow); |
84 | } |
85 | tab.shrink_to_fit(); |
86 | pow_tab.emplace_back(std::move(tab)); |
87 | } |
88 | } |
89 | /*! |
90 | * \brief Factorize a number n, and return in a cryptic format |
91 | * \param n The number to be factorized |
92 | * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] |
93 | * For each pair (i, j), we define |
94 | * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) |
95 | * (primes[i], j) if i != -1 |
96 | * Then the factorization is |
97 | * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) |
98 | */ |
99 | std::vector<std::pair<int32_t, int32_t>> Factorize(int32_t n) const { |
100 | std::vector<std::pair<int32_t, int32_t>> result; |
101 | result.reserve(16); |
102 | int32_t i = 0, n_primes = primes.size(); |
103 | // Phase 1: n >= kMaxPrime |
104 | for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { |
105 | for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { |
106 | } |
107 | if (j != 0) { |
108 | result.emplace_back(i, j); |
109 | } |
110 | } |
111 | // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number |
112 | if (n >= kMaxPrime) { |
113 | result.emplace_back(-1, n); |
114 | return result; |
115 | } |
116 | // Phase 2: n < kMaxPrime |
117 | for (int32_t j; n > 1;) { |
118 | int32_t i = min_factor_idx[n]; |
119 | for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { |
120 | } |
121 | result.emplace_back(i, j); |
122 | } |
123 | return result; |
124 | } |
125 | }; |
126 | |
127 | int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, |
128 | int32_t max_exclusive) { |
129 | CHECK(min_inclusive < max_exclusive) |
130 | << "ValueError: max_exclusive must be greater than min_inclusive." ; |
131 | if (min_inclusive + 1 == max_exclusive) { |
132 | return min_inclusive; |
133 | } |
134 | support::LinearCongruentialEngine rand_(rand_state); |
135 | std::uniform_int_distribution<int32_t> dist(min_inclusive, max_exclusive - 1); |
136 | return dist(rand_); |
137 | } |
138 | |
139 | std::vector<int32_t> SampleWithoutReplacement( |
140 | support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) { |
141 | if (k == 1) { |
142 | return {SampleInt(rand_state, 0, n)}; |
143 | } |
144 | if (k == 2) { |
145 | int32_t result0 = SampleInt(rand_state, 0, n); |
146 | int32_t result1 = SampleInt(rand_state, 0, n - 1); |
147 | if (result1 >= result0) { |
148 | result1 += 1; |
149 | } |
150 | return {result0, result1}; |
151 | } |
152 | std::vector<int32_t> order(n); |
153 | for (int32_t i = 0; i < n; ++i) { |
154 | order[i] = i; |
155 | } |
156 | for (int32_t i = 0; i < k; ++i) { |
157 | int32_t j = SampleInt(rand_state, i, n); |
158 | if (i != j) { |
159 | std::swap(order[i], order[j]); |
160 | } |
161 | } |
162 | return {order.begin(), order.begin() + k}; |
163 | } |
164 | |
165 | int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, |
166 | const Array<Integer>& candidates, const Array<FloatImm>& probs, |
167 | Optional<Integer>* decision) { |
168 | CHECK(candidates.size() == probs.size()) |
169 | << "ValueError: number of candidates does not match number of probabilities." ; |
170 | int32_t i = -1; |
171 | int32_t n = candidates.size(); |
172 | if (decision->defined()) { |
173 | const auto* int_imm = decision->as<IntImmNode>(); |
174 | i = int_imm->value; |
175 | CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n |
176 | << ", but decision is: " << i; |
177 | } else { |
178 | std::vector<double> weights = support::AsVector<FloatImm, double>(probs); |
179 | std::discrete_distribution<int32_t> dist(weights.begin(), weights.end()); |
180 | support::LinearCongruentialEngine rand_(rand_state); |
181 | i = dist(rand_); |
182 | ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n |
183 | << ", but decision is: " << i; |
184 | } |
185 | |
186 | *decision = Integer(i); // decision is guaranteed not to be nullptr. |
187 | return candidates[i].IntValue(); |
188 | } |
189 | |
190 | std::function<int32_t()> MakeMultinomialSampler( |
191 | support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights) { |
192 | ICHECK(!weights.empty()); |
193 | std::vector<double> sums; |
194 | sums.reserve(weights.size()); |
195 | double sum = 0.0; |
196 | for (double w : weights) { |
197 | sums.push_back(sum += w); |
198 | } |
199 | return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), |
200 | dist = std::uniform_real_distribution<double>(0.0, sum), |
201 | sums = std::move(sums)]() mutable -> int32_t { |
202 | support::LinearCongruentialEngine rand_(&rng); |
203 | double p = dist(rand_); |
204 | int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); |
205 | int32_t n = sums.size(); |
206 | CHECK_LE(0, idx); |
207 | CHECK_LE(idx, n); |
208 | return (idx == n) ? (n - 1) : idx; |
209 | }; |
210 | } |
211 | |
212 | std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, |
213 | int32_t extent, int32_t n_splits) { |
214 | CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent" ; |
215 | CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits" ; |
216 | // Handle special case that we can potentially accelerate |
217 | if (n_splits == 1) { |
218 | return {extent}; |
219 | } |
220 | if (extent == 1) { |
221 | return std::vector<int64_t>(n_splits, 1); |
222 | } |
223 | // Enumerate each pair (i, j), we define |
224 | // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) |
225 | // (primes[i], j) if i != -1 |
226 | // Then the factorization is |
227 | // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) |
228 | const PrimeTable* prime_tab = PrimeTable::Global(); |
229 | std::vector<std::pair<int32_t, int32_t>> factorized = prime_tab->Factorize(extent); |
230 | if (n_splits == 2) { |
231 | // n_splits = 2, this can be taken special care of, |
232 | // because general reservoir sampling can be avoided to accelerate the sampling |
233 | int32_t result0 = 1; |
234 | int32_t result1 = 1; |
235 | for (const std::pair<int32_t, int32_t>& ij : factorized) { |
236 | // Case 1: (a, p) = (j, 1), where j is a prime number |
237 | if (ij.first == -1) { |
238 | (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; |
239 | continue; |
240 | } |
241 | // Case 2: (a = primes[i], p = 1) |
242 | int32_t p = ij.second; |
243 | const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; |
244 | int32_t x1 = SampleInt(rand_state, 0, p + 1); |
245 | int32_t x2 = p - x1; |
246 | if (x1 != 0) { |
247 | result0 *= pow[x1]; |
248 | } |
249 | if (x2 != 0) { |
250 | result1 *= pow[x2]; |
251 | } |
252 | } |
253 | return {result0, result1}; |
254 | } |
255 | // Data range: |
256 | // 2 <= extent <= 2^31 - 1 |
257 | // 3 <= n_splits <= max tiling splits |
258 | // 1 <= p <= 31 |
259 | std::vector<int64_t> result(n_splits, 1); |
260 | for (const std::pair<int32_t, int32_t>& ij : factorized) { |
261 | // Handle special cases to accelerate sampling |
262 | // Case 1: (a, p) = (j, 1), where j is a prime number |
263 | if (ij.first == -1) { |
264 | result[SampleInt(rand_state, 0, n_splits)] *= ij.second; |
265 | continue; |
266 | } |
267 | // Case 2: (a = primes[i], p = 1) |
268 | int32_t p = ij.second; |
269 | if (p == 1) { |
270 | result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; |
271 | continue; |
272 | } |
273 | // The general case. We have to sample uniformly from the solution of: |
274 | // x_1 + x_2 + ... + x_{n_splits} = p |
275 | // where x_i >= 0 |
276 | // Data range: |
277 | // 2 <= p <= 31 |
278 | // 3 <= n_splits <= max tiling splits |
279 | std::vector<int32_t> sampled = |
280 | SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); |
281 | std::sort(sampled.begin(), sampled.end()); |
282 | sampled.push_back(p + n_splits - 1); |
283 | const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; |
284 | for (int32_t i = 0, last = -1; i < n_splits; ++i) { |
285 | int32_t x = sampled[i] - last - 1; |
286 | last = sampled[i]; |
287 | if (x != 0) { |
288 | result[i] *= pow[x]; |
289 | } |
290 | } |
291 | } |
292 | return result; |
293 | } |
294 | |
295 | std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, |
296 | int32_t extent, int32_t n_splits, |
297 | int32_t max_innermost_factor) { |
298 | if (max_innermost_factor == -1) { |
299 | return SamplePerfectTile(rand_state, extent, n_splits); |
300 | } |
301 | CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits" ; |
302 | while (true) { |
303 | std::vector<int64_t> result = SamplePerfectTile(rand_state, extent, n_splits); |
304 | if (result.back() <= max_innermost_factor) { |
305 | return result; |
306 | } |
307 | } |
308 | } |
309 | |
310 | std::vector<int64_t> SamplePerfectTile( |
311 | support::LinearCongruentialEngine::TRandState* rand_state, // |
312 | const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, |
313 | Optional<Array<Integer>>* decision) { |
314 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
315 | const int64_t* extent = GetLoopIntExtent(loop); |
316 | std::vector<int64_t> result; |
317 | if (extent == nullptr) { |
318 | // Case 1. Handle loops with non-constant length |
319 | result = std::vector<int64_t>(n_splits, 1); |
320 | result[0] = -1; |
321 | } else if (decision->defined()) { |
322 | // Case 2. Use previous decision |
323 | result = support::AsVector<Integer, int64_t>(decision->value()); |
324 | int n = result.size(); |
325 | ICHECK_GE(n, 2); |
326 | int64_t len = *extent; |
327 | for (int i = n - 1; i > 0; --i) { |
328 | int64_t& l = result[i]; |
329 | // A previous decision could become invalid because of the change of outer tiles |
330 | // To handle this case properly, we check if the tiling strategy is still perfect. |
331 | // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles |
332 | if (len % l != 0) { |
333 | l = len; |
334 | } |
335 | len /= l; |
336 | } |
337 | result[0] = len; |
338 | } else { |
339 | // Case 3. Use fresh new sampling result |
340 | result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); |
341 | if (max_innermost_factor != -1) { |
342 | ICHECK_LE(result.back(), max_innermost_factor); |
343 | } |
344 | } |
345 | *decision = support::AsArray<int64_t, Integer>(result); |
346 | return result; |
347 | } |
348 | |
349 | tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, |
350 | support::LinearCongruentialEngine::TRandState* rand_state, |
351 | const StmtSRef& block_sref, Optional<Integer>* decision) { |
352 | // Step 1. Collect all possible compute-at locations. |
353 | auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); |
354 | ICHECK_EQ(location_srefs.size(), location_indices.size()); |
355 | |
356 | // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the |
357 | // location candidates. Otherwise, pick the location before the previous decision. |
358 | // Step 3. If there was not a previous decision, sample a decision from the collected locations. |
359 | if (decision->defined()) { |
360 | int64_t old_decision = Downcast<Integer>(*decision)->value; |
361 | auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision); |
362 | int idx = it - location_indices.begin(); |
363 | |
364 | if (it != location_indices.end() && *it == old_decision) { |
365 | *decision = Integer(old_decision); |
366 | return location_srefs[idx]; |
367 | } else if (it != location_indices.begin()) { |
368 | *decision = Integer(location_indices[idx - 1]); |
369 | return location_srefs[idx - 1]; |
370 | } else { |
371 | *decision = Integer(-1); |
372 | return StmtSRef::RootMark(); |
373 | } |
374 | } else { |
375 | int sampled_idx = SampleInt(rand_state, 0, location_indices.size()); |
376 | *decision = Integer(location_indices[sampled_idx]); |
377 | return location_srefs[sampled_idx]; |
378 | } |
379 | } |
380 | |
381 | /******** InstructionKind Registration ********/ |
382 | |
383 | struct SampleCategoricalTraits : public UnpackedInstTraits<SampleCategoricalTraits> { |
384 | static constexpr const char* kName = "SampleCategorical" ; |
385 | static constexpr bool kIsPure = true; |
386 | |
387 | private: |
388 | static constexpr size_t kNumInputs = 0; |
389 | static constexpr size_t kNumAttrs = 2; |
390 | static constexpr size_t kNumDecisions = 1; |
391 | |
392 | static ExprRV UnpackedApplyToSchedule(Schedule sch, // |
393 | Array<Integer> candidates, // |
394 | Array<FloatImm> probs, // |
395 | Optional<Integer> decision) { |
396 | return sch->SampleCategorical(candidates, probs, decision); |
397 | } |
398 | |
399 | static String UnpackedAsPython(Array<String> outputs, // |
400 | Array<Integer> candidates, // |
401 | Array<FloatImm> probs, // |
402 | Optional<Integer> decision) { |
403 | PythonAPICall py("sample_categorical" ); |
404 | py.Input("candidates" , candidates); |
405 | py.Input("probs" , probs); |
406 | py.Decision(decision); |
407 | py.SingleOutput(outputs); |
408 | return py.Str(); |
409 | } |
410 | |
411 | template <typename> |
412 | friend struct ::tvm::tir::UnpackedInstTraits; |
413 | }; |
414 | |
415 | struct SamplePerfectTileTraits : public UnpackedInstTraits<SamplePerfectTileTraits> { |
416 | static constexpr const char* kName = "SamplePerfectTile" ; |
417 | static constexpr bool kIsPure = true; |
418 | |
419 | private: |
420 | static constexpr size_t kNumInputs = 1; |
421 | static constexpr size_t kNumAttrs = 2; |
422 | static constexpr size_t kNumDecisions = 1; |
423 | |
424 | static Array<ExprRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, |
425 | Integer max_innermost_factor, |
426 | Optional<Array<Integer>> decision) { |
427 | return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); |
428 | } |
429 | |
430 | static String UnpackedAsPython(Array<String> outputs, String loop_rv, Integer n, |
431 | Integer max_innermost_factor, Optional<Array<Integer>> decision) { |
432 | PythonAPICall py("sample_perfect_tile" ); |
433 | py.Input("loop" , loop_rv); |
434 | py.Input("n" , n->value); |
435 | py.Input("max_innermost_factor" , max_innermost_factor->value); |
436 | py.Decision(decision); |
437 | py.OutputList(outputs); |
438 | return py.Str(); |
439 | } |
440 | |
441 | template <typename> |
442 | friend struct ::tvm::tir::UnpackedInstTraits; |
443 | }; |
444 | |
445 | struct SampleComputeLocationTraits : public UnpackedInstTraits<SampleComputeLocationTraits> { |
446 | static constexpr const char* kName = "SampleComputeLocation" ; |
447 | static constexpr bool kIsPure = true; |
448 | |
449 | private: |
450 | static constexpr size_t kNumInputs = 1; |
451 | static constexpr size_t kNumAttrs = 0; |
452 | static constexpr size_t kNumDecisions = 1; |
453 | |
454 | static LoopRV UnpackedApplyToSchedule(Schedule sch, // |
455 | BlockRV block_rv, // |
456 | Optional<Integer> decision) { |
457 | return sch->SampleComputeLocation(block_rv, decision); |
458 | } |
459 | |
460 | static String UnpackedAsPython(Array<String> outputs, // |
461 | String block_rv, // |
462 | Optional<Integer> decision) { |
463 | PythonAPICall py("sample_compute_location" ); |
464 | py.Input("block" , block_rv); |
465 | py.Decision(decision); |
466 | py.SingleOutput(outputs); |
467 | return py.Str(); |
468 | } |
469 | |
470 | template <typename> |
471 | friend struct ::tvm::tir::UnpackedInstTraits; |
472 | }; |
473 | |
474 | TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); |
475 | TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); |
476 | TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); |
477 | |
478 | } // namespace tir |
479 | } // namespace tvm |
480 | |