1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/tir/schedule/block_scope.h
21 * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
22 * \sa StmtSRefNode
23 * \sa BlockScopeNode
24 */
25#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
26#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
27
28#include <tvm/tir/stmt.h>
29
30#include <unordered_map>
31
32namespace tvm {
33namespace tir {
34
35/*!
36 * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
37 *
38 * Glossary
39 * - Block sref: A StmtSRef that points to a TensorIR block.
40 * - Loop sref: A StmtSRef that points to a TensorIR for loop.
41 * - Parent sref: The parent reference of an sref is the block or loop reference to the closest
42 schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in
43 the AST.
44 * schedulable statement of its ancestors on the TensorIR AST.
45 * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
46 * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
47 * the TensorIR AST.
48 */
49class StmtSRefNode : public Object {
50 public:
51 /*!
52 * \brief The block or `for` stmt the object refers to
53 * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
54 * optimization on statements when possible. The strong reference is held in the ScheduleState.
55 */
56 const StmtNode* stmt;
57 /*! \brief The parent sref. */
58 StmtSRefNode* parent;
59 /*!
60 * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
61 * then `seq_index` is set to its index; otherwise `seq_index` is -1
62 */
63 int64_t seq_index;
64
65 void VisitAttrs(AttrVisitor* v) {
66 // `stmt` is not visited
67 // `parent` is not visited
68 v->Visit("seq_index", &seq_index);
69 }
70
71 static constexpr const char* _type_key = "tir.StmtSRef";
72 TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
73
74 /*! \brief Reset the object inplace to the invalid state */
75 void Reset() {
76 this->stmt = nullptr;
77 this->parent = nullptr;
78 this->seq_index = -1;
79 }
80
81 /*!
82 * \brief Get the referenced statement with proper type checking.
83 * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
84 * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
85 * tvm::tir::BlockNode or tvm::tir::ForNode
86 * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
87 */
88 template <typename StmtType>
89 const StmtType* StmtAs() const {
90 if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
91 return static_cast<const StmtType*>(stmt);
92 } else {
93 return nullptr;
94 }
95 }
96};
97
98/*!
99 * \brief Managed reference to StmtSRefNode
100 * \sa StmtSRefNode
101 */
102class StmtSRef : public ObjectRef {
103 public:
104 /*!
105 * \brief The constructor
106 * \param stmt The corresponding stmt node, can be either block or for loop.
107 * \param parent The parent sref.
108 * \param seq_index The location in an array if the parent of the stmt contains multiple children.
109 * -1 if the parent does not contain multiple children.
110 */
111 TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
112
113 /*! \return The mutable pointer to the StmtSRefNode */
114 StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
115
116 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
117
118 public:
119 /*!
120 * \return A special StmtSRef, which doesn't point to any stmt in the AST,
121 * only serving as a "mark" to hint compute-at to do the work of compute-inline
122 * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
123 * i.e.
124 *
125 * compute-at(block, loop_sref):
126 * compute-inline(block) if loop_sref.same_as(InlineMark())
127 * no-op if loop_sref.same_as(RootMark())
128 * compute-at-impl(block, loop_sref) otherwise
129 */
130 TVM_DLL static StmtSRef InlineMark();
131 /*!
132 * \return A special StmtSRef, which doesn't point to any stmt in the AST,
133 * only serving as a "mark" to hint compute-at to do nothing
134 * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
135 * i.e.
136 *
137 * compute-at(block, loop_sref):
138 * compute-inline(block) if loop_sref.same_as(InlineMark())
139 * no-op if loop_sref.same_as(RootMark())
140 * compute-at-impl(block, loop_sref) otherwise
141 */
142 TVM_DLL static StmtSRef RootMark();
143};
144
145/*!
146 * \brief Type of dependency. Right now we have 4 types of dependencies
147 * 1) Read-after-write (kRAW)
148 * 2) Write-after-write (kWAW)
149 * 3) Write-after-read (kWAR)
150 * 4) Opaque dependency (kOpaque)
151 */
152enum class DepKind : int32_t {
153 kRAW = 0,
154 kWAW = 1,
155 kWAR = 2,
156 kOpaque = 3,
157};
158
159/*!
160 * \brief A tuple (src, dst, kind) representing certain types of dependency.
161 * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
162 * read-after-write, which means block B reads the result written by block A.
163 */
164class DependencyNode : public Object {
165 public:
166 /*! \brief The source of the dependency relation */
167 StmtSRef src;
168 /*! \brief The destination of the dependency relation */
169 StmtSRef dst;
170 /*! \brief The dependency kind */
171 DepKind kind;
172
173 void VisitAttrs(AttrVisitor* v) {
174 v->Visit("src", &src);
175 v->Visit("dst", &dst);
176 v->Visit("kind", &kind);
177 }
178
179 static constexpr const char* _type_key = "tir.Dependency";
180 TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
181};
182
183/*!
184 * \brief Managed reference to DependencyNode
185 * \sa DependencyNode
186 */
187class Dependency : public ObjectRef {
188 public:
189 /*! \brief Constructor */
190 TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
191 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
192};
193
194/*!
195 * \brief An object with 1-to-1 correspondence with each block reference in the sref tree.
196 * This data structure is used to track the producer-consumer dependencies between blocks.
197 * For example even leaf nodes have a scope node, even though they have no dependencies.
198 *
199 * Glossary:
200 * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
201 * whose components are:
202 * - scope root: a block sref
203 * - internal srefs: loop srefs
204 * - scope leaves: block srefs
205 * - Child block: The scope leaf blocks under the scope root or a specific internal sref
206 */
207class BlockScopeNode : public Object {
208 public:
209 /*!
210 * \brief Lookup table for the `src` of dependencies
211 * \note We intentionally didn't use tvm::Map as the data structure, because we need the values
212 * inside to be mutable so that they could be further maintained properly during transformations.
213 */
214 std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
215 /*! \brief Lookup table for the `dst` of dependencies */
216 std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
217 /*! \brief The mapping from the buffer to the blocks who write it */
218 std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
219 /*!
220 * \brief This property indicates that the block scope (rooted at its corresponding block) is
221 * equivalent to of a stage pipeline. Under the following conditions:
222 *
223 * 1) The region cover property holds for every of its child blocks
224 * 2) No write-after-read dependency or opaque dependency, only read-after-write and
225 * write-after-write are allowed
226 * 3) All the statements in the scope are schedulable statements, i.e. Block and For
227 */
228 bool stage_pipeline{false};
229
230 void VisitAttrs(AttrVisitor* v) {}
231
232 static constexpr const char* _type_key = "tir.BlockScope";
233 TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
234
235 public:
236 /******** Dependency ********/
237 /*!
238 * \brief Get all dependencies whose `src` equals `src`
239 * \param src The queried block
240 * \return The dependencies
241 */
242 TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
243 /*!
244 * \brief Get all dependencies whose `dst` equals `dst`
245 * \param dst The queried block
246 * \return The dependencies
247 */
248 TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
249};
250
251/*!
252 * \brief Managed reference to BlockScopeNode
253 * \sa BlockScopeNode
254 */
255class BlockScope : public ObjectRef {
256 public:
257 /*! \brief The constructor creating an empty block scope with on dependency information */
258 TVM_DLL BlockScope();
259 /*!
260 * \brief Create the object with the specific leaf blocks, and compute the dependency information
261 * between the leaf blocks.
262 * \param child_block_srefs The srefs to the leaf blocks
263 * \note We assume the leaf blocks are given in pre-DFS order
264 */
265 TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);
266
267 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
268};
269
270} // namespace tir
271} // namespace tvm
272
273#endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
274