1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule/cuda/thread_bind.h>
20#include <tvm/tir/op.h>
21#include <tvm/tir/schedule/schedule.h>
22
23#include <algorithm>
24#include <limits>
25#include <utility>
26
27#include "../../utils.h"
28
29namespace tvm {
30namespace meta_schedule {
31
32using namespace tvm::tir;
33
34std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<Integer> thread_extents) {
35 return [sch = std::move(sch),
36 thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV {
37 Array<Integer> extents;
38 extents.reserve(thread_extents.size());
39 for (const Integer extent : thread_extents) {
40 if (extent->value <= max_extent) {
41 extents.push_back(extent);
42 }
43 }
44 int n = extents.size();
45 if (n == 0) {
46 return Integer(max_extent);
47 }
48 if (n == 1) {
49 return Integer(extents[0]);
50 }
51 Array<FloatImm> probs(n, FloatImm(DataType::Float(64), 1.0 / n));
52 return sch->SampleCategorical(extents, probs);
53 };
54}
55
56Array<LoopRV> BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks,
57 int64_t max_threads_per_block,
58 std::function<ExprRV(int64_t)> get_factor) {
59 int64_t extent = -1;
60 if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) {
61 extent = *e;
62 } else {
63 extent = std::numeric_limits<int64_t>::max();
64 }
65 if (extent <= max_threadblocks * max_threads_per_block) {
66 if (!get_factor) {
67 get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
68 }
69 ExprRV factor = get_factor(std::min(extent, max_threads_per_block));
70 Array<LoopRV> splits = sch->Split(loop, {NullOpt, factor});
71 ICHECK_EQ(splits.size(), 2);
72 sch->Bind(splits[0], "blockIdx.x");
73 sch->Bind(splits[1], "threadIdx.x");
74 return {splits[0], splits[1]};
75 } else {
76 Array<LoopRV> splits = sch->Split(loop, {NullOpt,
77 Integer(max_threadblocks), //
78 Integer(max_threads_per_block)});
79 ICHECK_EQ(splits.size(), 3);
80 sch->Reorder({splits[1], splits[2], splits[0]});
81 sch->Bind(splits[1], "blockIdx.x");
82 sch->Bind(splits[2], "threadIdx.x");
83 return {splits[1], splits[2], splits[0]};
84 }
85}
86
87void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, //
88 int64_t max_threadblocks, int64_t max_threads_per_block,
89 std::function<tir::ExprRV(int64_t)> get_factor) {
90 using namespace tvm::tir;
91 StmtSRef block_sref = sch->GetSRef(block_rv);
92 if (block_sref->parent == nullptr) {
93 return;
94 }
95 if (tir::HasBeenMultiLevelTiled(block_sref)) {
96 return;
97 }
98 Array<StmtSRef> loops = tir::GetLoops(block_sref);
99 int n = loops.size();
100 int i_block_idx = -1;
101 int i_thread_idx = -1;
102 int i_multi_child = -1;
103 int i_spatial_loop = -1;
104 for (int i = 0; i < n; ++i) {
105 const StmtSRef& loop_sref = loops[i];
106 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
107 runtime::ThreadScope thread_scope = GetThreadScope(loop);
108 if (IsBlockIdx(thread_scope)) {
109 if (i_block_idx == -1) {
110 i_block_idx = i;
111 }
112 }
113 if (IsThreadIdx(thread_scope)) {
114 if (i_thread_idx == -1) {
115 i_thread_idx = i;
116 }
117 }
118 if (loop->kind != ForKind::kSerial) {
119 if (i_multi_child == -1) {
120 i_multi_child = i;
121 }
122 }
123 if (!IsSingleStmt(loop->body)) {
124 if (i_multi_child == -1) {
125 i_multi_child = i + 1;
126 }
127 }
128 if (GetLoopIterType(loop_sref) == IterVarType::kDataPar) {
129 if (i_spatial_loop == i - 1) {
130 ++i_spatial_loop;
131 }
132 }
133 }
134 if (i_multi_child == -1) {
135 i_multi_child = n;
136 }
137 if (i_block_idx != -1 && i_thread_idx != -1) {
138 return;
139 }
140 if (i_block_idx != -1 && i_thread_idx == -1) {
141 ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not";
142 throw;
143 }
144 LoopRV loop_rv{nullptr};
145 {
146 Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
147 if (i_spatial_loop == -1) {
148 LoopRV spatial_loop_rv{nullptr};
149 if (loop_rvs.empty()) {
150 spatial_loop_rv = sch->AddUnitLoop(block_rv);
151 } else {
152 spatial_loop_rv = sch->AddUnitLoop(loop_rvs[0]);
153 }
154 loop_rvs.insert(loop_rvs.begin(), spatial_loop_rv);
155 i_spatial_loop = 0;
156 if (i_block_idx != -1) {
157 i_block_idx += 1;
158 }
159 if (i_thread_idx != -1) {
160 i_thread_idx += 1;
161 }
162 if (i_multi_child != -1) {
163 i_multi_child += 1;
164 }
165 }
166 if (i_block_idx == -1 && i_thread_idx != -1) {
167 int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1);
168 Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
169 loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
170 sch->Bind(loop_rv, "blockIdx.x");
171 return;
172 } else { // i_block_idx == -1 && i_thread_idx == -1
173 int num_fuse = std::min(i_multi_child, i_spatial_loop + 1);
174 loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
175 }
176 }
177 BindSpatialLoop(sch, loop_rv, max_threadblocks, max_threads_per_block, get_factor);
178}
179
180} // namespace meta_schedule
181} // namespace tvm
182