1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../analysis.h"
20#include "../utils.h"
21
22namespace tvm {
23namespace tir {
24
25Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) {
26 struct Finder : public StmtVisitor {
27 explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {}
28
29 void VisitStmt_(const BlockNode* block) override {
30 if (block->name_hint == name_) {
31 auto it = self_->stmt2ref.find(block);
32 ICHECK(it != self_->stmt2ref.end());
33 results_.push_back(it->second);
34 }
35 StmtVisitor::VisitStmt_(block);
36 }
37
38 const ScheduleState& self_;
39 const String& name_;
40 Array<StmtSRef> results_;
41 };
42
43 BaseFunc func = self->mod->Lookup(gv);
44 const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode);
45 Finder finder(self, name);
46 finder(prim_func->body);
47 return std::move(finder.results_);
48}
49
50Array<StmtSRef> GetLoops(const StmtSRef& block_sref) {
51 std::vector<StmtSRef> result;
52 for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance<ForNode>();
53 parent = parent->parent) {
54 result.push_back(GetRef<StmtSRef>(parent));
55 }
56 return {result.rbegin(), result.rend()};
57}
58
59Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) {
60 struct Collector : public StmtVisitor {
61 private:
62 void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); }
63
64 public:
65 explicit Collector(const ScheduleState& self) : self(self) {}
66
67 const ScheduleState& self;
68 Array<StmtSRef> result;
69 };
70 Collector collector(self);
71 if (parent_sref->stmt->IsInstance<ForNode>()) {
72 const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
73 collector(loop->body);
74 } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
75 const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
76 collector(block->body);
77 }
78 return std::move(collector.result);
79}
80
81Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
82 StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
83 return tir::GetProducers(block_sref, self->GetBlockScope(scope_root));
84}
85
86Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
87 StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
88 return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root));
89}
90
91/******** InstructionKind Registration ********/
92
93struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
94 static constexpr const char* kName = "GetBlock";
95 static constexpr bool kIsPure = true;
96
97 private:
98 static constexpr size_t kNumInputs = 0;
99 static constexpr size_t kNumAttrs = 2;
100 static constexpr size_t kNumDecisions = 0;
101
102 static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) {
103 return sch->GetBlock(name, func_name);
104 }
105
106 static String UnpackedAsPython(Array<String> outputs, String name, String func_name) {
107 PythonAPICall py("get_block");
108 py.Input("name", name);
109 py.Input("func_name", func_name);
110 py.SingleOutput(outputs);
111 return py.Str();
112 }
113
114 template <typename>
115 friend struct ::tvm::tir::UnpackedInstTraits;
116};
117
118struct GetLoopsTraits : public UnpackedInstTraits<GetLoopsTraits> {
119 static constexpr const char* kName = "GetLoops";
120 static constexpr bool kIsPure = true;
121
122 private:
123 static constexpr size_t kNumInputs = 1;
124 static constexpr size_t kNumAttrs = 0;
125 static constexpr size_t kNumDecisions = 0;
126
127 static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
128 return sch->GetLoops(block_rv);
129 }
130
131 static String UnpackedAsPython(Array<String> outputs, String block_rv) {
132 PythonAPICall py("get_loops");
133 py.Input("block", block_rv);
134 py.OutputList(outputs);
135 return py.Str();
136 }
137
138 template <typename>
139 friend struct ::tvm::tir::UnpackedInstTraits;
140};
141
142struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
143 static constexpr const char* kName = "GetChildBlocks";
144 static constexpr bool kIsPure = true;
145
146 private:
147 static constexpr size_t kNumInputs = 1;
148 static constexpr size_t kNumAttrs = 0;
149 static constexpr size_t kNumDecisions = 0;
150
151 static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) {
152 if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
153 return sch->GetChildBlocks(GetRef<BlockRV>(block));
154 }
155 if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
156 return sch->GetChildBlocks(GetRef<LoopRV>(loop));
157 }
158 LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey();
159 throw;
160 }
161
162 static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv) {
163 PythonAPICall py("get_child_blocks");
164 py.Input("", block_or_loop_rv);
165 py.OutputList(outputs);
166 return py.Str();
167 }
168
169 template <typename>
170 friend struct ::tvm::tir::UnpackedInstTraits;
171};
172
173struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> {
174 static constexpr const char* kName = "GetProducers";
175 static constexpr bool kIsPure = true;
176
177 private:
178 static constexpr size_t kNumInputs = 1;
179 static constexpr size_t kNumAttrs = 0;
180 static constexpr size_t kNumDecisions = 0;
181
182 static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
183 return sch->GetProducers(block_rv);
184 }
185
186 static String UnpackedAsPython(Array<String> outputs, String block_rv) {
187 PythonAPICall py("get_producers");
188 py.Input("block", block_rv);
189 py.OutputList(outputs);
190 return py.Str();
191 }
192
193 template <typename>
194 friend struct ::tvm::tir::UnpackedInstTraits;
195};
196
197struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
198 static constexpr const char* kName = "GetConsumers";
199 static constexpr bool kIsPure = true;
200
201 private:
202 static constexpr size_t kNumInputs = 1;
203 static constexpr size_t kNumAttrs = 0;
204 static constexpr size_t kNumDecisions = 0;
205
206 static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
207 return sch->GetConsumers(block_rv);
208 }
209
210 static String UnpackedAsPython(Array<String> outputs, String block_rv) {
211 PythonAPICall py("get_consumers");
212 py.Input("block", block_rv);
213 py.OutputList(outputs);
214 return py.Str();
215 }
216
217 template <typename>
218 friend struct ::tvm::tir::UnpackedInstTraits;
219};
220
221TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
222TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
223TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
224TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
225TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);
226
227} // namespace tir
228} // namespace tvm
229