1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ |
20 | #define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ |
21 | |
22 | #include <memory> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "./utils.h" |
27 | |
28 | namespace tvm { |
29 | namespace tir { |
30 | |
31 | class ConcreteScheduleNode : public ScheduleNode { |
32 | friend class Schedule; |
33 | friend class ScheduleCopier; |
34 | |
35 | public: |
36 | using TSymbolTable = Map<ObjectRef, ObjectRef>; |
37 | |
38 | protected: |
39 | /*! \brief The internal state of scheduling */ |
40 | ScheduleState state_; |
41 | /*! \brief The function to be worked on. */ |
42 | Optional<GlobalVar> func_working_on_; |
43 | /*! \brief The level of error rendering */ |
44 | ScheduleErrorRenderLevel error_render_level_; |
45 | /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ |
46 | TSymbolTable symbol_table_; |
47 | /*! \brief A persistent stateless arithmetic analyzer. */ |
48 | std::unique_ptr<arith::Analyzer> analyzer_; |
49 | /*! \brief The value of random state for sampling. */ |
50 | support::LinearCongruentialEngine::TRandState rand_state_; |
51 | |
52 | public: |
53 | void VisitAttrs(tvm::AttrVisitor* v) { |
54 | // `state_` is not visited |
55 | // `func_working_on_` is not visited |
56 | // `error_render_level_` is not visited |
57 | // `symbol_table_` is not visited |
58 | // `analyzer_` is not visited |
59 | // `rand_state_` is not visited |
60 | } |
61 | |
62 | virtual ~ConcreteScheduleNode() = default; |
63 | |
64 | public: |
65 | ScheduleState state() const final { return state_; } |
66 | Optional<Trace> trace() const override { return NullOpt; } |
67 | void WorkOn(const String& func_name) final; |
68 | Schedule Copy() override; |
69 | void Seed(support::LinearCongruentialEngine::TRandState seed) final; |
70 | support::LinearCongruentialEngine::TRandState ForkSeed() final; |
71 | |
72 | public: |
73 | /******** Lookup random variables ********/ |
74 | inline Block Get(const BlockRV& block_rv) const final; |
75 | inline For Get(const LoopRV& loop_rv) const final; |
76 | inline PrimExpr Get(const ExprRV& expr_rv) const final; |
77 | inline StmtSRef GetSRef(const BlockRV& block_rv) const final; |
78 | inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; |
79 | inline bool HasBlock(const BlockRV& block_rv) const final; |
80 | inline Array<StmtSRef> GetSRefs(const Array<BlockRV>& rvs) const; |
81 | inline Array<StmtSRef> GetSRefs(const Array<LoopRV>& rvs) const; |
82 | void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } |
83 | void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } |
84 | void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } |
85 | using ScheduleNode::GetSRef; |
86 | |
87 | public: |
88 | /******** Schedule: Sampling ********/ |
89 | ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, |
90 | Optional<Integer> decision = NullOpt) override; |
91 | Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, |
92 | Optional<Array<Integer>> decision = NullOpt) override; |
93 | LoopRV SampleComputeLocation(const BlockRV& block_rv, |
94 | Optional<Integer> decision = NullOpt) override; |
95 | /******** Schedule: Get blocks & loops ********/ |
96 | BlockRV GetBlock(const String& name, const Optional<String>& func_name) override; |
97 | Array<LoopRV> GetLoops(const BlockRV& block_rv) override; |
98 | Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override; |
99 | Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override; |
100 | Array<BlockRV> GetProducers(const BlockRV& block_rv) override; |
101 | Array<BlockRV> GetConsumers(const BlockRV& block_rv) override; |
102 | /******** Schedule: Transform loops ********/ |
103 | LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override; |
104 | Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors, |
105 | bool preserve_unit_iters) override; |
106 | void Reorder(const Array<LoopRV>& ordered_loop_rvs) override; |
107 | LoopRV AddUnitLoop(const BlockRV& block_rv) override; |
108 | LoopRV AddUnitLoop(const LoopRV& loop_rv) override; |
109 | /******** Schedule: Manipulate ForKind ********/ |
110 | void Parallel(const LoopRV& loop_rv) override; |
111 | void Vectorize(const LoopRV& loop_rv) override; |
112 | void Bind(const LoopRV& loop_rv, const String& thread_axis) override; |
113 | void Unroll(const LoopRV& loop_rv) override; |
114 | /******** Schedule: Insert cache stages ********/ |
115 | BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, |
116 | const Array<BlockRV> consumer_blocks = {}) override; |
117 | BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, |
118 | const Array<BlockRV> consumer_blocks = {}) override; |
119 | Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index, |
120 | const String& storage_scope) override; |
121 | Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope, |
122 | int cse_thresh) override; |
123 | BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, |
124 | BufferIndexType buffer_index_type) override; |
125 | /******** Schedule: Compute location ********/ |
126 | void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, |
127 | int index = -1) override; |
128 | void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, |
129 | int index = -1) override; |
130 | void ComputeInline(const BlockRV& block) override; |
131 | void ReverseComputeInline(const BlockRV& block) override; |
132 | /******** Schedule: Reduction ********/ |
133 | BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; |
134 | BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; |
135 | void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) override; |
136 | /******** Schedule: Block annotation ********/ |
137 | void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, |
138 | int offset) override; |
139 | void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; |
140 | /******** Schedule: Blockize & Tensorize ********/ |
141 | BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; |
142 | void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; |
143 | void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override; |
144 | /******** Schedule: Annotation ********/ |
145 | void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; |
146 | void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; |
147 | void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override; |
148 | void Unannotate(const BlockRV& block_rv, const String& ann_key) override; |
149 | /******** Schedule: Layout transformation ********/ |
150 | void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, |
151 | const IndexMap& index_map, const Optional<IndexMap>& pad_value) override; |
152 | void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; |
153 | void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, |
154 | BufferIndexType buffer_index_type, |
155 | const Array<IntImm>& axis_separators) override; |
156 | /******** Schedule: Padding decomposition ********/ |
157 | BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; |
158 | /******** Schedule: Buffer transformation ********/ |
159 | void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; |
160 | /******** Schedule: Misc ********/ |
161 | void EnterPostproc() override {} |
162 | |
163 | protected: |
164 | /******** Utility functions ********/ |
165 | /*! |
166 | * \brief Copy the schedule state, as well as the symbol table |
167 | * \param new_state The ScheduleState copied |
168 | * \param new_symbol_table The symbol table copied |
169 | */ |
170 | void Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const; |
171 | /*! |
172 | * \brief Add srefs as random variables into the symbol table |
173 | * \tparam T The type of the random variables |
174 | * \param srefs The srefs to be added to the symbol table |
175 | * \return The new random variables created |
176 | */ |
177 | template <class T> |
178 | inline Array<T> CreateRV(const Array<StmtSRef>& srefs); |
179 | /*! |
180 | * \brief Add an sref as a random variable into the symbol table |
181 | * \tparam T The type of the random variable |
182 | * \param sref The sref to be added to the symbol table |
183 | * \return The new random variable created |
184 | */ |
185 | template <class T> |
186 | inline T CreateRV(const StmtSRef& sref); |
187 | /*! |
188 | * \brief Add an integer as a random variable into the symbol table |
189 | * \param value The integer to be added to the symbol table |
190 | * \return The new random variable created |
191 | */ |
192 | inline ExprRV CreateRV(int64_t value); |
193 | /*! |
194 | * \brief Add a list of integers as random variables into the symbol table |
195 | * \param value The list of integers to be added to the symbol table |
196 | * \return The new random variables created |
197 | */ |
198 | inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value); |
199 | /*! \brief Remove a random variable from the symbol table */ |
200 | inline void RemoveFromSymbolTable(const ObjectRef& rv); |
201 | /*! |
202 | * \brief Check the annotation value is valid and look up the random variable. Raises an exception |
203 | * if the type of the annotation value is not allowed. |
204 | * \param The annotation value. |
205 | * \return The annotation value with random variables substituted with their values. |
206 | */ |
207 | ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val); |
208 | }; |
209 | |
210 | // implementations |
211 | |
212 | /******** Lookup random variables ********/ |
213 | |
214 | inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { |
215 | StmtSRef sref = this->GetSRef(block_rv); |
216 | const BlockNode* block = TVM_SREF_TO_BLOCK(sref); |
217 | return GetRef<Block>(block); |
218 | } |
219 | |
220 | inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { |
221 | StmtSRef sref = this->GetSRef(loop_rv); |
222 | const ForNode* loop = TVM_SREF_TO_FOR(sref); |
223 | return GetRef<For>(loop); |
224 | } |
225 | |
226 | inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { |
227 | PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional<PrimExpr> { |
228 | auto it = this->symbol_table_.find(var); |
229 | if (it == this->symbol_table_.end()) { |
230 | LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; |
231 | } |
232 | const ObjectRef& obj = (*it).second; |
233 | const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode); |
234 | return Integer(int_imm->value); |
235 | }); |
236 | return this->analyzer_->Simplify(transformed); |
237 | } |
238 | |
239 | inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { |
240 | auto it = this->symbol_table_.find(block_rv); |
241 | if (it == this->symbol_table_.end()) { |
242 | return false; |
243 | } |
244 | const ObjectRef& obj = (*it).second; |
245 | const auto* sref = obj.as<StmtSRefNode>(); |
246 | if (sref == nullptr || sref->stmt == nullptr) { |
247 | return false; |
248 | } |
249 | return true; |
250 | } |
251 | |
252 | inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { |
253 | auto it = this->symbol_table_.find(block_rv); |
254 | if (it == this->symbol_table_.end()) { |
255 | LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; |
256 | } |
257 | const ObjectRef& obj = (*it).second; |
258 | const auto* sref = obj.as<StmtSRefNode>(); |
259 | if (sref == nullptr) { |
260 | LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " |
261 | << (obj.defined() ? obj->GetTypeKey() : "None" ); |
262 | } |
263 | if (sref->stmt == nullptr) { |
264 | LOG(FATAL) << "ValueError: The block no longer exists in the IRModule" ; |
265 | } |
266 | return GetRef<StmtSRef>(sref); |
267 | } |
268 | |
269 | inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { |
270 | static StmtSRef inline_mark = StmtSRef::InlineMark(); |
271 | static StmtSRef root_mark = StmtSRef::RootMark(); |
272 | auto it = this->symbol_table_.find(loop_rv); |
273 | if (it == this->symbol_table_.end()) { |
274 | LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; |
275 | } |
276 | const ObjectRef& obj = (*it).second; |
277 | if (obj.same_as(inline_mark)) { |
278 | return inline_mark; |
279 | } |
280 | if (obj.same_as(root_mark)) { |
281 | return root_mark; |
282 | } |
283 | const auto* sref = obj.as<StmtSRefNode>(); |
284 | if (sref == nullptr) { |
285 | LOG(FATAL) << "ValueError: LoopRV's corresponding type is invalid: " |
286 | << (obj.defined() ? obj->GetTypeKey() : "None" ); |
287 | } |
288 | if (sref->stmt == nullptr) { |
289 | LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule" ; |
290 | } |
291 | return GetRef<StmtSRef>(sref); |
292 | } |
293 | |
294 | template <class T> |
295 | inline Array<StmtSRef> GetSRefsHelper(const ConcreteScheduleNode* sch, const Array<T>& rvs) { |
296 | Array<StmtSRef> result; |
297 | result.reserve(rvs.size()); |
298 | for (const T& rv : rvs) { |
299 | result.push_back(sch->GetSRef(rv)); |
300 | } |
301 | return result; |
302 | } |
303 | |
304 | inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<BlockRV>& rvs) const { |
305 | return GetSRefsHelper(this, rvs); |
306 | } |
307 | |
308 | inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<LoopRV>& rvs) const { |
309 | return GetSRefsHelper(this, rvs); |
310 | } |
311 | |
312 | /******** Adding/Removing elements in the symbol table ********/ |
313 | |
314 | template <class T> |
315 | inline Array<T> ConcreteScheduleNode::CreateRV(const Array<StmtSRef>& srefs) { |
316 | Array<T> result; |
317 | result.reserve(srefs.size()); |
318 | for (const StmtSRef& sref : srefs) { |
319 | T rv; |
320 | this->symbol_table_.Set(rv, sref); |
321 | result.push_back(rv); |
322 | } |
323 | return result; |
324 | } |
325 | |
326 | template <class T> |
327 | inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { |
328 | T rv; |
329 | this->symbol_table_.Set(rv, sref); |
330 | return std::move(rv); |
331 | } |
332 | |
333 | inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { |
334 | Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); |
335 | this->symbol_table_.Set(rv, Integer(static_cast<int32_t>(value))); |
336 | return std::move(rv); |
337 | } |
338 | |
339 | inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value) { |
340 | Array<ExprRV> results; |
341 | results.reserve(value.size()); |
342 | for (int64_t v : value) { |
343 | results.push_back(CreateRV(v)); |
344 | } |
345 | return results; |
346 | } |
347 | |
348 | inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { |
349 | auto it = this->symbol_table_.find(obj); |
350 | if (it != this->symbol_table_.end()) { |
351 | this->symbol_table_.erase(obj); |
352 | } else { |
353 | LOG(FATAL) << "IndexError: Cannot find the object in the symbol table: " << obj; |
354 | throw; |
355 | } |
356 | } |
357 | |
358 | } // namespace tir |
359 | } // namespace tvm |
360 | |
361 | #endif // TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ |
362 | |