1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace tir {
23
24/******** Utility functions ********/
25
26template <class K, class V>
27using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
28
29/*!
30 * \brief Add a dependency relation.
31 * \param src The source of the dependency
32 * \param dst The destination of the dependecy
33 * \param kind Type of the dependency
34 * \note This method is effectively NOP on self-loops
35 */
36void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
37 if (!src.same_as(dst)) {
38 Dependency dep(src, dst, kind);
39 self->src2deps[src].push_back(dep);
40 self->dst2deps[dst].push_back(dep);
41 }
42}
43
44/******** Constructors ********/
45
46StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
47 ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
48 n->stmt = stmt;
49 n->parent = parent;
50 n->seq_index = seq_index;
51 data_ = std::move(n);
52}
53
54StmtSRef StmtSRef::InlineMark() {
55 static StmtSRef result(nullptr, nullptr, -1);
56 return result;
57}
58
59StmtSRef StmtSRef::RootMark() {
60 static StmtSRef result(nullptr, nullptr, -1);
61 return result;
62}
63
64Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
65 ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
66 node->src = std::move(src);
67 node->dst = std::move(dst);
68 node->kind = kind;
69 data_ = std::move(node);
70}
71
72BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); }
73
74BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
75 ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>();
76 SMap<Buffer, Array<StmtSRef>> buffer_readers;
77 SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
78 for (const StmtSRef& child_block_sref : child_block_srefs) {
79 const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
80 // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
81 for (const BufferRegion& region : child_block->reads) {
82 buffer_readers[region->buffer].push_back(child_block_sref);
83 }
84 for (const BufferRegion& region : child_block->writes) {
85 buffer_writers[region->buffer].push_back(child_block_sref);
86 }
87 // Step 2. Update RAW dependency
88 for (const BufferRegion& region : child_block->reads) {
89 auto it = buffer_writers.find(region->buffer);
90 if (it != buffer_writers.end()) {
91 for (const StmtSRef& from : it->second) {
92 AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
93 }
94 }
95 }
96 // Step 3. Update WAW dependency
97 for (const BufferRegion& region : child_block->writes) {
98 auto it = buffer_writers.find(region->buffer);
99 if (it != buffer_writers.end()) {
100 for (const StmtSRef& from : it->second) {
101 AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
102 }
103 }
104 }
105 // Step 4. Update WAR dependency
106 for (const BufferRegion& region : child_block->writes) {
107 auto it = buffer_readers.find(region->buffer);
108 if (it != buffer_readers.end()) {
109 for (const StmtSRef& from : it->second) {
110 AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
111 }
112 }
113 }
114 }
115 data_ = std::move(n);
116}
117
118/******** Dependency ********/
119
120Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const {
121 auto iter = this->src2deps.find(block_sref);
122 if (iter != this->src2deps.end()) {
123 return iter->second;
124 } else {
125 return {};
126 }
127}
128
129Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const {
130 auto iter = this->dst2deps.find(block_sref);
131 if (iter != this->dst2deps.end()) {
132 return iter->second;
133 } else {
134 return {};
135 }
136}
137
138/******** FFI ********/
139
140TVM_REGISTER_NODE_TYPE(StmtSRefNode);
141TVM_REGISTER_NODE_TYPE(DependencyNode);
142TVM_REGISTER_NODE_TYPE(BlockScopeNode);
143
144TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
145 .set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
146 return GetRef<Optional<Stmt>>(sref->stmt);
147 });
148TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
149 .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
150 return GetRef<Optional<StmtSRef>>(sref->parent);
151 });
152TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") //
153 .set_body_typed(StmtSRef::RootMark);
154TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") //
155 .set_body_typed(StmtSRef::InlineMark);
156TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
157 .set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
158TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
159 .set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
160
161} // namespace tir
162} // namespace tvm
163