1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_
20#define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_
21
22#include <memory>
23#include <utility>
24#include <vector>
25
26#include "./utils.h"
27
28namespace tvm {
29namespace tir {
30
31class ConcreteScheduleNode : public ScheduleNode {
32 friend class Schedule;
33 friend class ScheduleCopier;
34
35 public:
36 using TSymbolTable = Map<ObjectRef, ObjectRef>;
37
38 protected:
39 /*! \brief The internal state of scheduling */
40 ScheduleState state_;
41 /*! \brief The function to be worked on. */
42 Optional<GlobalVar> func_working_on_;
43 /*! \brief The level of error rendering */
44 ScheduleErrorRenderLevel error_render_level_;
45 /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */
46 TSymbolTable symbol_table_;
47 /*! \brief A persistent stateless arithmetic analyzer. */
48 std::unique_ptr<arith::Analyzer> analyzer_;
49 /*! \brief The value of random state for sampling. */
50 support::LinearCongruentialEngine::TRandState rand_state_;
51
52 public:
53 void VisitAttrs(tvm::AttrVisitor* v) {
54 // `state_` is not visited
55 // `func_working_on_` is not visited
56 // `error_render_level_` is not visited
57 // `symbol_table_` is not visited
58 // `analyzer_` is not visited
59 // `rand_state_` is not visited
60 }
61
62 virtual ~ConcreteScheduleNode() = default;
63
64 public:
65 ScheduleState state() const final { return state_; }
66 Optional<Trace> trace() const override { return NullOpt; }
67 void WorkOn(const String& func_name) final;
68 Schedule Copy() override;
69 void Seed(support::LinearCongruentialEngine::TRandState seed) final;
70 support::LinearCongruentialEngine::TRandState ForkSeed() final;
71
72 public:
73 /******** Lookup random variables ********/
74 inline Block Get(const BlockRV& block_rv) const final;
75 inline For Get(const LoopRV& loop_rv) const final;
76 inline PrimExpr Get(const ExprRV& expr_rv) const final;
77 inline StmtSRef GetSRef(const BlockRV& block_rv) const final;
78 inline StmtSRef GetSRef(const LoopRV& loop_rv) const final;
79 inline bool HasBlock(const BlockRV& block_rv) const final;
80 inline Array<StmtSRef> GetSRefs(const Array<BlockRV>& rvs) const;
81 inline Array<StmtSRef> GetSRefs(const Array<LoopRV>& rvs) const;
82 void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); }
83 void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); }
84 void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); }
85 using ScheduleNode::GetSRef;
86
87 public:
88 /******** Schedule: Sampling ********/
89 ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
90 Optional<Integer> decision = NullOpt) override;
91 Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
92 Optional<Array<Integer>> decision = NullOpt) override;
93 LoopRV SampleComputeLocation(const BlockRV& block_rv,
94 Optional<Integer> decision = NullOpt) override;
95 /******** Schedule: Get blocks & loops ********/
96 BlockRV GetBlock(const String& name, const Optional<String>& func_name) override;
97 Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
98 Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
99 Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
100 Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
101 Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
102 /******** Schedule: Transform loops ********/
103 LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
104 Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
105 bool preserve_unit_iters) override;
106 void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
107 LoopRV AddUnitLoop(const BlockRV& block_rv) override;
108 LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
109 /******** Schedule: Manipulate ForKind ********/
110 void Parallel(const LoopRV& loop_rv) override;
111 void Vectorize(const LoopRV& loop_rv) override;
112 void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
113 void Unroll(const LoopRV& loop_rv) override;
114 /******** Schedule: Insert cache stages ********/
115 BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
116 const Array<BlockRV> consumer_blocks = {}) override;
117 BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
118 const Array<BlockRV> consumer_blocks = {}) override;
119 Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
120 const String& storage_scope) override;
121 Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
122 int cse_thresh) override;
123 BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
124 BufferIndexType buffer_index_type) override;
125 /******** Schedule: Compute location ********/
126 void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
127 int index = -1) override;
128 void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
129 int index = -1) override;
130 void ComputeInline(const BlockRV& block) override;
131 void ReverseComputeInline(const BlockRV& block) override;
132 /******** Schedule: Reduction ********/
133 BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
134 BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
135 void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) override;
136 /******** Schedule: Block annotation ********/
137 void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
138 int offset) override;
139 void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
140 /******** Schedule: Blockize & Tensorize ********/
141 BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
142 void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
143 void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override;
144 /******** Schedule: Annotation ********/
145 void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
146 void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
147 void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override;
148 void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
149 /******** Schedule: Layout transformation ********/
150 void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
151 const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
152 void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
153 void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
154 BufferIndexType buffer_index_type,
155 const Array<IntImm>& axis_separators) override;
156 /******** Schedule: Padding decomposition ********/
157 BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override;
158 /******** Schedule: Buffer transformation ********/
159 void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override;
160 /******** Schedule: Misc ********/
161 void EnterPostproc() override {}
162
163 protected:
164 /******** Utility functions ********/
165 /*!
166 * \brief Copy the schedule state, as well as the symbol table
167 * \param new_state The ScheduleState copied
168 * \param new_symbol_table The symbol table copied
169 */
170 void Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const;
171 /*!
172 * \brief Add srefs as random variables into the symbol table
173 * \tparam T The type of the random variables
174 * \param srefs The srefs to be added to the symbol table
175 * \return The new random variables created
176 */
177 template <class T>
178 inline Array<T> CreateRV(const Array<StmtSRef>& srefs);
179 /*!
180 * \brief Add an sref as a random variable into the symbol table
181 * \tparam T The type of the random variable
182 * \param sref The sref to be added to the symbol table
183 * \return The new random variable created
184 */
185 template <class T>
186 inline T CreateRV(const StmtSRef& sref);
187 /*!
188 * \brief Add an integer as a random variable into the symbol table
189 * \param value The integer to be added to the symbol table
190 * \return The new random variable created
191 */
192 inline ExprRV CreateRV(int64_t value);
193 /*!
194 * \brief Add a list of integers as random variables into the symbol table
195 * \param value The list of integers to be added to the symbol table
196 * \return The new random variables created
197 */
198 inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
199 /*! \brief Remove a random variable from the symbol table */
200 inline void RemoveFromSymbolTable(const ObjectRef& rv);
201 /*!
202 * \brief Check the annotation value is valid and look up the random variable. Raises an exception
203 * if the type of the annotation value is not allowed.
204 * \param The annotation value.
205 * \return The annotation value with random variables substituted with their values.
206 */
207 ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val);
208};
209
210// implementations
211
212/******** Lookup random variables ********/
213
214inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const {
215 StmtSRef sref = this->GetSRef(block_rv);
216 const BlockNode* block = TVM_SREF_TO_BLOCK(sref);
217 return GetRef<Block>(block);
218}
219
220inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
221 StmtSRef sref = this->GetSRef(loop_rv);
222 const ForNode* loop = TVM_SREF_TO_FOR(sref);
223 return GetRef<For>(loop);
224}
225
226inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
227 PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional<PrimExpr> {
228 auto it = this->symbol_table_.find(var);
229 if (it == this->symbol_table_.end()) {
230 LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var;
231 }
232 const ObjectRef& obj = (*it).second;
233 const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode);
234 return Integer(int_imm->value);
235 });
236 return this->analyzer_->Simplify(transformed);
237}
238
239inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const {
240 auto it = this->symbol_table_.find(block_rv);
241 if (it == this->symbol_table_.end()) {
242 return false;
243 }
244 const ObjectRef& obj = (*it).second;
245 const auto* sref = obj.as<StmtSRefNode>();
246 if (sref == nullptr || sref->stmt == nullptr) {
247 return false;
248 }
249 return true;
250}
251
252inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const {
253 auto it = this->symbol_table_.find(block_rv);
254 if (it == this->symbol_table_.end()) {
255 LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv;
256 }
257 const ObjectRef& obj = (*it).second;
258 const auto* sref = obj.as<StmtSRefNode>();
259 if (sref == nullptr) {
260 LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: "
261 << (obj.defined() ? obj->GetTypeKey() : "None");
262 }
263 if (sref->stmt == nullptr) {
264 LOG(FATAL) << "ValueError: The block no longer exists in the IRModule";
265 }
266 return GetRef<StmtSRef>(sref);
267}
268
269inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const {
270 static StmtSRef inline_mark = StmtSRef::InlineMark();
271 static StmtSRef root_mark = StmtSRef::RootMark();
272 auto it = this->symbol_table_.find(loop_rv);
273 if (it == this->symbol_table_.end()) {
274 LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv;
275 }
276 const ObjectRef& obj = (*it).second;
277 if (obj.same_as(inline_mark)) {
278 return inline_mark;
279 }
280 if (obj.same_as(root_mark)) {
281 return root_mark;
282 }
283 const auto* sref = obj.as<StmtSRefNode>();
284 if (sref == nullptr) {
285 LOG(FATAL) << "ValueError: LoopRV's corresponding type is invalid: "
286 << (obj.defined() ? obj->GetTypeKey() : "None");
287 }
288 if (sref->stmt == nullptr) {
289 LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule";
290 }
291 return GetRef<StmtSRef>(sref);
292}
293
294template <class T>
295inline Array<StmtSRef> GetSRefsHelper(const ConcreteScheduleNode* sch, const Array<T>& rvs) {
296 Array<StmtSRef> result;
297 result.reserve(rvs.size());
298 for (const T& rv : rvs) {
299 result.push_back(sch->GetSRef(rv));
300 }
301 return result;
302}
303
304inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<BlockRV>& rvs) const {
305 return GetSRefsHelper(this, rvs);
306}
307
308inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<LoopRV>& rvs) const {
309 return GetSRefsHelper(this, rvs);
310}
311
312/******** Adding/Removing elements in the symbol table ********/
313
314template <class T>
315inline Array<T> ConcreteScheduleNode::CreateRV(const Array<StmtSRef>& srefs) {
316 Array<T> result;
317 result.reserve(srefs.size());
318 for (const StmtSRef& sref : srefs) {
319 T rv;
320 this->symbol_table_.Set(rv, sref);
321 result.push_back(rv);
322 }
323 return result;
324}
325
326template <class T>
327inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) {
328 T rv;
329 this->symbol_table_.Set(rv, sref);
330 return std::move(rv);
331}
332
333inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) {
334 Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32));
335 this->symbol_table_.Set(rv, Integer(static_cast<int32_t>(value)));
336 return std::move(rv);
337}
338
339inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value) {
340 Array<ExprRV> results;
341 results.reserve(value.size());
342 for (int64_t v : value) {
343 results.push_back(CreateRV(v));
344 }
345 return results;
346}
347
348inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) {
349 auto it = this->symbol_table_.find(obj);
350 if (it != this->symbol_table_.end()) {
351 this->symbol_table_.erase(obj);
352 } else {
353 LOG(FATAL) << "IndexError: Cannot find the object in the symbol table: " << obj;
354 throw;
355 }
356}
357
358} // namespace tir
359} // namespace tvm
360
361#endif // TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_
362