1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule_rule.h>
20
21#include <algorithm>
22#include <utility>
23#include <vector>
24
25#include "../utils.h"
26#include "./multi_level_tiling.h"
27
28namespace tvm {
29namespace meta_schedule {
30
31using tir::BlockRV;
32using tir::LoopRV;
33using tir::Schedule;
34
35struct TensorCoreIntrinGroup {
36 String init_intrin;
37 String load_a_intrin;
38 String load_b_intrin;
39 String compute_intrin;
40 String store_intrin;
41
42 /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the
43 * following keys:
44 * - init
45 * - load_a
46 * - load_b
47 * - compute
48 * - store
49 * The values of the keys should be the names of the corresponding intrinsics and should be
50 * registered via TensorIntrin.Register beforehand.
51 */
52 static TensorCoreIntrinGroup FromConfig(const Map<String, String>& config);
53};
54
55TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map<String, String>& config) {
56 auto f_initialize_intrin = [&config](String key_name, String* intrin_name) {
57 CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set.";
58 *intrin_name = config.at(key_name);
59 // Check the existence of the intrin
60 tir::TensorIntrin::Get(*intrin_name);
61 };
62 TensorCoreIntrinGroup intrin_group;
63 f_initialize_intrin("init", &intrin_group.init_intrin);
64 f_initialize_intrin("load_a", &intrin_group.load_a_intrin);
65 f_initialize_intrin("load_b", &intrin_group.load_b_intrin);
66 f_initialize_intrin("compute", &intrin_group.compute_intrin);
67 f_initialize_intrin("store", &intrin_group.store_intrin);
68 return intrin_group;
69}
70
71class TensorCoreStateNode : public StateNode {
72 public:
73 /*! \brief The tensor core intrinsic group. */
74 TensorCoreIntrinGroup intrin_group;
75 /*! \brief The auto tensorization maping info. */
76 tir::AutoTensorizeMappingInfo mapping_info{nullptr};
77 /*! \brief The Tensor Core reindex block A for Tensor Core computation */
78 tir::BlockRV tensor_core_reindex_A;
79 /*! \brief The Tensor Core reindex block B for Tensor Core computation */
80 tir::BlockRV tensor_core_reindex_B;
81 /*! \brief The Tensor Core reindex store block for Tensor Core computation */
82 tir::BlockRV tensor_core_reindex_store;
83
84 State Copy() const final;
85
86 static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
87 TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
88};
89
90class TensorCoreState : public State {
91 public:
92 explicit TensorCoreState(TensorCoreIntrinGroup intrin_group,
93 tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
94 BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
95
96 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
97};
98
99TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);
100
101TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group,
102 tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
103 BlockRV block_rv, Array<Array<LoopRV>> tiles) {
104 ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>();
105 node->intrin_group = intrin_group;
106 node->mapping_info = mapping_info;
107 node->sch = std::move(sch);
108 node->block_rv = std::move(block_rv);
109 node->tiles = std::move(tiles);
110 data_ = std::move(node);
111}
112
113State TensorCoreStateNode::Copy() const {
114 ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>(*this);
115 node->sch = sch->Copy();
116 return State(node);
117}
118
119/*!
120 * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
121 * intrinsics.
122 */
123class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
124 private:
125 // SubRule: Add tensorization-related transformations
126 inline std::vector<State> TransformForTensorization(TensorCoreState state) const;
127 // Subrule: Add tensorized load
128 inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
129 // Subrule: Add tensorized store
130 inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const;
131 // Subrule: Add software pipeline
132 inline std::vector<State> AddSoftwarePipeline(TensorCoreState state) const;
133
134 // Override ApplySubRules to apply tensorization-specific sub-rules
135 std::vector<State> ApplySubRules(std::vector<State> states) final;
136
137 // Override Apply to apply tensorization-specific analysis before applying sub-rules
138 Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
139
140 // Inherited from ScheduleRuleNode
141 ScheduleRule Clone() const final {
142 ObjectPtr<MultiLevelTilingTensorCoreNode> n =
143 make_object<MultiLevelTilingTensorCoreNode>(*this);
144 return ScheduleRule(n);
145 }
146
147 /*!
148 * \brief Transform and tensorize with the given tensor intrin
149 * \param state The state of the meta schedule rule
150 * \param intrin_name The name of the tensor intrin
151 * \return The loop to be tensorized. NullOpt if the workload can't be tensorized.
152 */
153 Optional<LoopRV> TransformWithTensorIntrin(TensorCoreStateNode* state,
154 const String& intrin_name) const;
155
156 /*!
157 * \brief Tile, blockize and annotate for tensorization with the given intrin
158 * \param block_rv The block to be tensorized
159 * \param intrin_name The name of the tensor intrin
160 */
161 void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv,
162 const String& intrin_name) const;
163
164 public:
165 /*! \brief The candidate tensor core intrin groups to apply */
166 std::vector<TensorCoreIntrinGroup> intrin_groups;
167 /*! \brief Whether to use software pipeline */
168 bool use_software_pipeline = false;
169 static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
170 TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
171
172 private:
173};
174
175// Entry of the mega rule; Inherited from ScheduleRuleNode
176Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
177 const BlockRV& block_rv) {
178 if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
179 return {sch};
180 }
181
182 std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info;
183 for (int i = 0, n = intrin_groups.size(); i < n; ++i) {
184 TensorCoreIntrinGroup intrin_group = intrin_groups[i];
185 Optional<tir::AutoTensorizeMappingInfo> mapping_info = tir::GetAutoTensorizeMappingInfo(
186 sch->state(), sch->GetSRef(block_rv),
187 tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc);
188 if (mapping_info.defined()) {
189 intrin_group_to_mapping_info.emplace(i, mapping_info.value());
190 }
191 }
192
193 if (intrin_group_to_mapping_info.empty()) {
194 // No tensor intrinsics can be applied.
195 return {sch};
196 }
197
198 // Save the original schedule so that we can roll back transformations if tensorization
199 // fail.
200 Schedule original_sch = sch;
201
202 std::vector<State> initial_states;
203 for (const auto& kv : intrin_group_to_mapping_info) {
204 const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
205 const tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
206 Schedule new_sch = sch->Copy();
207 new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
208 initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv));
209 }
210 Array<Schedule> results;
211 for (auto&& state : ApplySubRules(initial_states)) {
212 TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with "
213 << state.as<TensorCoreStateNode>()->intrin_group.compute_intrin;
214 results.push_back(std::move(state->sch));
215 }
216 if (results.empty()) {
217 TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized.";
218 return {original_sch};
219 }
220 return results;
221}
222
223std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<State> states) {
224 states = SubRule(std::move(states), [&](State state) {
225 return TransformForTensorization(Downcast<TensorCoreState>(state));
226 });
227 states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
228 states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
229 states = SubRule(std::move(states), [&](State state) {
230 return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
231 });
232 states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
233 states = SubRule(std::move(states), [&](State state) {
234 return AddReadReuseTensorCore(Downcast<TensorCoreState>(state));
235 });
236 states = SubRule(std::move(states), [&](State state) {
237 return AddSoftwarePipeline(Downcast<TensorCoreState>(state));
238 });
239 return states;
240}
241
242void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
243 const BlockRV& block_rv,
244 const String& intrin_name) const {
245 Optional<LoopRV> loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value();
246 ICHECK(loop.defined());
247 BlockRV blockized_outer = (*sch)->Blockize(loop.value());
248 (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name);
249}
250
251std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
252 TensorCoreState state) const {
253 // Add the cache write stage for Tensor Core
254 int level = r_indices_.front() - 1;
255 const LoopRV& loop = state->tiles[level].back();
256 Schedule& sch = state->sch;
257 auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
258 sch->ReverseComputeAt(cache_write, loop, true);
259
260 if (state->write_reuse.count(0)) {
261 // Fuse the iterators of the cache_write
262 Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
263 ICHECK_GT(buffer_loops.size(), 2);
264 sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is always 2D
265 buffer_loops.end()});
266 AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
267 }
268 sch->ReverseComputeInline(state->tensor_core_reindex_store);
269 TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
270 return {state};
271}
272
273std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
274 TensorCoreState state) const {
275 const Array<LoopRV>& r_tiles = state->tiles[r_indices_[1]];
276 Schedule& sch = state->sch;
277 ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block";
278
279 auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) {
280 auto cache_read = sch->CacheRead(state->block_rv, read_index, scope);
281 state->sch->ComputeAt(cache_read, r_tiles.back(), true);
282 TileAndAnnotateTensorize(&sch, cache_read, intrin_name);
283 };
284
285 f_tensorize_load(0, "wmma.matrix_a", state->intrin_group.load_a_intrin);
286 f_tensorize_load(1, "wmma.matrix_b", state->intrin_group.load_b_intrin);
287 sch->ComputeInline(state->tensor_core_reindex_A);
288 sch->ComputeInline(state->tensor_core_reindex_B);
289
290 for (int i = 0; i < 2; ++i) {
291 const tir::BlockRV cache_read = state->read_reuse.at(i);
292 const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs<tir::BlockNode>();
293 tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer(
294 sch->state(), GetRef<tir::Block>(cache_read_block), 0, tir::BufferIndexType::kWrite);
295 const DataType& dtype = cache_read_buffer->dtype;
296 if (dtype.is_float16()) {
297 sch->StorageAlign(cache_read, 0, -2, 32, 8);
298 } else if (dtype.is_int() && dtype.bits() == 8) {
299 sch->StorageAlign(cache_read, 0, -2, 32, 16);
300 } else {
301 TVM_PY_LOG(WARNING, logger) << "StorageAlign is not applied for data type " << dtype
302 << ", shared memory accesses might be inefficient.";
303 }
304 }
305 return {state};
306}
307
308std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
309 TensorCoreState state) const {
310 if (!use_software_pipeline) {
311 return {state};
312 }
313 // The current config is not suitable for software pipelining.
314 if (r_indices_.size() < 2) {
315 return {state};
316 }
317
318 Schedule& sch = state->sch;
319 // Check reduction length after blockize.
320 int64_t reduction_length = 1;
321 for (int r_index : r_indices_) {
322 const Array<LoopRV>& tiles = state->tiles[r_index];
323 for (const LoopRV& tile : tiles) {
324 const auto* extent = sch->Get(tile)->extent.as<IntImmNode>();
325 ICHECK(extent != nullptr) << "Dynamic extent is not supported.";
326 reduction_length *= extent->value;
327 }
328 }
329 if (reduction_length <= 1) {
330 return {state};
331 }
332
333 // Add local stage and double buffering
334 for (int i = 0; i < 2; ++i) {
335 const tir::BlockRV cache_read = state->read_reuse.at(i);
336 sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1));
337 sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
338 }
339
340 // Add annotations of software pipeline
341 //
342 // Before pipelining, the original loop can be expressed as the pseudo code below:
343 //
344 // for k0 in [0, K0):
345 // load tile k0 to registers
346 // load tile k0 from registers to shared memory
347 //
348 // for k1 in [0, K1):
349 // load fragment k1 of tile k0
350 // compute matmul with fragment k1
351 //
352
353 // Inner software pipeline: Prefetch to tensor core fragment by one iteration
354 // The following annotation for the inner loop is equivalent the pesudo code below:
355 //
356 // Pipelined inner loop:
357 //
358 // prologue:
359 // load fragment 0
360 // body:
361 // for k1 in [0, K1 - 1):
362 // load fragment k1 + 1
363 // compute matmul with fragment k1
364 // epilogue:
365 // compute matmul with fragment K1 - 1
366 //
367 sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage,
368 Array<Integer>{0, 0, 1});
369 sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order,
370 Array<Integer>{0, 1, 2});
371 // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop.
372 // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop.
373 // The following annotation for the outer loop is equivalent the pesudo code below:
374 //
375 // Pipelined outer loop with nested inner pipeline:
376 //
377 // prologue:
378 // load tile 0 to registers
379 // load tile 0 from registers to shared memory
380 //
381 // // prologue of the inner pipeline
382 // load fragment 0 of tile 0
383 //
384 // body:
385 // for k0 in [0, K0 - 1):
386 // load tile k0 + 1 to registers
387 //
388 // // body of the inner pipeline
389 // for k1 in [0, K1 - 1):
390 // load fragment k1 + 1 of tile k0
391 // compute matmul with fragment k1 of tile k0
392 //
393 // load tile k0 + 1 from registers to shared memory
394 //
395 // // prologue of the inner pipeline
396 // load fragment 0 of tile k0 + 1
397 //
398 // // epilogue of the inner pipeline
399 // compute matmul with fragment K1 - 1 of tile k0
400 //
401 // epilogue:
402 //
403 // // body of the inner pipeline
404 // for k1 in [0, K1 - 1):
405 // load fragment k1 + 1 of tile K0 - 1
406 // compute matmul with fragment k1 of tile K0 - 1
407 //
408 // // epilogue of the inner pipeline
409 // compute matmul with fragment K1 - 1 of tile K0 - 1
410 //
411 sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage,
412 Array<Integer>{0, 0, 0, 0, 0, 1, 1});
413 sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order,
414 Array<Integer>{0, 3, 1, 4, 5, 2, 6});
415
416 return {state};
417}
418
419Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
420 TensorCoreStateNode* state, const String& intrin_name) const {
421 BlockRV block_rv = state->block_rv;
422 const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info;
423 tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);
424
425 // Add reindex stages
426 const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
427 // Hold the reference of the block before reindex
428 const tir::Block block_before_reindex = GetRef<tir::Block>(block);
429 if (block->reads.size() != 2 || block->writes.size() != 1) {
430 // only matmul-like computation is allowed
431 return NullOpt;
432 }
433 state->tensor_core_reindex_store =
434 state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite);
435 state->tensor_core_reindex_A =
436 state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead);
437 state->tensor_core_reindex_B =
438 state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead);
439
440 // Transform the layout of reindex buffers accordingly.
441 // The index map defines the mapping for the computation block. We need to extract the sub index
442 // map to transform the load and store block.
443 ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present
444 const tir::IndexMap& index_map = mapping_info->mappings[0];
445
446 // Find the correspondence between block iters and the iters in the index map.
447 std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
448 std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
449 std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
450 ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size());
451 for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) {
452 lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i];
453 }
454 // The number of result iters in the index map is equal or more than the number of rhs (the
455 // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from
456 // the lhs. They will be skipped during pattern matching for tensorization. An example of such
457 // case is batch matmul, the batch dimension is kept after layout transformations and it will be
458 // kept as a outer loop after tensorization.
459 int offset = static_cast<int>(index_map->final_indices.size()) -
460 static_cast<int>(mapping_info->rhs_iters.size());
461 ICHECK_GE(offset, 0);
462 for (int i = 0; i < offset; ++i) {
463 const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>();
464 ICHECK(var_ptr != nullptr);
465 unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr));
466 }
467 for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) {
468 rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i];
469 }
470
471 auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) {
472 std::vector<tir::Var> sub_index_map_src;
473 std::vector<PrimExpr> sub_index_map_tgt;
474 const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer];
475 for (const Range& range : lhs_region) {
476 ICHECK(tir::is_one(range->extent));
477 const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
478 ICHECK(var_ptr != nullptr);
479 const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef<tir::Var>(var_ptr)];
480 sub_index_map_src.push_back(lhs_representer);
481 if (unmapped_index_map_src.count(lhs_representer)) {
482 sub_index_map_tgt.push_back(lhs_representer);
483 }
484 }
485 for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) {
486 const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
487 ICHECK(var != nullptr);
488 sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]);
489 }
490 return tir::IndexMap(sub_index_map_src, sub_index_map_tgt);
491 };
492
493 std::unordered_set<tir::Buffer, ObjectPtrHash, ObjectPtrEqual> visited_buffers;
494
495 Map<tir::Buffer, tir::IndexMap> buffer_sub_index_map; // cache of the sub index map associated
496 // with each buffer
497
498 auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) {
499 const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer(
500 state->sch->state(), block_before_reindex, buffer_index, index_type);
501 if (visited_buffers.count(lhs_buffer)) {
502 return;
503 }
504 visited_buffers.insert(lhs_buffer);
505 // Refresh block pointer (block sref is not invalidated)
506 block = TVM_SREF_TO_BLOCK(block_sref);
507 const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
508 state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
509 auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
510 buffer_sub_index_map.Set(lhs_buffer, sub_index_map);
511 state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt);
512 };
513
514 for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
515 f_transform_buffer_layout(tir::BufferIndexType::kRead, i);
516 }
517 for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) {
518 f_transform_buffer_layout(tir::BufferIndexType::kWrite, i);
519 }
520
521 // Transform the layout of current block and reindex blocks
522 auto f_transform_reindex_block_layout = [&](const BlockRV& block_rv,
523 tir::BufferIndexType buffer_type) {
524 tir::Buffer buffer =
525 tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type);
526 const auto& sub_index_map = buffer_sub_index_map.at(buffer);
527 state->sch->TransformBlockLayout(block_rv, sub_index_map);
528 };
529 f_transform_reindex_block_layout(state->tensor_core_reindex_store, tir::BufferIndexType::kWrite);
530 f_transform_reindex_block_layout(state->tensor_core_reindex_A, tir::BufferIndexType::kRead);
531 f_transform_reindex_block_layout(state->tensor_core_reindex_B, tir::BufferIndexType::kRead);
532 state->sch->TransformBlockLayout(state->block_rv, index_map);
533 return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name,
534 /*allow_padding=*/true);
535}
536
537inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization(
538 TensorCoreState state) const {
539 // Do reindex and layout transformations.
540 Optional<LoopRV> transformed_loop_rv =
541 TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin);
542 if (!transformed_loop_rv.defined()) {
543 // The workload can't be tensorized.
544 return {};
545 }
546
547 // Do blockize
548 state->block_rv = state->sch->Blockize(transformed_loop_rv.value());
549
550 // Add annotations for post processors.
551 state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize,
552 state->intrin_group.compute_intrin);
553 state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
554 state->intrin_group.init_intrin);
555 state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1));
556 return {std::move(state)};
557}
558
559ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
560 Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
561 Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
562 Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
563 bool use_software_pipeline) {
564 if (tile_binds.defined()) {
565 for (const String& tile_bind : tile_binds.value()) {
566 CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core.";
567 }
568 }
569 auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
570 structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
571
572 node->intrin_groups.reserve(intrin_groups.size());
573 for (const auto& intrin_group_config : intrin_groups) {
574 node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
575 }
576 node->use_software_pipeline = use_software_pipeline;
577 return ScheduleRule(node);
578}
579
580TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode);
581TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore")
582 .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore);
583
584} // namespace meta_schedule
585} // namespace tvm
586