1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ |
20 | #define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ |
21 | |
22 | #include <tvm/meta_schedule/schedule_rule.h> |
23 | #include <tvm/tir/schedule/schedule.h> |
24 | |
25 | #include <unordered_map> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "../../support/array.h" |
30 | |
31 | namespace tvm { |
32 | namespace meta_schedule { |
33 | |
34 | /*! |
35 | * \brief Configuration of data reuse type: |
36 | * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. |
37 | * 1) kMayReuse: reuse is allowed, but no reuse is explored. |
38 | * 2) kMustReuse: reuse is allowed and no reuse is not explored. |
39 | */ |
40 | enum class ReuseType : int32_t { |
41 | kNoReuse = 0, |
42 | kMayReuse = 1, |
43 | kMustReuse = 2, |
44 | }; |
45 | |
46 | /*! |
47 | * \brief Converts a string to ReuseType. |
48 | * \param str The string to be converted. |
49 | * \return The converted ReuseType. |
50 | */ |
51 | inline ReuseType Str2ReuseType(const String& str) { |
52 | if (str == "no" ) { |
53 | return ReuseType::kNoReuse; |
54 | } else if (str == "may" ) { |
55 | return ReuseType::kMayReuse; |
56 | } else if (str == "must" ) { |
57 | return ReuseType::kMustReuse; |
58 | } else { |
59 | LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; |
60 | throw; |
61 | } |
62 | } |
63 | |
64 | /*! \brief Configuration of data reuse patterns */ |
65 | struct ReuseConfig { |
66 | /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ |
67 | ReuseType req; |
68 | /*! \brief Which levels are caching stage inserted at */ |
69 | std::vector<int> levels; |
70 | /*! \brief The storage scope */ |
71 | String scope; |
72 | |
73 | /*! \brief Default constructor: no data reuse */ |
74 | ReuseConfig() : req(ReuseType::kNoReuse) {} |
75 | |
76 | /*! \brief Construct from a configuration dictionary */ |
77 | explicit ReuseConfig(const Map<String, ObjectRef>& config) |
78 | : req(Str2ReuseType(Downcast<String>(config.at("req" )))), |
79 | levels(support::AsVector<Integer, int>(Downcast<Array<Integer>>(config.at("levels" )))), |
80 | scope(Downcast<String>(config.at("scope" ))) { |
81 | ICHECK_EQ(config.size(), 3); |
82 | } |
83 | }; |
84 | |
85 | // Forware declaration |
86 | class State; |
87 | |
88 | /*! \brief The state of auto scheduling for the multi-level tiling rule */ |
89 | class StateNode : public Object { |
90 | public: |
91 | /*! \brief The schedule to date */ |
92 | tir::Schedule sch; |
93 | /*! \brief The block to be tiled */ |
94 | tir::BlockRV block_rv; |
95 | /*! \brief The loop tiles */ |
96 | Array<Array<tir::LoopRV>> tiles; |
97 | /*! \brief The mapping from buffer index to read cache block. */ |
98 | std::unordered_map<int, tir::BlockRV> read_reuse; |
99 | /*! \brief The mapping from buffer index to write cache block. */ |
100 | std::unordered_map<int, tir::BlockRV> write_reuse; |
101 | |
102 | /*! |
103 | * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that |
104 | * produce multiple states should use this method to create new states. |
105 | */ |
106 | virtual State Copy() const; |
107 | |
108 | static constexpr const char* _type_key = "meta_schedule.State" ; |
109 | TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object); |
110 | }; |
111 | |
112 | /*! \brief Managed reference to StateNode */ |
113 | class State : public ObjectRef { |
114 | public: |
115 | /*! \brief Default constructor */ |
116 | explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {}); |
117 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); |
118 | }; |
119 | |
120 | /*! |
121 | * \brief Helper to apply a sub-rule to a list of auto scheduling states |
122 | * \tparam FLambda The type of the sub-rule functor |
123 | * \param states The list of states to be applied |
124 | * \return The list of states after applying the sub-rule |
125 | */ |
126 | template <class FLambda> |
127 | std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) { |
128 | std::vector<State> results; |
129 | for (auto&& state : states) { |
130 | std::vector<State> next = sub_rule(std::move(state)); |
131 | results.insert(results.end(), // |
132 | std::make_move_iterator(next.begin()), // |
133 | std::make_move_iterator(next.end())); |
134 | } |
135 | return results; |
136 | } |
137 | |
138 | /*! |
139 | * \brief The mega rule: multi-level tiling with data reuse |
140 | */ |
141 | class MultiLevelTilingNode : public ScheduleRuleNode { |
142 | public: |
143 | virtual ~MultiLevelTilingNode() = default; |
144 | |
145 | // SubRule 1. add write cache |
146 | std::vector<State> AddWriteReuse(State state) const; |
147 | // SubRule 2. tile the loop nest |
148 | std::vector<State> TileLoopNest(State state) const; |
149 | // SubRule 3. add read cache |
150 | std::vector<State> AddReadReuse(State state) const; |
151 | |
152 | // Do nothing; Inherited from ScheduleRuleNode |
153 | void InitializeWithTuneContext(const TuneContext& context) final; |
154 | |
155 | // Entry of the mega rule; Inherited from ScheduleRuleNode |
156 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; |
157 | |
158 | // Inherited from ScheduleRuleNode |
159 | ScheduleRule Clone() const override; |
160 | |
161 | protected: |
162 | virtual std::vector<State> ApplySubRules(std::vector<State> states); |
163 | |
164 | virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block, |
165 | tir::LoopRV loop, int n_tiles) const; |
166 | |
167 | // Annotate a block to use cooperative fetching |
168 | void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; |
169 | |
170 | public: |
171 | /*! |
172 | * \brief The tiling structure. Recommended: |
173 | * - 'SSRSRS' on CPU |
174 | * - 'SSSRRSRS' on GPU |
175 | */ |
176 | String structure; |
177 | /*! \brief For each level of tiles, which thread axis it is bound to */ |
178 | Array<String> tile_binds; |
179 | /*! \brief The maximum size of the innermost factor */ |
180 | int max_innermost_factor; |
181 | /*! \brief The length of vector lane in vectorized cooperative fetching */ |
182 | std::vector<int> vector_load_lens; |
183 | /*! \brief Data reuse configuration for reading */ |
184 | ReuseConfig reuse_read_; |
185 | /*! \brief Data reuse configuration for writing */ |
186 | ReuseConfig reuse_write_; |
187 | /*! \brief The indices of spatial tiles in `structure` */ |
188 | std::vector<int> s_indices_; |
189 | /*! \brief The indices of reduction tiles in `structure` */ |
190 | std::vector<int> r_indices_; |
191 | /*! \brief The size of the thread warp */ |
192 | int thread_warp_size_; |
193 | /*! \brief The maximum number of threads to be used size of a thread warp */ |
194 | int max_threads_per_block_; |
195 | /*! \brief The logging function */ |
196 | PackedFunc logger; |
197 | /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ |
198 | Optional<PackedFunc> filter_fn_; |
199 | |
200 | void VisitAttrs(tvm::AttrVisitor* v) { |
201 | v->Visit("structure" , &structure); |
202 | v->Visit("tile_binds" , &tile_binds); |
203 | v->Visit("max_innermost_factor" , &max_innermost_factor); |
204 | // `vector_load_lens` is not visited |
205 | // `reuse_read_` is not visited |
206 | // `reuse_write_` is not visited |
207 | // `s_indices_` is not visited |
208 | // `r_indices_` is not visited |
209 | // `thread_warp_size_` is not visited |
210 | // `max_threads_per_block` is not visited |
211 | } |
212 | |
213 | static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling" ; |
214 | TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); |
215 | }; |
216 | |
217 | template <typename NodeType> |
218 | ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds, |
219 | Optional<Integer> max_innermost_factor, |
220 | Optional<Array<Integer>> vector_load_lens, |
221 | Optional<Map<String, ObjectRef>> reuse_read, |
222 | Optional<Map<String, ObjectRef>> reuse_write) { |
223 | ObjectPtr<NodeType> n = make_object<NodeType>(); |
224 | n->structure = structure; |
225 | n->tile_binds = tile_binds.value_or({}); |
226 | n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; |
227 | n->vector_load_lens = vector_load_lens.defined() |
228 | ? support::AsVector<Integer, int>(vector_load_lens.value()) |
229 | : std::vector<int>(); |
230 | n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); |
231 | n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); |
232 | for (int i = 0, len = structure.size(); i < len; ++i) { |
233 | char c = structure.data()[i]; |
234 | if (c == 'S') { |
235 | n->s_indices_.push_back(i); |
236 | } else if (c == 'R') { |
237 | n->r_indices_.push_back(i); |
238 | } else { |
239 | LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; |
240 | } |
241 | } |
242 | n->thread_warp_size_ = -1; |
243 | n->max_threads_per_block_ = -1; |
244 | return n; |
245 | } |
246 | |
247 | } // namespace meta_schedule |
248 | } // namespace tvm |
249 | |
250 | #endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ |
251 | |