1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/tir/schedule/state.h
21 * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
22 */
23#ifndef TVM_TIR_SCHEDULE_STATE_H_
24#define TVM_TIR_SCHEDULE_STATE_H_
25
26#include <tvm/ir/module.h>
27#include <tvm/tir/function.h>
28#include <tvm/tir/schedule/block_scope.h>
29
30#include <unordered_map>
31#include <utility>
32
33namespace tvm {
34namespace tir {
35
36/*!
37 * \brief The information about a TensorIR block, it contains two categories of information
38 * 1) Info on the block scope rooted at a specific block, including dependency tracking,
39 * flags indicating if the scope is a stage pipeline, etc.
40 * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
41 * reads are completely covered by their producers, etc.
42 */
43struct BlockInfo {
44 /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
45 BlockScope scope{nullptr};
46 // The properties below are information about the current block realization under its parent scope
47 /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
48 bool affine_binding{false};
49 /*!
50 * \brief Property of a block, indicating each of the block's read regions is fully
51 * produced by its producers
52 */
53 bool region_cover{false};
54
55 BlockInfo() = default;
56
57 explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
58 : scope(std::move(scope)), //
59 affine_binding(affine_binding), //
60 region_cover(region_cover) {}
61};
62
63/*!
64 * \brief The bitmask of the debug flag in the ScheduleStateNode.
65 * \sa ScheduleStateNode
66 */
67enum ScheduleDebugMask : uint32_t {
68 /*! \brief Verify the correctness of the sref tree */
69 kVerifySRefTree = 1,
70 /*! \brief Verify the correctness of affine_binding, region_cover and stage_pipeline */
71 kVerifyCachedFlags = 2,
72};
73
74/*!
75 * \brief The state of scheduling, which exposes a `Replace` method as
76 * the primary interface for all the scheduling primitives to manipulate the TensorIR.
77 *
78 * The data structure contains the following information
79 * 1) The AST being scheduled (mod)
80 * 2) The sref tree of schedulable statements (indicated by the srefs)
81 * 3) The dependency information of each block scope (block_info)
82 * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
83 * 5) A debug flag, if set, extra checking is enabled (debug_mask)
84 */
85class ScheduleStateNode : public Object {
86 public:
87 /*! \brief The AST of the module being scheduled */
88 IRModule mod;
89 /*!
90 * \brief Mapping from a block sref to its correpsonding BlockInfo,
91 * tracking the dependency inside the block scope,
92 * and storing necessary information flags for scheduling
93 */
94 std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info;
95 /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
96 std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
97 /*!
98 * \brief Do extra correctness checking after the class creation
99 * and each time after calling the Replace method.
100 * \sa ScheduleDebugMask
101 */
102 int debug_mask;
103
104 void VisitAttrs(AttrVisitor* v) {
105 v->Visit("mod", &mod);
106 // `block_info` is not visited
107 // `stmt2ref` is not visited
108 v->Visit("debug_mask", &debug_mask);
109 }
110 /*!
111 * \brief Replace the part of the AST, as being pointed to by `src_sref`,
112 * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly.
113 * Replace will try to perform copy on write as much as possible when the ScheduleState holds
114 * the only copy to the IRModule and IR nodes.
115 *
116 * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`.
117 * 1) Block -> Block
118 * 2) Loop -> Loop
119 * 3) Loop -> BlockRealize
120 *
121 * \param src_sref The sref to the statement to be replaced
122 * \param tgt_stmt The statement to be replaced in
123 * \param block_sref_reuse Maps an old block (to be replaced in the subtree under
124 * `src_sref->stmt`) to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces
125 * reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref
126 * that points to the old block will point to the new one
127 * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars.
128 */
129 TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
130 const Map<Block, Block>& block_sref_reuse);
131 /*!
132 * \brief Trigger the verification according to the `debug_mask` bitmask.
133 * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
134 * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`,
135 * `region_cover` and `stage_pipeline`
136 */
137 TVM_DLL void DebugVerify() const;
138
139 static constexpr const char* _type_key = "tir.ScheduleState";
140 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object);
141
142 /******** Property of blocks ********/
143 /*! \brief Returns the BlockInfo correpsonding to the block sref */
144 TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
145 /*!
146 * \brief Recalculate the BlockInfo recursively under stmt.
147 * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't
148 * have block vars, since the affine flag depends on the outer scope of stmt.
149 */
150 TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt);
151 /*!
152 * \brief Get the BlockScope correpsonding to the sref of scope root block
153 * \param scope_root The block sref to be retrieved
154 * \return The corresponding BlockScope
155 */
156 BlockScope GetBlockScope(const StmtSRef& scope_root) const {
157 return GetBlockInfo(scope_root).scope;
158 }
159 /*!
160 * \brief Check a cached flag indicating if the specific block has quasi-affine bindings
161 * \param block_sref The block sref to be checked
162 * \return A boolean flag indicating if the block has quasi-affine bindings
163 */
164 bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
165 return GetBlockInfo(block_sref).affine_binding;
166 }
167 /*!
168 * \brief Check a cached flag indicating if each of the specific consumer block's read region
169 * is fully produced by its producers
170 * \param consumer_block_sref The specific consumer block
171 * \return A boolean flag indicating if the block has quasi-affine bindings
172 */
173 bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const {
174 return GetBlockInfo(consumer_block_sref).region_cover;
175 }
176 /*!
177 * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline
178 * \param scope_root The block sref to be retrieved
179 * \return The corresponding BlockScope
180 */
181 bool IsStagePipeline(const StmtSRef& scope_root) const {
182 return GetBlockScope(scope_root)->stage_pipeline;
183 }
184};
185
186/*!
187 * \brief Managed reference to ScheduleStateNode
188 * \sa ScheduleStateNode
189 */
190class ScheduleState : public ObjectRef {
191 public:
192 /*!
193 * \brief Construct a schedule state from an IRModule
194 * \param mod The IRModule to be scheduled
195 * \param debug_mask Do extra correctness checking after the class creation
196 * and each time after calling the Replace method.
197 */
198 TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
199
200 /*! \return The mutable pointer to the ScheduleStateNode */
201 ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
202
203 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode);
204};
205
206} // namespace tir
207} // namespace tvm
208
209#endif // TVM_TIR_SCHEDULE_STATE_H_
210