1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./multi_level_tiling.h" |
20 | |
21 | #include <tvm/meta_schedule/schedule_rule.h> |
22 | |
23 | #include <algorithm> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "../utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace tir { |
31 | /*! |
32 | * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction |
33 | * buffers' dimensions as -1 |
34 | * \param block_sref The block to be processed |
35 | * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers |
36 | * \note The method is not designed for generic analysis and relies on assumptions in the scenario |
37 | * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header |
38 | */ |
39 | std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) { |
40 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
41 | const BufferNode* write_buffer = block->writes[0]->buffer.get(); |
42 | int n = block->reads.size(); |
43 | std::vector<int> results(n, -1); |
44 | for (int i = 0; i < n; ++i) { |
45 | const BufferNode* read_buffer = block->reads[i]->buffer.get(); |
46 | if (read_buffer != write_buffer) { |
47 | results[i] = read_buffer->shape.size(); |
48 | } |
49 | } |
50 | return results; |
51 | } |
52 | |
53 | } // namespace tir |
54 | } // namespace tvm |
55 | |
56 | namespace tvm { |
57 | namespace meta_schedule { |
58 | |
59 | using tir::BlockRV; |
60 | using tir::IterVarType; |
61 | using tir::LoopRV; |
62 | using tir::Schedule; |
63 | |
64 | TVM_REGISTER_OBJECT_TYPE(StateNode); |
65 | |
66 | State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) { |
67 | ObjectPtr<StateNode> node = make_object<StateNode>(); |
68 | node->sch = std::move(sch); |
69 | node->block_rv = std::move(block_rv); |
70 | node->tiles = std::move(tiles); |
71 | data_ = std::move(node); |
72 | } |
73 | |
74 | State StateNode::Copy() const { |
75 | ObjectPtr<StateNode> node = make_object<StateNode>(*this); |
76 | node->sch = sch->Copy(); |
77 | return State(node); |
78 | } |
79 | |
80 | // Do nothing; Inherited from ScheduleRuleNode |
81 | void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) { |
82 | if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block" )) { |
83 | this->max_threads_per_block_ = v.value()->value; |
84 | if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size" )) { |
85 | this->thread_warp_size_ = v.value()->value; |
86 | } else { |
87 | TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target" ; |
88 | } |
89 | } |
90 | logger = context->logger; |
91 | } |
92 | |
93 | // Entry of the mega rule; Inherited from ScheduleRuleNode |
94 | Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { |
95 | if ((filter_fn_ && filter_fn_.value()(sch, sch->GetSRef(block_rv))) || |
96 | NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { |
97 | sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); |
98 | |
99 | Array<Schedule> results; |
100 | for (auto&& state : ApplySubRules({State(sch, block_rv)})) { |
101 | results.push_back(std::move(state->sch)); |
102 | } |
103 | return results; |
104 | } |
105 | return {sch}; |
106 | } |
107 | |
108 | // Inherited from ScheduleRuleNode |
109 | ScheduleRule MultiLevelTilingNode::Clone() const { |
110 | ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>(*this); |
111 | return ScheduleRule(n); |
112 | } |
113 | |
114 | std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) { |
115 | states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); |
116 | states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); |
117 | states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); }); |
118 | return states; |
119 | } |
120 | |
121 | std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const { |
122 | const ReuseConfig& config = this->reuse_write_; |
123 | if (config.req == ReuseType::kNoReuse) { |
124 | return {std::move(state)}; |
125 | } |
126 | std::vector<int> levels = config.levels; |
127 | ReuseType req = config.req; |
128 | if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>( |
129 | state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level" )) { |
130 | req = ReuseType::kMustReuse; |
131 | levels.clear(); |
132 | std::transform(ann.value().begin(), ann.value().end(), std::back_inserter(levels), |
133 | [](auto&& v) { return v.IntValue(); }); |
134 | } |
135 | std::vector<State> results; |
136 | if (req == ReuseType::kMayReuse) { |
137 | // Case 1. If the write cache is already there, we don't need to add another. |
138 | Array<BlockRV> consumer_rvs = state->sch->GetConsumers(state->block_rv); |
139 | if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) { |
140 | for (int level : levels) { |
141 | State new_state = state->Copy(); |
142 | const LoopRV& loop_rv = new_state->tiles[level - 1].back(); |
143 | new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); |
144 | results.push_back(std::move(new_state)); |
145 | } |
146 | state->write_reuse.emplace(0, consumer_rvs[0]); |
147 | results.push_back(state); |
148 | return results; |
149 | } else { |
150 | // Case 2. No write cache is added |
151 | State new_state = state->Copy(); |
152 | results.emplace_back(std::move(new_state)); |
153 | } |
154 | } |
155 | |
156 | // Case 3. Add one write cache |
157 | BlockRV write_cache = |
158 | state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0, |
159 | /*storage_scope=*/config.scope); |
160 | state->write_reuse.emplace(0, write_cache); |
161 | for (int level : levels) { |
162 | State new_state = state->Copy(); |
163 | const LoopRV& loop_rv = new_state->tiles[level - 1].back(); |
164 | new_state->sch->ReverseComputeAt(write_cache, loop_rv, true); |
165 | results.push_back(std::move(new_state)); |
166 | } |
167 | return results; |
168 | } |
169 | |
170 | Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, |
171 | int n_tiles) const { |
172 | Array<tir::ExprRV> factors = sch->SamplePerfectTile( |
173 | /*loop=*/loop, |
174 | /*n=*/n_tiles, |
175 | /*max_innermost_factor=*/max_innermost_factor); |
176 | Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop, |
177 | /*factors=*/{factors.begin(), factors.end()}); |
178 | return splits; |
179 | } |
180 | |
181 | std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const { |
182 | Schedule& sch = state->sch; |
183 | const BlockRV& block_rv = state->block_rv; |
184 | // Step 1. Assuming trivial binding, pair the loops and their iter-var-types |
185 | Array<LoopRV> loops = sch->GetLoops(block_rv); |
186 | std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); |
187 | ICHECK_EQ(loops.size(), iter_types.size()); |
188 | // Step 2. For each loop axis, tile it |
189 | int64_t spatial_loop_product = 1; |
190 | std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size()); |
191 | for (int i = 0, n = loops.size(); i < n; ++i) { |
192 | LoopRV loop = loops[i]; |
193 | const std::vector<int>* idx = nullptr; |
194 | |
195 | if (iter_types[i] == IterVarType::kDataPar) { |
196 | idx = &s_indices_; |
197 | if (spatial_loop_product != -1) { |
198 | if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { |
199 | spatial_loop_product *= *extent; |
200 | } else { |
201 | spatial_loop_product = -1; |
202 | } |
203 | } |
204 | } else if (iter_types[i] == IterVarType::kCommReduce) { |
205 | idx = &r_indices_; |
206 | } else { |
207 | continue; |
208 | } |
209 | |
210 | const int n_tiles = idx->size(); |
211 | |
212 | if (n_tiles == 1) { |
213 | tiles[idx->at(0)].push_back(loop); |
214 | } else { |
215 | auto splits = SplitLoop(sch, block_rv, loop, n_tiles); |
216 | |
217 | // Put every tile to its slot |
218 | for (int j = 0; j < n_tiles; ++j) { |
219 | tiles[idx->at(j)].push_back(splits[j]); |
220 | } |
221 | } |
222 | } |
223 | // Step 3. Reorder to organize the tiles |
224 | sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end())); |
225 | // Step 4. Bind the tiles to threads |
226 | int n_binds = std::min(tile_binds.size(), tiles.size()); |
227 | for (int i = 0; i < n_binds; ++i) { |
228 | LoopRV fused = sch->Fuse(tiles[i]); |
229 | sch->Bind(fused, tile_binds[i]); |
230 | tiles[i] = {fused}; |
231 | } |
232 | state->tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()}; |
233 | if (this->thread_warp_size_ != -1) { |
234 | int64_t low_inclusive = 1; |
235 | int64_t high_inclusive = this->max_threads_per_block_; |
236 | if (spatial_loop_product > 2 * this->thread_warp_size_) { |
237 | low_inclusive = this->thread_warp_size_; |
238 | } |
239 | sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, |
240 | Integer(low_inclusive)); |
241 | sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, |
242 | Integer(high_inclusive)); |
243 | } |
244 | return {state}; |
245 | } |
246 | |
247 | std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const { |
248 | const ReuseConfig& config = this->reuse_read_; |
249 | if (config.req == ReuseType::kNoReuse) { |
250 | return {std::move(state)}; |
251 | } |
252 | ICHECK(config.req != ReuseType::kMayReuse); |
253 | const BlockRV& block_rv = state->block_rv; |
254 | std::vector<State> results; |
255 | results.reserve(config.levels.size()); |
256 | for (int level : config.levels) { |
257 | State new_state = state->Copy(); |
258 | Schedule& sch = new_state->sch; |
259 | const LoopRV& loop_rv = state->tiles[level - 1].back(); |
260 | // Enumerate all buffers that are read but not written |
261 | std::vector<int> read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); |
262 | for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { |
263 | int buffer_ndim = read_buffer_ndims[i]; |
264 | if (buffer_ndim == -1) { |
265 | continue; |
266 | } |
267 | // Do cache_read |
268 | BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope, {block_rv}); |
269 | // Insert cache_read block to the proper place |
270 | sch->ComputeAt(cache_read_block, loop_rv, true); |
271 | // Fuse the iterators of the cache_read |
272 | Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block); |
273 | sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, // |
274 | buffer_loops.end()}); |
275 | AnnotateCooperativeFetching(&sch, cache_read_block); |
276 | new_state->read_reuse.emplace(i, cache_read_block); |
277 | } |
278 | results.push_back(std::move(new_state)); |
279 | } |
280 | return results; |
281 | } |
282 | |
283 | void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, |
284 | const tir::BlockRV& block) const { |
285 | // Filter out invalid vector lanes according to the data type. |
286 | const tir::BlockNode* block_node = (*sch)->GetSRef(block)->StmtAs<tir::BlockNode>(); |
287 | ICHECK_EQ(block_node->writes.size(), 1); |
288 | const runtime::DataType dtype = block_node->writes[0]->buffer->dtype; |
289 | std::function<bool(int)> f_filter = nullptr; |
290 | if (dtype == runtime::DataType::Float(32)) { |
291 | f_filter = [&](int vector_len) { return vector_len <= 4; }; |
292 | } else if (dtype == runtime::DataType::Float(16)) { |
293 | f_filter = [&](int vector_len) { |
294 | return (vector_len == 1 || vector_len % 2 == 0) && vector_len <= 8; |
295 | }; |
296 | } else if (dtype == runtime::DataType::Int(8)) { |
297 | f_filter = [&](int vector_len) { return vector_len <= 16; }; |
298 | } |
299 | std::vector<int> valid_vector_lens; |
300 | valid_vector_lens.reserve(vector_load_lens.size()); |
301 | if (f_filter != nullptr) { |
302 | std::copy_if(vector_load_lens.begin(), vector_load_lens.end(), |
303 | std::back_inserter(valid_vector_lens), f_filter); |
304 | } else { |
305 | valid_vector_lens = vector_load_lens; |
306 | } |
307 | |
308 | if (!valid_vector_lens.empty()) { |
309 | int n = valid_vector_lens.size(); |
310 | double prob = 1.0 / n; |
311 | tir::ExprRV vector_load_len = |
312 | (*sch)->SampleCategorical(support::AsArray<int, Integer>(valid_vector_lens), |
313 | Array<FloatImm>(n, FloatImm(DataType::Float(64), prob))); |
314 | (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); |
315 | } |
316 | } |
317 | |
318 | // Constructor |
319 | |
320 | ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds, |
321 | Optional<Integer> max_innermost_factor, |
322 | Optional<Array<Integer>> vector_load_lens, |
323 | Optional<Map<String, ObjectRef>> reuse_read, |
324 | Optional<Map<String, ObjectRef>> reuse_write, |
325 | Optional<runtime::PackedFunc> filter_fn) { |
326 | auto node = MultiLevelTilingInitCommon<MultiLevelTilingNode>( |
327 | structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); |
328 | node->filter_fn_ = filter_fn; |
329 | return ScheduleRule(node); |
330 | } |
331 | |
332 | TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); |
333 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling" ) |
334 | .set_body_typed(ScheduleRule::MultiLevelTiling); |
335 | |
336 | } // namespace meta_schedule |
337 | } // namespace tvm |
338 | |