1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
20#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
21
22#include <tvm/meta_schedule/schedule_rule.h>
23#include <tvm/tir/schedule/schedule.h>
24
25#include <unordered_map>
26#include <utility>
27#include <vector>
28
29#include "../../support/array.h"
30
31namespace tvm {
32namespace meta_schedule {
33
34/*!
35 * \brief Configuration of data reuse type:
36 * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed.
37 * 1) kMayReuse: reuse is allowed, but no reuse is explored.
38 * 2) kMustReuse: reuse is allowed and no reuse is not explored.
39 */
40enum class ReuseType : int32_t {
41 kNoReuse = 0,
42 kMayReuse = 1,
43 kMustReuse = 2,
44};
45
46/*!
47 * \brief Converts a string to ReuseType.
48 * \param str The string to be converted.
49 * \return The converted ReuseType.
50 */
51inline ReuseType Str2ReuseType(const String& str) {
52 if (str == "no") {
53 return ReuseType::kNoReuse;
54 } else if (str == "may") {
55 return ReuseType::kMayReuse;
56 } else if (str == "must") {
57 return ReuseType::kMustReuse;
58 } else {
59 LOG(FATAL) << "ValueError: Unknown ReuseType: " << str;
60 throw;
61 }
62}
63
64/*! \brief Configuration of data reuse patterns */
65struct ReuseConfig {
66 /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */
67 ReuseType req;
68 /*! \brief Which levels are caching stage inserted at */
69 std::vector<int> levels;
70 /*! \brief The storage scope */
71 String scope;
72
73 /*! \brief Default constructor: no data reuse */
74 ReuseConfig() : req(ReuseType::kNoReuse) {}
75
76 /*! \brief Construct from a configuration dictionary */
77 explicit ReuseConfig(const Map<String, ObjectRef>& config)
78 : req(Str2ReuseType(Downcast<String>(config.at("req")))),
79 levels(support::AsVector<Integer, int>(Downcast<Array<Integer>>(config.at("levels")))),
80 scope(Downcast<String>(config.at("scope"))) {
81 ICHECK_EQ(config.size(), 3);
82 }
83};
84
85// Forware declaration
86class State;
87
88/*! \brief The state of auto scheduling for the multi-level tiling rule */
89class StateNode : public Object {
90 public:
91 /*! \brief The schedule to date */
92 tir::Schedule sch;
93 /*! \brief The block to be tiled */
94 tir::BlockRV block_rv;
95 /*! \brief The loop tiles */
96 Array<Array<tir::LoopRV>> tiles;
97 /*! \brief The mapping from buffer index to read cache block. */
98 std::unordered_map<int, tir::BlockRV> read_reuse;
99 /*! \brief The mapping from buffer index to write cache block. */
100 std::unordered_map<int, tir::BlockRV> write_reuse;
101
102 /*!
103 * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
104 * produce multiple states should use this method to create new states.
105 */
106 virtual State Copy() const;
107
108 static constexpr const char* _type_key = "meta_schedule.State";
109 TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object);
110};
111
112/*! \brief Managed reference to StateNode */
113class State : public ObjectRef {
114 public:
115 /*! \brief Default constructor */
116 explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
117 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
118};
119
120/*!
121 * \brief Helper to apply a sub-rule to a list of auto scheduling states
122 * \tparam FLambda The type of the sub-rule functor
123 * \param states The list of states to be applied
124 * \return The list of states after applying the sub-rule
125 */
126template <class FLambda>
127std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) {
128 std::vector<State> results;
129 for (auto&& state : states) {
130 std::vector<State> next = sub_rule(std::move(state));
131 results.insert(results.end(), //
132 std::make_move_iterator(next.begin()), //
133 std::make_move_iterator(next.end()));
134 }
135 return results;
136}
137
138/*!
139 * \brief The mega rule: multi-level tiling with data reuse
140 */
141class MultiLevelTilingNode : public ScheduleRuleNode {
142 public:
143 virtual ~MultiLevelTilingNode() = default;
144
145 // SubRule 1. add write cache
146 std::vector<State> AddWriteReuse(State state) const;
147 // SubRule 2. tile the loop nest
148 std::vector<State> TileLoopNest(State state) const;
149 // SubRule 3. add read cache
150 std::vector<State> AddReadReuse(State state) const;
151
152 // Do nothing; Inherited from ScheduleRuleNode
153 void InitializeWithTuneContext(const TuneContext& context) final;
154
155 // Entry of the mega rule; Inherited from ScheduleRuleNode
156 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;
157
158 // Inherited from ScheduleRuleNode
159 ScheduleRule Clone() const override;
160
161 protected:
162 virtual std::vector<State> ApplySubRules(std::vector<State> states);
163
164 virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block,
165 tir::LoopRV loop, int n_tiles) const;
166
167 // Annotate a block to use cooperative fetching
168 void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;
169
170 public:
171 /*!
172 * \brief The tiling structure. Recommended:
173 * - 'SSRSRS' on CPU
174 * - 'SSSRRSRS' on GPU
175 */
176 String structure;
177 /*! \brief For each level of tiles, which thread axis it is bound to */
178 Array<String> tile_binds;
179 /*! \brief The maximum size of the innermost factor */
180 int max_innermost_factor;
181 /*! \brief The length of vector lane in vectorized cooperative fetching */
182 std::vector<int> vector_load_lens;
183 /*! \brief Data reuse configuration for reading */
184 ReuseConfig reuse_read_;
185 /*! \brief Data reuse configuration for writing */
186 ReuseConfig reuse_write_;
187 /*! \brief The indices of spatial tiles in `structure` */
188 std::vector<int> s_indices_;
189 /*! \brief The indices of reduction tiles in `structure` */
190 std::vector<int> r_indices_;
191 /*! \brief The size of the thread warp */
192 int thread_warp_size_;
193 /*! \brief The maximum number of threads to be used size of a thread warp */
194 int max_threads_per_block_;
195 /*! \brief The logging function */
196 PackedFunc logger;
197 /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */
198 Optional<PackedFunc> filter_fn_;
199
200 void VisitAttrs(tvm::AttrVisitor* v) {
201 v->Visit("structure", &structure);
202 v->Visit("tile_binds", &tile_binds);
203 v->Visit("max_innermost_factor", &max_innermost_factor);
204 // `vector_load_lens` is not visited
205 // `reuse_read_` is not visited
206 // `reuse_write_` is not visited
207 // `s_indices_` is not visited
208 // `r_indices_` is not visited
209 // `thread_warp_size_` is not visited
210 // `max_threads_per_block` is not visited
211 }
212
213 static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling";
214 TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode);
215};
216
217template <typename NodeType>
218ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
219 Optional<Integer> max_innermost_factor,
220 Optional<Array<Integer>> vector_load_lens,
221 Optional<Map<String, ObjectRef>> reuse_read,
222 Optional<Map<String, ObjectRef>> reuse_write) {
223 ObjectPtr<NodeType> n = make_object<NodeType>();
224 n->structure = structure;
225 n->tile_binds = tile_binds.value_or({});
226 n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
227 n->vector_load_lens = vector_load_lens.defined()
228 ? support::AsVector<Integer, int>(vector_load_lens.value())
229 : std::vector<int>();
230 n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
231 n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
232 for (int i = 0, len = structure.size(); i < len; ++i) {
233 char c = structure.data()[i];
234 if (c == 'S') {
235 n->s_indices_.push_back(i);
236 } else if (c == 'R') {
237 n->r_indices_.push_back(i);
238 } else {
239 LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
240 }
241 }
242 n->thread_warp_size_ = -1;
243 n->max_threads_per_block_ = -1;
244 return n;
245}
246
247} // namespace meta_schedule
248} // namespace tvm
249
250#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
251