1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule_rule.h> |
20 | |
21 | #include <algorithm> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "../utils.h" |
26 | #include "./multi_level_tiling.h" |
27 | |
28 | namespace tvm { |
29 | namespace meta_schedule { |
30 | |
31 | using tir::BlockRV; |
32 | using tir::LoopRV; |
33 | using tir::Schedule; |
34 | |
35 | struct TensorCoreIntrinGroup { |
36 | String init_intrin; |
37 | String load_a_intrin; |
38 | String load_b_intrin; |
39 | String compute_intrin; |
40 | String store_intrin; |
41 | |
42 | /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the |
43 | * following keys: |
44 | * - init |
45 | * - load_a |
46 | * - load_b |
47 | * - compute |
48 | * - store |
49 | * The values of the keys should be the names of the corresponding intrinsics and should be |
50 | * registered via TensorIntrin.Register beforehand. |
51 | */ |
52 | static TensorCoreIntrinGroup FromConfig(const Map<String, String>& config); |
53 | }; |
54 | |
55 | TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map<String, String>& config) { |
56 | auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { |
57 | CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set." ; |
58 | *intrin_name = config.at(key_name); |
59 | // Check the existence of the intrin |
60 | tir::TensorIntrin::Get(*intrin_name); |
61 | }; |
62 | TensorCoreIntrinGroup intrin_group; |
63 | f_initialize_intrin("init" , &intrin_group.init_intrin); |
64 | f_initialize_intrin("load_a" , &intrin_group.load_a_intrin); |
65 | f_initialize_intrin("load_b" , &intrin_group.load_b_intrin); |
66 | f_initialize_intrin("compute" , &intrin_group.compute_intrin); |
67 | f_initialize_intrin("store" , &intrin_group.store_intrin); |
68 | return intrin_group; |
69 | } |
70 | |
71 | class TensorCoreStateNode : public StateNode { |
72 | public: |
73 | /*! \brief The tensor core intrinsic group. */ |
74 | TensorCoreIntrinGroup intrin_group; |
75 | /*! \brief The auto tensorization maping info. */ |
76 | tir::AutoTensorizeMappingInfo mapping_info{nullptr}; |
77 | /*! \brief The Tensor Core reindex block A for Tensor Core computation */ |
78 | tir::BlockRV tensor_core_reindex_A; |
79 | /*! \brief The Tensor Core reindex block B for Tensor Core computation */ |
80 | tir::BlockRV tensor_core_reindex_B; |
81 | /*! \brief The Tensor Core reindex store block for Tensor Core computation */ |
82 | tir::BlockRV tensor_core_reindex_store; |
83 | |
84 | State Copy() const final; |
85 | |
86 | static constexpr const char* _type_key = "meta_schedule.TensorCoreState" ; |
87 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); |
88 | }; |
89 | |
90 | class TensorCoreState : public State { |
91 | public: |
92 | explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, |
93 | tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, |
94 | BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {}); |
95 | |
96 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); |
97 | }; |
98 | |
99 | TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode); |
100 | |
101 | TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, |
102 | tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, |
103 | BlockRV block_rv, Array<Array<LoopRV>> tiles) { |
104 | ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>(); |
105 | node->intrin_group = intrin_group; |
106 | node->mapping_info = mapping_info; |
107 | node->sch = std::move(sch); |
108 | node->block_rv = std::move(block_rv); |
109 | node->tiles = std::move(tiles); |
110 | data_ = std::move(node); |
111 | } |
112 | |
113 | State TensorCoreStateNode::Copy() const { |
114 | ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>(*this); |
115 | node->sch = sch->Copy(); |
116 | return State(node); |
117 | } |
118 | |
119 | /*! |
120 | * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core |
121 | * intrinsics. |
122 | */ |
123 | class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { |
124 | private: |
125 | // SubRule: Add tensorization-related transformations |
126 | inline std::vector<State> TransformForTensorization(TensorCoreState state) const; |
127 | // Subrule: Add tensorized load |
128 | inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const; |
129 | // Subrule: Add tensorized store |
130 | inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const; |
131 | // Subrule: Add software pipeline |
132 | inline std::vector<State> AddSoftwarePipeline(TensorCoreState state) const; |
133 | |
134 | // Override ApplySubRules to apply tensorization-specific sub-rules |
135 | std::vector<State> ApplySubRules(std::vector<State> states) final; |
136 | |
137 | // Override Apply to apply tensorization-specific analysis before applying sub-rules |
138 | Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final; |
139 | |
140 | // Inherited from ScheduleRuleNode |
141 | ScheduleRule Clone() const final { |
142 | ObjectPtr<MultiLevelTilingTensorCoreNode> n = |
143 | make_object<MultiLevelTilingTensorCoreNode>(*this); |
144 | return ScheduleRule(n); |
145 | } |
146 | |
147 | /*! |
148 | * \brief Transform and tensorize with the given tensor intrin |
149 | * \param state The state of the meta schedule rule |
150 | * \param intrin_name The name of the tensor intrin |
151 | * \return The loop to be tensorized. NullOpt if the workload can't be tensorized. |
152 | */ |
153 | Optional<LoopRV> TransformWithTensorIntrin(TensorCoreStateNode* state, |
154 | const String& intrin_name) const; |
155 | |
156 | /*! |
157 | * \brief Tile, blockize and annotate for tensorization with the given intrin |
158 | * \param block_rv The block to be tensorized |
159 | * \param intrin_name The name of the tensor intrin |
160 | */ |
161 | void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, |
162 | const String& intrin_name) const; |
163 | |
164 | public: |
165 | /*! \brief The candidate tensor core intrin groups to apply */ |
166 | std::vector<TensorCoreIntrinGroup> intrin_groups; |
167 | /*! \brief Whether to use software pipeline */ |
168 | bool use_software_pipeline = false; |
169 | static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore" ; |
170 | TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode); |
171 | |
172 | private: |
173 | }; |
174 | |
175 | // Entry of the mega rule; Inherited from ScheduleRuleNode |
176 | Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, |
177 | const BlockRV& block_rv) { |
178 | if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { |
179 | return {sch}; |
180 | } |
181 | |
182 | std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info; |
183 | for (int i = 0, n = intrin_groups.size(); i < n; ++i) { |
184 | TensorCoreIntrinGroup intrin_group = intrin_groups[i]; |
185 | Optional<tir::AutoTensorizeMappingInfo> mapping_info = tir::GetAutoTensorizeMappingInfo( |
186 | sch->state(), sch->GetSRef(block_rv), |
187 | tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); |
188 | if (mapping_info.defined()) { |
189 | intrin_group_to_mapping_info.emplace(i, mapping_info.value()); |
190 | } |
191 | } |
192 | |
193 | if (intrin_group_to_mapping_info.empty()) { |
194 | // No tensor intrinsics can be applied. |
195 | return {sch}; |
196 | } |
197 | |
198 | // Save the original schedule so that we can roll back transformations if tensorization |
199 | // fail. |
200 | Schedule original_sch = sch; |
201 | |
202 | std::vector<State> initial_states; |
203 | for (const auto& kv : intrin_group_to_mapping_info) { |
204 | const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first]; |
205 | const tir::AutoTensorizeMappingInfo& mapping_info = kv.second; |
206 | Schedule new_sch = sch->Copy(); |
207 | new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); |
208 | initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv)); |
209 | } |
210 | Array<Schedule> results; |
211 | for (auto&& state : ApplySubRules(initial_states)) { |
212 | TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " |
213 | << state.as<TensorCoreStateNode>()->intrin_group.compute_intrin; |
214 | results.push_back(std::move(state->sch)); |
215 | } |
216 | if (results.empty()) { |
217 | TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized." ; |
218 | return {original_sch}; |
219 | } |
220 | return results; |
221 | } |
222 | |
223 | std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<State> states) { |
224 | states = SubRule(std::move(states), [&](State state) { |
225 | return TransformForTensorization(Downcast<TensorCoreState>(state)); |
226 | }); |
227 | states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); |
228 | states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); |
229 | states = SubRule(std::move(states), [&](State state) { |
230 | return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state)); |
231 | }); |
232 | states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); |
233 | states = SubRule(std::move(states), [&](State state) { |
234 | return AddReadReuseTensorCore(Downcast<TensorCoreState>(state)); |
235 | }); |
236 | states = SubRule(std::move(states), [&](State state) { |
237 | return AddSoftwarePipeline(Downcast<TensorCoreState>(state)); |
238 | }); |
239 | return states; |
240 | } |
241 | |
242 | void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, |
243 | const BlockRV& block_rv, |
244 | const String& intrin_name) const { |
245 | Optional<LoopRV> loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); |
246 | ICHECK(loop.defined()); |
247 | BlockRV blockized_outer = (*sch)->Blockize(loop.value()); |
248 | (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); |
249 | } |
250 | |
251 | std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( |
252 | TensorCoreState state) const { |
253 | // Add the cache write stage for Tensor Core |
254 | int level = r_indices_.front() - 1; |
255 | const LoopRV& loop = state->tiles[level].back(); |
256 | Schedule& sch = state->sch; |
257 | auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator" ); |
258 | sch->ReverseComputeAt(cache_write, loop, true); |
259 | |
260 | if (state->write_reuse.count(0)) { |
261 | // Fuse the iterators of the cache_write |
262 | Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]); |
263 | ICHECK_GT(buffer_loops.size(), 2); |
264 | sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is always 2D |
265 | buffer_loops.end()}); |
266 | AnnotateCooperativeFetching(&sch, state->write_reuse[0]); |
267 | } |
268 | sch->ReverseComputeInline(state->tensor_core_reindex_store); |
269 | TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); |
270 | return {state}; |
271 | } |
272 | |
273 | std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( |
274 | TensorCoreState state) const { |
275 | const Array<LoopRV>& r_tiles = state->tiles[r_indices_[1]]; |
276 | Schedule& sch = state->sch; |
277 | ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block" ; |
278 | |
279 | auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) { |
280 | auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); |
281 | state->sch->ComputeAt(cache_read, r_tiles.back(), true); |
282 | TileAndAnnotateTensorize(&sch, cache_read, intrin_name); |
283 | }; |
284 | |
285 | f_tensorize_load(0, "wmma.matrix_a" , state->intrin_group.load_a_intrin); |
286 | f_tensorize_load(1, "wmma.matrix_b" , state->intrin_group.load_b_intrin); |
287 | sch->ComputeInline(state->tensor_core_reindex_A); |
288 | sch->ComputeInline(state->tensor_core_reindex_B); |
289 | |
290 | for (int i = 0; i < 2; ++i) { |
291 | const tir::BlockRV cache_read = state->read_reuse.at(i); |
292 | const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs<tir::BlockNode>(); |
293 | tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( |
294 | sch->state(), GetRef<tir::Block>(cache_read_block), 0, tir::BufferIndexType::kWrite); |
295 | const DataType& dtype = cache_read_buffer->dtype; |
296 | if (dtype.is_float16()) { |
297 | sch->StorageAlign(cache_read, 0, -2, 32, 8); |
298 | } else if (dtype.is_int() && dtype.bits() == 8) { |
299 | sch->StorageAlign(cache_read, 0, -2, 32, 16); |
300 | } else { |
301 | TVM_PY_LOG(WARNING, logger) << "StorageAlign is not applied for data type " << dtype |
302 | << ", shared memory accesses might be inefficient." ; |
303 | } |
304 | } |
305 | return {state}; |
306 | } |
307 | |
308 | std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( |
309 | TensorCoreState state) const { |
310 | if (!use_software_pipeline) { |
311 | return {state}; |
312 | } |
313 | // The current config is not suitable for software pipelining. |
314 | if (r_indices_.size() < 2) { |
315 | return {state}; |
316 | } |
317 | |
318 | Schedule& sch = state->sch; |
319 | // Check reduction length after blockize. |
320 | int64_t reduction_length = 1; |
321 | for (int r_index : r_indices_) { |
322 | const Array<LoopRV>& tiles = state->tiles[r_index]; |
323 | for (const LoopRV& tile : tiles) { |
324 | const auto* extent = sch->Get(tile)->extent.as<IntImmNode>(); |
325 | ICHECK(extent != nullptr) << "Dynamic extent is not supported." ; |
326 | reduction_length *= extent->value; |
327 | } |
328 | } |
329 | if (reduction_length <= 1) { |
330 | return {state}; |
331 | } |
332 | |
333 | // Add local stage and double buffering |
334 | for (int i = 0; i < 2; ++i) { |
335 | const tir::BlockRV cache_read = state->read_reuse.at(i); |
336 | sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1)); |
337 | sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0)); |
338 | } |
339 | |
340 | // Add annotations of software pipeline |
341 | // |
342 | // Before pipelining, the original loop can be expressed as the pseudo code below: |
343 | // |
344 | // for k0 in [0, K0): |
345 | // load tile k0 to registers |
346 | // load tile k0 from registers to shared memory |
347 | // |
348 | // for k1 in [0, K1): |
349 | // load fragment k1 of tile k0 |
350 | // compute matmul with fragment k1 |
351 | // |
352 | |
353 | // Inner software pipeline: Prefetch to tensor core fragment by one iteration |
354 | // The following annotation for the inner loop is equivalent the pesudo code below: |
355 | // |
356 | // Pipelined inner loop: |
357 | // |
358 | // prologue: |
359 | // load fragment 0 |
360 | // body: |
361 | // for k1 in [0, K1 - 1): |
362 | // load fragment k1 + 1 |
363 | // compute matmul with fragment k1 |
364 | // epilogue: |
365 | // compute matmul with fragment K1 - 1 |
366 | // |
367 | sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage, |
368 | Array<Integer>{0, 0, 1}); |
369 | sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order, |
370 | Array<Integer>{0, 1, 2}); |
371 | // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop. |
372 | // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop. |
373 | // The following annotation for the outer loop is equivalent the pesudo code below: |
374 | // |
375 | // Pipelined outer loop with nested inner pipeline: |
376 | // |
377 | // prologue: |
378 | // load tile 0 to registers |
379 | // load tile 0 from registers to shared memory |
380 | // |
381 | // // prologue of the inner pipeline |
382 | // load fragment 0 of tile 0 |
383 | // |
384 | // body: |
385 | // for k0 in [0, K0 - 1): |
386 | // load tile k0 + 1 to registers |
387 | // |
388 | // // body of the inner pipeline |
389 | // for k1 in [0, K1 - 1): |
390 | // load fragment k1 + 1 of tile k0 |
391 | // compute matmul with fragment k1 of tile k0 |
392 | // |
393 | // load tile k0 + 1 from registers to shared memory |
394 | // |
395 | // // prologue of the inner pipeline |
396 | // load fragment 0 of tile k0 + 1 |
397 | // |
398 | // // epilogue of the inner pipeline |
399 | // compute matmul with fragment K1 - 1 of tile k0 |
400 | // |
401 | // epilogue: |
402 | // |
403 | // // body of the inner pipeline |
404 | // for k1 in [0, K1 - 1): |
405 | // load fragment k1 + 1 of tile K0 - 1 |
406 | // compute matmul with fragment k1 of tile K0 - 1 |
407 | // |
408 | // // epilogue of the inner pipeline |
409 | // compute matmul with fragment K1 - 1 of tile K0 - 1 |
410 | // |
411 | sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, |
412 | Array<Integer>{0, 0, 0, 0, 0, 1, 1}); |
413 | sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, |
414 | Array<Integer>{0, 3, 1, 4, 5, 2, 6}); |
415 | |
416 | return {state}; |
417 | } |
418 | |
419 | Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( |
420 | TensorCoreStateNode* state, const String& intrin_name) const { |
421 | BlockRV block_rv = state->block_rv; |
422 | const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; |
423 | tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); |
424 | |
425 | // Add reindex stages |
426 | const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
427 | // Hold the reference of the block before reindex |
428 | const tir::Block block_before_reindex = GetRef<tir::Block>(block); |
429 | if (block->reads.size() != 2 || block->writes.size() != 1) { |
430 | // only matmul-like computation is allowed |
431 | return NullOpt; |
432 | } |
433 | state->tensor_core_reindex_store = |
434 | state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); |
435 | state->tensor_core_reindex_A = |
436 | state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead); |
437 | state->tensor_core_reindex_B = |
438 | state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead); |
439 | |
440 | // Transform the layout of reindex buffers accordingly. |
441 | // The index map defines the mapping for the computation block. We need to extract the sub index |
442 | // map to transform the load and store block. |
443 | ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present |
444 | const tir::IndexMap& index_map = mapping_info->mappings[0]; |
445 | |
446 | // Find the correspondence between block iters and the iters in the index map. |
447 | std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src; |
448 | std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt; |
449 | std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src; |
450 | ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); |
451 | for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) { |
452 | lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; |
453 | } |
454 | // The number of result iters in the index map is equal or more than the number of rhs (the |
455 | // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from |
456 | // the lhs. They will be skipped during pattern matching for tensorization. An example of such |
457 | // case is batch matmul, the batch dimension is kept after layout transformations and it will be |
458 | // kept as a outer loop after tensorization. |
459 | int offset = static_cast<int>(index_map->final_indices.size()) - |
460 | static_cast<int>(mapping_info->rhs_iters.size()); |
461 | ICHECK_GE(offset, 0); |
462 | for (int i = 0; i < offset; ++i) { |
463 | const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>(); |
464 | ICHECK(var_ptr != nullptr); |
465 | unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr)); |
466 | } |
467 | for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) { |
468 | rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; |
469 | } |
470 | |
471 | auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { |
472 | std::vector<tir::Var> sub_index_map_src; |
473 | std::vector<PrimExpr> sub_index_map_tgt; |
474 | const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; |
475 | for (const Range& range : lhs_region) { |
476 | ICHECK(tir::is_one(range->extent)); |
477 | const tir::VarNode* var_ptr = range->min.as<tir::VarNode>(); |
478 | ICHECK(var_ptr != nullptr); |
479 | const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef<tir::Var>(var_ptr)]; |
480 | sub_index_map_src.push_back(lhs_representer); |
481 | if (unmapped_index_map_src.count(lhs_representer)) { |
482 | sub_index_map_tgt.push_back(lhs_representer); |
483 | } |
484 | } |
485 | for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { |
486 | const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>(); |
487 | ICHECK(var != nullptr); |
488 | sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]); |
489 | } |
490 | return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); |
491 | }; |
492 | |
493 | std::unordered_set<tir::Buffer, ObjectPtrHash, ObjectPtrEqual> visited_buffers; |
494 | |
495 | Map<tir::Buffer, tir::IndexMap> buffer_sub_index_map; // cache of the sub index map associated |
496 | // with each buffer |
497 | |
498 | auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { |
499 | const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( |
500 | state->sch->state(), block_before_reindex, buffer_index, index_type); |
501 | if (visited_buffers.count(lhs_buffer)) { |
502 | return; |
503 | } |
504 | visited_buffers.insert(lhs_buffer); |
505 | // Refresh block pointer (block sref is not invalidated) |
506 | block = TVM_SREF_TO_BLOCK(block_sref); |
507 | const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( |
508 | state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type); |
509 | auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); |
510 | buffer_sub_index_map.Set(lhs_buffer, sub_index_map); |
511 | state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); |
512 | }; |
513 | |
514 | for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { |
515 | f_transform_buffer_layout(tir::BufferIndexType::kRead, i); |
516 | } |
517 | for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) { |
518 | f_transform_buffer_layout(tir::BufferIndexType::kWrite, i); |
519 | } |
520 | |
521 | // Transform the layout of current block and reindex blocks |
522 | auto f_transform_reindex_block_layout = [&](const BlockRV& block_rv, |
523 | tir::BufferIndexType buffer_type) { |
524 | tir::Buffer buffer = |
525 | tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); |
526 | const auto& sub_index_map = buffer_sub_index_map.at(buffer); |
527 | state->sch->TransformBlockLayout(block_rv, sub_index_map); |
528 | }; |
529 | f_transform_reindex_block_layout(state->tensor_core_reindex_store, tir::BufferIndexType::kWrite); |
530 | f_transform_reindex_block_layout(state->tensor_core_reindex_A, tir::BufferIndexType::kRead); |
531 | f_transform_reindex_block_layout(state->tensor_core_reindex_B, tir::BufferIndexType::kRead); |
532 | state->sch->TransformBlockLayout(state->block_rv, index_map); |
533 | return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, |
534 | /*allow_padding=*/true); |
535 | } |
536 | |
537 | inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization( |
538 | TensorCoreState state) const { |
539 | // Do reindex and layout transformations. |
540 | Optional<LoopRV> transformed_loop_rv = |
541 | TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); |
542 | if (!transformed_loop_rv.defined()) { |
543 | // The workload can't be tensorized. |
544 | return {}; |
545 | } |
546 | |
547 | // Do blockize |
548 | state->block_rv = state->sch->Blockize(transformed_loop_rv.value()); |
549 | |
550 | // Add annotations for post processors. |
551 | state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize, |
552 | state->intrin_group.compute_intrin); |
553 | state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init, |
554 | state->intrin_group.init_intrin); |
555 | state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1)); |
556 | return {std::move(state)}; |
557 | } |
558 | |
559 | ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( |
560 | Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds, |
561 | Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens, |
562 | Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write, |
563 | bool use_software_pipeline) { |
564 | if (tile_binds.defined()) { |
565 | for (const String& tile_bind : tile_binds.value()) { |
566 | CHECK_NE(tile_bind, "threadIdx.x" ) << "Cannot bind to threadIdx.x when using tensor core." ; |
567 | } |
568 | } |
569 | auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>( |
570 | structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); |
571 | |
572 | node->intrin_groups.reserve(intrin_groups.size()); |
573 | for (const auto& intrin_group_config : intrin_groups) { |
574 | node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config)); |
575 | } |
576 | node->use_software_pipeline = use_software_pipeline; |
577 | return ScheduleRule(node); |
578 | } |
579 | |
580 | TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode); |
581 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore" ) |
582 | .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore); |
583 | |
584 | } // namespace meta_schedule |
585 | } // namespace tvm |
586 | |