1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/tir/schedule/state.h |
21 | * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling. |
22 | */ |
23 | #ifndef TVM_TIR_SCHEDULE_STATE_H_ |
24 | #define TVM_TIR_SCHEDULE_STATE_H_ |
25 | |
26 | #include <tvm/ir/module.h> |
27 | #include <tvm/tir/function.h> |
28 | #include <tvm/tir/schedule/block_scope.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <utility> |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | |
36 | /*! |
37 | * \brief The information about a TensorIR block, it contains two categories of information |
38 | * 1) Info on the block scope rooted at a specific block, including dependency tracking, |
39 | * flags indicating if the scope is a stage pipeline, etc. |
40 | * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it |
41 | * reads are completely covered by their producers, etc. |
42 | */ |
43 | struct BlockInfo { |
44 | /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ |
45 | BlockScope scope{nullptr}; |
46 | // The properties below are information about the current block realization under its parent scope |
47 | /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ |
48 | bool affine_binding{false}; |
49 | /*! |
50 | * \brief Property of a block, indicating each of the block's read regions is fully |
51 | * produced by its producers |
52 | */ |
53 | bool region_cover{false}; |
54 | |
55 | BlockInfo() = default; |
56 | |
57 | explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false) |
58 | : scope(std::move(scope)), // |
59 | affine_binding(affine_binding), // |
60 | region_cover(region_cover) {} |
61 | }; |
62 | |
63 | /*! |
64 | * \brief The bitmask of the debug flag in the ScheduleStateNode. |
65 | * \sa ScheduleStateNode |
66 | */ |
67 | enum ScheduleDebugMask : uint32_t { |
68 | /*! \brief Verify the correctness of the sref tree */ |
69 | kVerifySRefTree = 1, |
70 | /*! \brief Verify the correctness of affine_binding, region_cover and stage_pipeline */ |
71 | kVerifyCachedFlags = 2, |
72 | }; |
73 | |
74 | /*! |
75 | * \brief The state of scheduling, which exposes a `Replace` method as |
76 | * the primary interface for all the scheduling primitives to manipulate the TensorIR. |
77 | * |
78 | * The data structure contains the following information |
79 | * 1) The AST being scheduled (mod) |
80 | * 2) The sref tree of schedulable statements (indicated by the srefs) |
81 | * 3) The dependency information of each block scope (block_info) |
82 | * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref) |
83 | * 5) A debug flag, if set, extra checking is enabled (debug_mask) |
84 | */ |
85 | class ScheduleStateNode : public Object { |
86 | public: |
87 | /*! \brief The AST of the module being scheduled */ |
88 | IRModule mod; |
89 | /*! |
90 | * \brief Mapping from a block sref to its correpsonding BlockInfo, |
91 | * tracking the dependency inside the block scope, |
92 | * and storing necessary information flags for scheduling |
93 | */ |
94 | std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info; |
95 | /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ |
96 | std::unordered_map<const StmtNode*, StmtSRef> stmt2ref; |
97 | /*! |
98 | * \brief Do extra correctness checking after the class creation |
99 | * and each time after calling the Replace method. |
100 | * \sa ScheduleDebugMask |
101 | */ |
102 | int debug_mask; |
103 | |
104 | void VisitAttrs(AttrVisitor* v) { |
105 | v->Visit("mod" , &mod); |
106 | // `block_info` is not visited |
107 | // `stmt2ref` is not visited |
108 | v->Visit("debug_mask" , &debug_mask); |
109 | } |
110 | /*! |
111 | * \brief Replace the part of the AST, as being pointed to by `src_sref`, |
112 | * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly. |
113 | * Replace will try to perform copy on write as much as possible when the ScheduleState holds |
114 | * the only copy to the IRModule and IR nodes. |
115 | * |
116 | * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`. |
117 | * 1) Block -> Block |
118 | * 2) Loop -> Loop |
119 | * 3) Loop -> BlockRealize |
120 | * |
121 | * \param src_sref The sref to the statement to be replaced |
122 | * \param tgt_stmt The statement to be replaced in |
123 | * \param block_sref_reuse Maps an old block (to be replaced in the subtree under |
124 | * `src_sref->stmt`) to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces |
125 | * reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref |
126 | * that points to the old block will point to the new one |
127 | * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. |
128 | */ |
129 | TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, |
130 | const Map<Block, Block>& block_sref_reuse); |
131 | /*! |
132 | * \brief Trigger the verification according to the `debug_mask` bitmask. |
133 | * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. |
134 | * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`, |
135 | * `region_cover` and `stage_pipeline` |
136 | */ |
137 | TVM_DLL void DebugVerify() const; |
138 | |
139 | static constexpr const char* _type_key = "tir.ScheduleState" ; |
140 | TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object); |
141 | |
142 | /******** Property of blocks ********/ |
143 | /*! \brief Returns the BlockInfo correpsonding to the block sref */ |
144 | TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const; |
145 | /*! |
146 | * \brief Recalculate the BlockInfo recursively under stmt. |
147 | * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't |
148 | * have block vars, since the affine flag depends on the outer scope of stmt. |
149 | */ |
150 | TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt); |
151 | /*! |
152 | * \brief Get the BlockScope correpsonding to the sref of scope root block |
153 | * \param scope_root The block sref to be retrieved |
154 | * \return The corresponding BlockScope |
155 | */ |
156 | BlockScope GetBlockScope(const StmtSRef& scope_root) const { |
157 | return GetBlockInfo(scope_root).scope; |
158 | } |
159 | /*! |
160 | * \brief Check a cached flag indicating if the specific block has quasi-affine bindings |
161 | * \param block_sref The block sref to be checked |
162 | * \return A boolean flag indicating if the block has quasi-affine bindings |
163 | */ |
164 | bool IsAffineBlockBinding(const StmtSRef& block_sref) const { |
165 | return GetBlockInfo(block_sref).affine_binding; |
166 | } |
167 | /*! |
168 | * \brief Check a cached flag indicating if each of the specific consumer block's read region |
169 | * is fully produced by its producers |
170 | * \param consumer_block_sref The specific consumer block |
171 | * \return A boolean flag indicating if the block has quasi-affine bindings |
172 | */ |
173 | bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const { |
174 | return GetBlockInfo(consumer_block_sref).region_cover; |
175 | } |
176 | /*! |
177 | * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline |
178 | * \param scope_root The block sref to be retrieved |
179 | * \return The corresponding BlockScope |
180 | */ |
181 | bool IsStagePipeline(const StmtSRef& scope_root) const { |
182 | return GetBlockScope(scope_root)->stage_pipeline; |
183 | } |
184 | }; |
185 | |
186 | /*! |
187 | * \brief Managed reference to ScheduleStateNode |
188 | * \sa ScheduleStateNode |
189 | */ |
190 | class ScheduleState : public ObjectRef { |
191 | public: |
192 | /*! |
193 | * \brief Construct a schedule state from an IRModule |
194 | * \param mod The IRModule to be scheduled |
195 | * \param debug_mask Do extra correctness checking after the class creation |
196 | * and each time after calling the Replace method. |
197 | */ |
198 | TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0); |
199 | |
200 | /*! \return The mutable pointer to the ScheduleStateNode */ |
201 | ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); } |
202 | |
203 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode); |
204 | }; |
205 | |
206 | } // namespace tir |
207 | } // namespace tvm |
208 | |
209 | #endif // TVM_TIR_SCHEDULE_STATE_H_ |
210 | |