1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "../../tir/schedule/analysis.h"
21#include "../../tir/schedule/transform.h"
22#include "../utils.h"
23#include "multi_level_tiling.h"
24
25namespace tvm {
26namespace meta_schedule {
27
28/*!
29 * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate
30 * the tiled block for tensorization by postproc rewrite.
31 */
32Optional<tir::BlockRV> TileForIntrin(tir::Schedule sch, tir::BlockRV block,
33 const std::string& intrin_name) {
34 Optional<tir::LoopRV> tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name);
35 if (!tiled_loop_rv) {
36 return NullOpt;
37 }
38 ICHECK(tiled_loop_rv.defined());
39 tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value());
40 sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
41 return outer_block;
42}
43
44/*!
45 * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
46 */
47class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
48 protected:
49 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
50 auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc;
51 if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) {
52 TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized.";
53 return {sch};
54 }
55
56 auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv);
57
58 if (res.empty()) {
59 TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized.";
60 return {sch};
61 }
62 TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name;
63 return res;
64 }
65
66 // Inherited from ScheduleRuleNode
67 ScheduleRule Clone() const final {
68 ObjectPtr<MultiLevelTilingWithIntrinNode> n =
69 make_object<MultiLevelTilingWithIntrinNode>(*this);
70 return ScheduleRule(n);
71 }
72
73 // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then
74 // tile the outerloops.
75 virtual std::vector<State> ApplySubRules(std::vector<State> states) {
76 states = SubRule(std::move(states), [&](State state) {
77 if (auto block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name)) {
78 state->block_rv = block_rv.value();
79 return std::vector<State>(1, state);
80 }
81 return std::vector<State>();
82 });
83 return MultiLevelTilingNode::ApplySubRules(states);
84 }
85
86 public:
87 /*! \brief The name of a tensor intrinsic. */
88 String intrin_name;
89
90 static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin";
91 TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode);
92};
93
94ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(
95 String intrin_name, String structure, Optional<Array<String>> tile_binds,
96 Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
97 Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
98 ICHECK(tir::TensorIntrin::Get(intrin_name).defined())
99 << "Provided tensor intrinsic " << intrin_name << " is not registered.";
100 auto node = MultiLevelTilingInitCommon<MultiLevelTilingWithIntrinNode>(
101 structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
102 node->intrin_name = intrin_name;
103 return ScheduleRule(node);
104}
105
106TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode);
107TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin")
108 .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin);
109
110} // namespace meta_schedule
111} // namespace tvm
112