1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/tir/schedule/block_scope.h |
21 | * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope. |
22 | * \sa StmtSRefNode |
23 | * \sa BlockScopeNode |
24 | */ |
25 | #ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ |
26 | #define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ |
27 | |
28 | #include <tvm/tir/stmt.h> |
29 | |
30 | #include <unordered_map> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | /*! |
36 | * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". |
37 | * |
38 | * Glossary |
39 | * - Block sref: A StmtSRef that points to a TensorIR block. |
40 | * - Loop sref: A StmtSRef that points to a TensorIR for loop. |
41 | * - Parent sref: The parent reference of an sref is the block or loop reference to the closest |
42 | schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in |
43 | the AST. |
44 | * schedulable statement of its ancestors on the TensorIR AST. |
45 | * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref. |
46 | * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by |
47 | * the TensorIR AST. |
48 | */ |
49 | class StmtSRefNode : public Object { |
50 | public: |
51 | /*! |
52 | * \brief The block or `for` stmt the object refers to |
53 | * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write |
54 | * optimization on statements when possible. The strong reference is held in the ScheduleState. |
55 | */ |
56 | const StmtNode* stmt; |
57 | /*! \brief The parent sref. */ |
58 | StmtSRefNode* parent; |
59 | /*! |
60 | * \brief If the statement the sref points to is an element of a SeqStmt in the AST, |
61 | * then `seq_index` is set to its index; otherwise `seq_index` is -1 |
62 | */ |
63 | int64_t seq_index; |
64 | |
65 | void VisitAttrs(AttrVisitor* v) { |
66 | // `stmt` is not visited |
67 | // `parent` is not visited |
68 | v->Visit("seq_index" , &seq_index); |
69 | } |
70 | |
71 | static constexpr const char* _type_key = "tir.StmtSRef" ; |
72 | TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object); |
73 | |
74 | /*! \brief Reset the object inplace to the invalid state */ |
75 | void Reset() { |
76 | this->stmt = nullptr; |
77 | this->parent = nullptr; |
78 | this->seq_index = -1; |
79 | } |
80 | |
81 | /*! |
82 | * \brief Get the referenced statement with proper type checking. |
83 | * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt` |
84 | * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably |
85 | * tvm::tir::BlockNode or tvm::tir::ForNode |
86 | * \return nullptr if type check fails, otherwise the casted result for `this->stmt` |
87 | */ |
88 | template <typename StmtType> |
89 | const StmtType* StmtAs() const { |
90 | if (stmt != nullptr && stmt->IsInstance<StmtType>()) { |
91 | return static_cast<const StmtType*>(stmt); |
92 | } else { |
93 | return nullptr; |
94 | } |
95 | } |
96 | }; |
97 | |
98 | /*! |
99 | * \brief Managed reference to StmtSRefNode |
100 | * \sa StmtSRefNode |
101 | */ |
102 | class StmtSRef : public ObjectRef { |
103 | public: |
104 | /*! |
105 | * \brief The constructor |
106 | * \param stmt The corresponding stmt node, can be either block or for loop. |
107 | * \param parent The parent sref. |
108 | * \param seq_index The location in an array if the parent of the stmt contains multiple children. |
109 | * -1 if the parent does not contain multiple children. |
110 | */ |
111 | TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index); |
112 | |
113 | /*! \return The mutable pointer to the StmtSRefNode */ |
114 | StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); } |
115 | |
116 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode); |
117 | |
118 | public: |
119 | /*! |
120 | * \return A special StmtSRef, which doesn't point to any stmt in the AST, |
121 | * only serving as a "mark" to hint compute-at to do the work of compute-inline |
122 | * \note This is only as a faked loop sref for compute-at and reverse-compute-at, |
123 | * i.e. |
124 | * |
125 | * compute-at(block, loop_sref): |
126 | * compute-inline(block) if loop_sref.same_as(InlineMark()) |
127 | * no-op if loop_sref.same_as(RootMark()) |
128 | * compute-at-impl(block, loop_sref) otherwise |
129 | */ |
130 | TVM_DLL static StmtSRef InlineMark(); |
131 | /*! |
132 | * \return A special StmtSRef, which doesn't point to any stmt in the AST, |
133 | * only serving as a "mark" to hint compute-at to do nothing |
134 | * \note This is only as a faked loop sref for compute-at and reverse-compute-at, |
135 | * i.e. |
136 | * |
137 | * compute-at(block, loop_sref): |
138 | * compute-inline(block) if loop_sref.same_as(InlineMark()) |
139 | * no-op if loop_sref.same_as(RootMark()) |
140 | * compute-at-impl(block, loop_sref) otherwise |
141 | */ |
142 | TVM_DLL static StmtSRef RootMark(); |
143 | }; |
144 | |
145 | /*! |
146 | * \brief Type of dependency. Right now we have 4 types of dependencies |
147 | * 1) Read-after-write (kRAW) |
148 | * 2) Write-after-write (kWAW) |
149 | * 3) Write-after-read (kWAR) |
150 | * 4) Opaque dependency (kOpaque) |
151 | */ |
152 | enum class DepKind : int32_t { |
153 | kRAW = 0, |
154 | kWAW = 1, |
155 | kWAR = 2, |
156 | kOpaque = 3, |
157 | }; |
158 | |
159 | /*! |
160 | * \brief A tuple (src, dst, kind) representing certain types of dependency. |
161 | * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is |
162 | * read-after-write, which means block B reads the result written by block A. |
163 | */ |
164 | class DependencyNode : public Object { |
165 | public: |
166 | /*! \brief The source of the dependency relation */ |
167 | StmtSRef src; |
168 | /*! \brief The destination of the dependency relation */ |
169 | StmtSRef dst; |
170 | /*! \brief The dependency kind */ |
171 | DepKind kind; |
172 | |
173 | void VisitAttrs(AttrVisitor* v) { |
174 | v->Visit("src" , &src); |
175 | v->Visit("dst" , &dst); |
176 | v->Visit("kind" , &kind); |
177 | } |
178 | |
179 | static constexpr const char* _type_key = "tir.Dependency" ; |
180 | TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object); |
181 | }; |
182 | |
183 | /*! |
184 | * \brief Managed reference to DependencyNode |
185 | * \sa DependencyNode |
186 | */ |
187 | class Dependency : public ObjectRef { |
188 | public: |
189 | /*! \brief Constructor */ |
190 | TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind); |
191 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode); |
192 | }; |
193 | |
194 | /*! |
195 | * \brief An object with 1-to-1 correspondence with each block reference in the sref tree. |
196 | * This data structure is used to track the producer-consumer dependencies between blocks. |
197 | * For example even leaf nodes have a scope node, even though they have no dependencies. |
198 | * |
199 | * Glossary: |
200 | * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref, |
201 | * whose components are: |
202 | * - scope root: a block sref |
203 | * - internal srefs: loop srefs |
204 | * - scope leaves: block srefs |
205 | * - Child block: The scope leaf blocks under the scope root or a specific internal sref |
206 | */ |
207 | class BlockScopeNode : public Object { |
208 | public: |
209 | /*! |
210 | * \brief Lookup table for the `src` of dependencies |
211 | * \note We intentionally didn't use tvm::Map as the data structure, because we need the values |
212 | * inside to be mutable so that they could be further maintained properly during transformations. |
213 | */ |
214 | std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps; |
215 | /*! \brief Lookup table for the `dst` of dependencies */ |
216 | std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps; |
217 | /*! \brief The mapping from the buffer to the blocks who write it */ |
218 | std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers; |
219 | /*! |
220 | * \brief This property indicates that the block scope (rooted at its corresponding block) is |
221 | * equivalent to of a stage pipeline. Under the following conditions: |
222 | * |
223 | * 1) The region cover property holds for every of its child blocks |
224 | * 2) No write-after-read dependency or opaque dependency, only read-after-write and |
225 | * write-after-write are allowed |
226 | * 3) All the statements in the scope are schedulable statements, i.e. Block and For |
227 | */ |
228 | bool stage_pipeline{false}; |
229 | |
230 | void VisitAttrs(AttrVisitor* v) {} |
231 | |
232 | static constexpr const char* _type_key = "tir.BlockScope" ; |
233 | TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object); |
234 | |
235 | public: |
236 | /******** Dependency ********/ |
237 | /*! |
238 | * \brief Get all dependencies whose `src` equals `src` |
239 | * \param src The queried block |
240 | * \return The dependencies |
241 | */ |
242 | TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const; |
243 | /*! |
244 | * \brief Get all dependencies whose `dst` equals `dst` |
245 | * \param dst The queried block |
246 | * \return The dependencies |
247 | */ |
248 | TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const; |
249 | }; |
250 | |
251 | /*! |
252 | * \brief Managed reference to BlockScopeNode |
253 | * \sa BlockScopeNode |
254 | */ |
255 | class BlockScope : public ObjectRef { |
256 | public: |
257 | /*! \brief The constructor creating an empty block scope with on dependency information */ |
258 | TVM_DLL BlockScope(); |
259 | /*! |
260 | * \brief Create the object with the specific leaf blocks, and compute the dependency information |
261 | * between the leaf blocks. |
262 | * \param child_block_srefs The srefs to the leaf blocks |
263 | * \note We assume the leaf blocks are given in pre-DFS order |
264 | */ |
265 | TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs); |
266 | |
267 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); |
268 | }; |
269 | |
270 | } // namespace tir |
271 | } // namespace tvm |
272 | |
273 | #endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_ |
274 | |