1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./multi_level_tiling.h"
20
21#include <tvm/meta_schedule/schedule_rule.h>
22
23#include <algorithm>
24#include <utility>
25#include <vector>
26
27#include "../utils.h"
28
29namespace tvm {
30namespace tir {
31/*!
32 * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction
33 * buffers' dimensions as -1
34 * \param block_sref The block to be processed
35 * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers
36 * \note The method is not designed for generic analysis and relies on assumptions in the scenario
37 * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header
38 */
39std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) {
40 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
41 const BufferNode* write_buffer = block->writes[0]->buffer.get();
42 int n = block->reads.size();
43 std::vector<int> results(n, -1);
44 for (int i = 0; i < n; ++i) {
45 const BufferNode* read_buffer = block->reads[i]->buffer.get();
46 if (read_buffer != write_buffer) {
47 results[i] = read_buffer->shape.size();
48 }
49 }
50 return results;
51}
52
53} // namespace tir
54} // namespace tvm
55
56namespace tvm {
57namespace meta_schedule {
58
59using tir::BlockRV;
60using tir::IterVarType;
61using tir::LoopRV;
62using tir::Schedule;
63
64TVM_REGISTER_OBJECT_TYPE(StateNode);
65
66State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
67 ObjectPtr<StateNode> node = make_object<StateNode>();
68 node->sch = std::move(sch);
69 node->block_rv = std::move(block_rv);
70 node->tiles = std::move(tiles);
71 data_ = std::move(node);
72}
73
74State StateNode::Copy() const {
75 ObjectPtr<StateNode> node = make_object<StateNode>(*this);
76 node->sch = sch->Copy();
77 return State(node);
78}
79
80// Do nothing; Inherited from ScheduleRuleNode
81void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
82 if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
83 this->max_threads_per_block_ = v.value()->value;
84 if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
85 this->thread_warp_size_ = v.value()->value;
86 } else {
87 TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
88 }
89 }
90 logger = context->logger;
91}
92
93// Entry of the mega rule; Inherited from ScheduleRuleNode
94Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) {
95 if ((filter_fn_ && filter_fn_.value()(sch, sch->GetSRef(block_rv))) ||
96 NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
97 sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
98
99 Array<Schedule> results;
100 for (auto&& state : ApplySubRules({State(sch, block_rv)})) {
101 results.push_back(std::move(state->sch));
102 }
103 return results;
104 }
105 return {sch};
106}
107
108// Inherited from ScheduleRuleNode
109ScheduleRule MultiLevelTilingNode::Clone() const {
110 ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>(*this);
111 return ScheduleRule(n);
112}
113
114std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
115 states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
116 states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
117 states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
118 return states;
119}
120
121std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
122 const ReuseConfig& config = this->reuse_write_;
123 if (config.req == ReuseType::kNoReuse) {
124 return {std::move(state)};
125 }
126 std::vector<int> levels = config.levels;
127 ReuseType req = config.req;
128 if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
129 state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) {
130 req = ReuseType::kMustReuse;
131 levels.clear();
132 std::transform(ann.value().begin(), ann.value().end(), std::back_inserter(levels),
133 [](auto&& v) { return v.IntValue(); });
134 }
135 std::vector<State> results;
136 if (req == ReuseType::kMayReuse) {
137 // Case 1. If the write cache is already there, we don't need to add another.
138 Array<BlockRV> consumer_rvs = state->sch->GetConsumers(state->block_rv);
139 if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) {
140 for (int level : levels) {
141 State new_state = state->Copy();
142 const LoopRV& loop_rv = new_state->tiles[level - 1].back();
143 new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
144 results.push_back(std::move(new_state));
145 }
146 state->write_reuse.emplace(0, consumer_rvs[0]);
147 results.push_back(state);
148 return results;
149 } else {
150 // Case 2. No write cache is added
151 State new_state = state->Copy();
152 results.emplace_back(std::move(new_state));
153 }
154 }
155
156 // Case 3. Add one write cache
157 BlockRV write_cache =
158 state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
159 /*storage_scope=*/config.scope);
160 state->write_reuse.emplace(0, write_cache);
161 for (int level : levels) {
162 State new_state = state->Copy();
163 const LoopRV& loop_rv = new_state->tiles[level - 1].back();
164 new_state->sch->ReverseComputeAt(write_cache, loop_rv, true);
165 results.push_back(std::move(new_state));
166 }
167 return results;
168}
169
170Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop,
171 int n_tiles) const {
172 Array<tir::ExprRV> factors = sch->SamplePerfectTile(
173 /*loop=*/loop,
174 /*n=*/n_tiles,
175 /*max_innermost_factor=*/max_innermost_factor);
176 Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
177 /*factors=*/{factors.begin(), factors.end()});
178 return splits;
179}
180
181std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
182 Schedule& sch = state->sch;
183 const BlockRV& block_rv = state->block_rv;
184 // Step 1. Assuming trivial binding, pair the loops and their iter-var-types
185 Array<LoopRV> loops = sch->GetLoops(block_rv);
186 std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv));
187 ICHECK_EQ(loops.size(), iter_types.size());
188 // Step 2. For each loop axis, tile it
189 int64_t spatial_loop_product = 1;
190 std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
191 for (int i = 0, n = loops.size(); i < n; ++i) {
192 LoopRV loop = loops[i];
193 const std::vector<int>* idx = nullptr;
194
195 if (iter_types[i] == IterVarType::kDataPar) {
196 idx = &s_indices_;
197 if (spatial_loop_product != -1) {
198 if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
199 spatial_loop_product *= *extent;
200 } else {
201 spatial_loop_product = -1;
202 }
203 }
204 } else if (iter_types[i] == IterVarType::kCommReduce) {
205 idx = &r_indices_;
206 } else {
207 continue;
208 }
209
210 const int n_tiles = idx->size();
211
212 if (n_tiles == 1) {
213 tiles[idx->at(0)].push_back(loop);
214 } else {
215 auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
216
217 // Put every tile to its slot
218 for (int j = 0; j < n_tiles; ++j) {
219 tiles[idx->at(j)].push_back(splits[j]);
220 }
221 }
222 }
223 // Step 3. Reorder to organize the tiles
224 sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
225 // Step 4. Bind the tiles to threads
226 int n_binds = std::min(tile_binds.size(), tiles.size());
227 for (int i = 0; i < n_binds; ++i) {
228 LoopRV fused = sch->Fuse(tiles[i]);
229 sch->Bind(fused, tile_binds[i]);
230 tiles[i] = {fused};
231 }
232 state->tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
233 if (this->thread_warp_size_ != -1) {
234 int64_t low_inclusive = 1;
235 int64_t high_inclusive = this->max_threads_per_block_;
236 if (spatial_loop_product > 2 * this->thread_warp_size_) {
237 low_inclusive = this->thread_warp_size_;
238 }
239 sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive,
240 Integer(low_inclusive));
241 sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive,
242 Integer(high_inclusive));
243 }
244 return {state};
245}
246
247std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
248 const ReuseConfig& config = this->reuse_read_;
249 if (config.req == ReuseType::kNoReuse) {
250 return {std::move(state)};
251 }
252 ICHECK(config.req != ReuseType::kMayReuse);
253 const BlockRV& block_rv = state->block_rv;
254 std::vector<State> results;
255 results.reserve(config.levels.size());
256 for (int level : config.levels) {
257 State new_state = state->Copy();
258 Schedule& sch = new_state->sch;
259 const LoopRV& loop_rv = state->tiles[level - 1].back();
260 // Enumerate all buffers that are read but not written
261 std::vector<int> read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv));
262 for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) {
263 int buffer_ndim = read_buffer_ndims[i];
264 if (buffer_ndim == -1) {
265 continue;
266 }
267 // Do cache_read
268 BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope, {block_rv});
269 // Insert cache_read block to the proper place
270 sch->ComputeAt(cache_read_block, loop_rv, true);
271 // Fuse the iterators of the cache_read
272 Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
273 sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
274 buffer_loops.end()});
275 AnnotateCooperativeFetching(&sch, cache_read_block);
276 new_state->read_reuse.emplace(i, cache_read_block);
277 }
278 results.push_back(std::move(new_state));
279 }
280 return results;
281}
282
283void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
284 const tir::BlockRV& block) const {
285 // Filter out invalid vector lanes according to the data type.
286 const tir::BlockNode* block_node = (*sch)->GetSRef(block)->StmtAs<tir::BlockNode>();
287 ICHECK_EQ(block_node->writes.size(), 1);
288 const runtime::DataType dtype = block_node->writes[0]->buffer->dtype;
289 std::function<bool(int)> f_filter = nullptr;
290 if (dtype == runtime::DataType::Float(32)) {
291 f_filter = [&](int vector_len) { return vector_len <= 4; };
292 } else if (dtype == runtime::DataType::Float(16)) {
293 f_filter = [&](int vector_len) {
294 return (vector_len == 1 || vector_len % 2 == 0) && vector_len <= 8;
295 };
296 } else if (dtype == runtime::DataType::Int(8)) {
297 f_filter = [&](int vector_len) { return vector_len <= 16; };
298 }
299 std::vector<int> valid_vector_lens;
300 valid_vector_lens.reserve(vector_load_lens.size());
301 if (f_filter != nullptr) {
302 std::copy_if(vector_load_lens.begin(), vector_load_lens.end(),
303 std::back_inserter(valid_vector_lens), f_filter);
304 } else {
305 valid_vector_lens = vector_load_lens;
306 }
307
308 if (!valid_vector_lens.empty()) {
309 int n = valid_vector_lens.size();
310 double prob = 1.0 / n;
311 tir::ExprRV vector_load_len =
312 (*sch)->SampleCategorical(support::AsArray<int, Integer>(valid_vector_lens),
313 Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
314 (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
315 }
316}
317
318// Constructor
319
320ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
321 Optional<Integer> max_innermost_factor,
322 Optional<Array<Integer>> vector_load_lens,
323 Optional<Map<String, ObjectRef>> reuse_read,
324 Optional<Map<String, ObjectRef>> reuse_write,
325 Optional<runtime::PackedFunc> filter_fn) {
326 auto node = MultiLevelTilingInitCommon<MultiLevelTilingNode>(
327 structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
328 node->filter_fn_ = filter_fn;
329 return ScheduleRule(node);
330}
331
332TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode);
333TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling")
334 .set_body_typed(ScheduleRule::MultiLevelTiling);
335
336} // namespace meta_schedule
337} // namespace tvm
338