1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "../../tir/schedule/analysis.h" |
21 | #include "../../tir/schedule/transform.h" |
22 | #include "../utils.h" |
23 | #include "multi_level_tiling.h" |
24 | |
25 | namespace tvm { |
26 | namespace meta_schedule { |
27 | |
28 | /*! |
29 | * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate |
30 | * the tiled block for tensorization by postproc rewrite. |
31 | */ |
32 | Optional<tir::BlockRV> TileForIntrin(tir::Schedule sch, tir::BlockRV block, |
33 | const std::string& intrin_name) { |
34 | Optional<tir::LoopRV> tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); |
35 | if (!tiled_loop_rv) { |
36 | return NullOpt; |
37 | } |
38 | ICHECK(tiled_loop_rv.defined()); |
39 | tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); |
40 | sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); |
41 | return outer_block; |
42 | } |
43 | |
44 | /*! |
45 | * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. |
46 | */ |
47 | class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { |
48 | protected: |
49 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { |
50 | auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; |
51 | if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { |
52 | TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized." ; |
53 | return {sch}; |
54 | } |
55 | |
56 | auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv); |
57 | |
58 | if (res.empty()) { |
59 | TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized." ; |
60 | return {sch}; |
61 | } |
62 | TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name; |
63 | return res; |
64 | } |
65 | |
66 | // Inherited from ScheduleRuleNode |
67 | ScheduleRule Clone() const final { |
68 | ObjectPtr<MultiLevelTilingWithIntrinNode> n = |
69 | make_object<MultiLevelTilingWithIntrinNode>(*this); |
70 | return ScheduleRule(n); |
71 | } |
72 | |
73 | // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then |
74 | // tile the outerloops. |
75 | virtual std::vector<State> ApplySubRules(std::vector<State> states) { |
76 | states = SubRule(std::move(states), [&](State state) { |
77 | if (auto block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name)) { |
78 | state->block_rv = block_rv.value(); |
79 | return std::vector<State>(1, state); |
80 | } |
81 | return std::vector<State>(); |
82 | }); |
83 | return MultiLevelTilingNode::ApplySubRules(states); |
84 | } |
85 | |
86 | public: |
87 | /*! \brief The name of a tensor intrinsic. */ |
88 | String intrin_name; |
89 | |
90 | static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin" ; |
91 | TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); |
92 | }; |
93 | |
94 | ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( |
95 | String intrin_name, String structure, Optional<Array<String>> tile_binds, |
96 | Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens, |
97 | Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) { |
98 | ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) |
99 | << "Provided tensor intrinsic " << intrin_name << " is not registered." ; |
100 | auto node = MultiLevelTilingInitCommon<MultiLevelTilingWithIntrinNode>( |
101 | structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); |
102 | node->intrin_name = intrin_name; |
103 | return ScheduleRule(node); |
104 | } |
105 | |
106 | TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); |
107 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin" ) |
108 | .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); |
109 | |
110 | } // namespace meta_schedule |
111 | } // namespace tvm |
112 | |