1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
20#define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
21
22#include <tvm/tir/schedule/schedule.h>
23
24#include <algorithm>
25#include <limits>
26#include <utility>
27
28namespace tvm {
29namespace meta_schedule {
30
31/*!
32 * \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical`
33 * to return a random thread extent.
34 * \param sch The schedule
35 * \param thread_extents The candidate thread extents.
36 * \return A sampler that returns a random thread extent.
37 */
38std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
39 Array<Integer> thread_extents);
40
41/*!
42 * \brief Bind blockIdx.x and threadIdx.x to the given loop
43 * \param sch The schedule.
44 * \param loop The loop to be bound.
45 * \param max_threadblocks The maximum number of threadblocks allowed.
46 * \param max_threads_per_block The maximum number of threads allowed.
47 * \param get_factor A function that returns the tiling factor.
48 * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest.
49 */
50Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
51 int64_t max_threadblocks, int64_t max_threads_per_block,
52 std::function<tir::ExprRV(int64_t)> get_factor = nullptr);
53
54/*!
55 * \brief Bind the given block if it is not bound to blockIdx or threadIdx.
56 * \param sch The schedule.
57 * \param block The block to be bound.
58 * \param max_threadblocks The maximum number of threadblocks allowed.
59 * \param max_threads_per_block The maximum number of threads allowed.
60 * \param get_factor A function that returns the tiling factor.
61 */
62void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, //
63 int64_t max_threadblocks, int64_t max_threads_per_block,
64 std::function<tir::ExprRV(int64_t max_extent)> get_factor = nullptr);
65
66} // namespace meta_schedule
67} // namespace tvm
68
69#endif // TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
70