1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /******** Utility functions ********/ |
25 | |
26 | template <class K, class V> |
27 | using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>; |
28 | |
29 | /*! |
30 | * \brief Add a dependency relation. |
31 | * \param src The source of the dependency |
32 | * \param dst The destination of the dependecy |
33 | * \param kind Type of the dependency |
34 | * \note This method is effectively NOP on self-loops |
35 | */ |
36 | void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { |
37 | if (!src.same_as(dst)) { |
38 | Dependency dep(src, dst, kind); |
39 | self->src2deps[src].push_back(dep); |
40 | self->dst2deps[dst].push_back(dep); |
41 | } |
42 | } |
43 | |
44 | /******** Constructors ********/ |
45 | |
46 | StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) { |
47 | ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>(); |
48 | n->stmt = stmt; |
49 | n->parent = parent; |
50 | n->seq_index = seq_index; |
51 | data_ = std::move(n); |
52 | } |
53 | |
54 | StmtSRef StmtSRef::InlineMark() { |
55 | static StmtSRef result(nullptr, nullptr, -1); |
56 | return result; |
57 | } |
58 | |
59 | StmtSRef StmtSRef::RootMark() { |
60 | static StmtSRef result(nullptr, nullptr, -1); |
61 | return result; |
62 | } |
63 | |
64 | Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { |
65 | ObjectPtr<DependencyNode> node = make_object<DependencyNode>(); |
66 | node->src = std::move(src); |
67 | node->dst = std::move(dst); |
68 | node->kind = kind; |
69 | data_ = std::move(node); |
70 | } |
71 | |
72 | BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); } |
73 | |
74 | BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) { |
75 | ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>(); |
76 | SMap<Buffer, Array<StmtSRef>> buffer_readers; |
77 | SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers; |
78 | for (const StmtSRef& child_block_sref : child_block_srefs) { |
79 | const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); |
80 | // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer |
81 | for (const BufferRegion& region : child_block->reads) { |
82 | buffer_readers[region->buffer].push_back(child_block_sref); |
83 | } |
84 | for (const BufferRegion& region : child_block->writes) { |
85 | buffer_writers[region->buffer].push_back(child_block_sref); |
86 | } |
87 | // Step 2. Update RAW dependency |
88 | for (const BufferRegion& region : child_block->reads) { |
89 | auto it = buffer_writers.find(region->buffer); |
90 | if (it != buffer_writers.end()) { |
91 | for (const StmtSRef& from : it->second) { |
92 | AddDependency(n.get(), from, child_block_sref, DepKind::kRAW); |
93 | } |
94 | } |
95 | } |
96 | // Step 3. Update WAW dependency |
97 | for (const BufferRegion& region : child_block->writes) { |
98 | auto it = buffer_writers.find(region->buffer); |
99 | if (it != buffer_writers.end()) { |
100 | for (const StmtSRef& from : it->second) { |
101 | AddDependency(n.get(), from, child_block_sref, DepKind::kWAW); |
102 | } |
103 | } |
104 | } |
105 | // Step 4. Update WAR dependency |
106 | for (const BufferRegion& region : child_block->writes) { |
107 | auto it = buffer_readers.find(region->buffer); |
108 | if (it != buffer_readers.end()) { |
109 | for (const StmtSRef& from : it->second) { |
110 | AddDependency(n.get(), from, child_block_sref, DepKind::kWAR); |
111 | } |
112 | } |
113 | } |
114 | } |
115 | data_ = std::move(n); |
116 | } |
117 | |
118 | /******** Dependency ********/ |
119 | |
120 | Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { |
121 | auto iter = this->src2deps.find(block_sref); |
122 | if (iter != this->src2deps.end()) { |
123 | return iter->second; |
124 | } else { |
125 | return {}; |
126 | } |
127 | } |
128 | |
129 | Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { |
130 | auto iter = this->dst2deps.find(block_sref); |
131 | if (iter != this->dst2deps.end()) { |
132 | return iter->second; |
133 | } else { |
134 | return {}; |
135 | } |
136 | } |
137 | |
138 | /******** FFI ********/ |
139 | |
140 | TVM_REGISTER_NODE_TYPE(StmtSRefNode); |
141 | TVM_REGISTER_NODE_TYPE(DependencyNode); |
142 | TVM_REGISTER_NODE_TYPE(BlockScopeNode); |
143 | |
144 | TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt" ) |
145 | .set_body_typed([](StmtSRef sref) -> Optional<Stmt> { |
146 | return GetRef<Optional<Stmt>>(sref->stmt); |
147 | }); |
148 | TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent" ) |
149 | .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> { |
150 | return GetRef<Optional<StmtSRef>>(sref->parent); |
151 | }); |
152 | TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark" ) // |
153 | .set_body_typed(StmtSRef::RootMark); |
154 | TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark" ) // |
155 | .set_body_typed(StmtSRef::InlineMark); |
156 | TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc" ) |
157 | .set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc); |
158 | TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst" ) |
159 | .set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst); |
160 | |
161 | } // namespace tir |
162 | } // namespace tvm |
163 | |