1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule/cuda/thread_bind.h> |
20 | #include <tvm/tir/op.h> |
21 | #include <tvm/tir/schedule/schedule.h> |
22 | |
23 | #include <algorithm> |
24 | #include <limits> |
25 | #include <utility> |
26 | |
27 | #include "../../utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace meta_schedule { |
31 | |
32 | using namespace tvm::tir; |
33 | |
34 | std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<Integer> thread_extents) { |
35 | return [sch = std::move(sch), |
36 | thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { |
37 | Array<Integer> extents; |
38 | extents.reserve(thread_extents.size()); |
39 | for (const Integer extent : thread_extents) { |
40 | if (extent->value <= max_extent) { |
41 | extents.push_back(extent); |
42 | } |
43 | } |
44 | int n = extents.size(); |
45 | if (n == 0) { |
46 | return Integer(max_extent); |
47 | } |
48 | if (n == 1) { |
49 | return Integer(extents[0]); |
50 | } |
51 | Array<FloatImm> probs(n, FloatImm(DataType::Float(64), 1.0 / n)); |
52 | return sch->SampleCategorical(extents, probs); |
53 | }; |
54 | } |
55 | |
56 | Array<LoopRV> BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, |
57 | int64_t max_threads_per_block, |
58 | std::function<ExprRV(int64_t)> get_factor) { |
59 | int64_t extent = -1; |
60 | if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { |
61 | extent = *e; |
62 | } else { |
63 | extent = std::numeric_limits<int64_t>::max(); |
64 | } |
65 | if (extent <= max_threadblocks * max_threads_per_block) { |
66 | if (!get_factor) { |
67 | get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); |
68 | } |
69 | ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); |
70 | Array<LoopRV> splits = sch->Split(loop, {NullOpt, factor}); |
71 | ICHECK_EQ(splits.size(), 2); |
72 | sch->Bind(splits[0], "blockIdx.x" ); |
73 | sch->Bind(splits[1], "threadIdx.x" ); |
74 | return {splits[0], splits[1]}; |
75 | } else { |
76 | Array<LoopRV> splits = sch->Split(loop, {NullOpt, |
77 | Integer(max_threadblocks), // |
78 | Integer(max_threads_per_block)}); |
79 | ICHECK_EQ(splits.size(), 3); |
80 | sch->Reorder({splits[1], splits[2], splits[0]}); |
81 | sch->Bind(splits[1], "blockIdx.x" ); |
82 | sch->Bind(splits[2], "threadIdx.x" ); |
83 | return {splits[1], splits[2], splits[0]}; |
84 | } |
85 | } |
86 | |
87 | void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // |
88 | int64_t max_threadblocks, int64_t max_threads_per_block, |
89 | std::function<tir::ExprRV(int64_t)> get_factor) { |
90 | using namespace tvm::tir; |
91 | StmtSRef block_sref = sch->GetSRef(block_rv); |
92 | if (block_sref->parent == nullptr) { |
93 | return; |
94 | } |
95 | if (tir::HasBeenMultiLevelTiled(block_sref)) { |
96 | return; |
97 | } |
98 | Array<StmtSRef> loops = tir::GetLoops(block_sref); |
99 | int n = loops.size(); |
100 | int i_block_idx = -1; |
101 | int i_thread_idx = -1; |
102 | int i_multi_child = -1; |
103 | int i_spatial_loop = -1; |
104 | for (int i = 0; i < n; ++i) { |
105 | const StmtSRef& loop_sref = loops[i]; |
106 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
107 | runtime::ThreadScope thread_scope = GetThreadScope(loop); |
108 | if (IsBlockIdx(thread_scope)) { |
109 | if (i_block_idx == -1) { |
110 | i_block_idx = i; |
111 | } |
112 | } |
113 | if (IsThreadIdx(thread_scope)) { |
114 | if (i_thread_idx == -1) { |
115 | i_thread_idx = i; |
116 | } |
117 | } |
118 | if (loop->kind != ForKind::kSerial) { |
119 | if (i_multi_child == -1) { |
120 | i_multi_child = i; |
121 | } |
122 | } |
123 | if (!IsSingleStmt(loop->body)) { |
124 | if (i_multi_child == -1) { |
125 | i_multi_child = i + 1; |
126 | } |
127 | } |
128 | if (GetLoopIterType(loop_sref) == IterVarType::kDataPar) { |
129 | if (i_spatial_loop == i - 1) { |
130 | ++i_spatial_loop; |
131 | } |
132 | } |
133 | } |
134 | if (i_multi_child == -1) { |
135 | i_multi_child = n; |
136 | } |
137 | if (i_block_idx != -1 && i_thread_idx != -1) { |
138 | return; |
139 | } |
140 | if (i_block_idx != -1 && i_thread_idx == -1) { |
141 | ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not" ; |
142 | throw; |
143 | } |
144 | LoopRV loop_rv{nullptr}; |
145 | { |
146 | Array<LoopRV> loop_rvs = sch->GetLoops(block_rv); |
147 | if (i_spatial_loop == -1) { |
148 | LoopRV spatial_loop_rv{nullptr}; |
149 | if (loop_rvs.empty()) { |
150 | spatial_loop_rv = sch->AddUnitLoop(block_rv); |
151 | } else { |
152 | spatial_loop_rv = sch->AddUnitLoop(loop_rvs[0]); |
153 | } |
154 | loop_rvs.insert(loop_rvs.begin(), spatial_loop_rv); |
155 | i_spatial_loop = 0; |
156 | if (i_block_idx != -1) { |
157 | i_block_idx += 1; |
158 | } |
159 | if (i_thread_idx != -1) { |
160 | i_thread_idx += 1; |
161 | } |
162 | if (i_multi_child != -1) { |
163 | i_multi_child += 1; |
164 | } |
165 | } |
166 | if (i_block_idx == -1 && i_thread_idx != -1) { |
167 | int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); |
168 | Array<LoopRV> loop_rvs = sch->GetLoops(block_rv); |
169 | loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); |
170 | sch->Bind(loop_rv, "blockIdx.x" ); |
171 | return; |
172 | } else { // i_block_idx == -1 && i_thread_idx == -1 |
173 | int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); |
174 | loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); |
175 | } |
176 | } |
177 | BindSpatialLoop(sch, loop_rv, max_threadblocks, max_threads_per_block, get_factor); |
178 | } |
179 | |
180 | } // namespace meta_schedule |
181 | } // namespace tvm |
182 | |