1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../analysis.h" |
20 | #include "../utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) { |
26 | struct Finder : public StmtVisitor { |
27 | explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} |
28 | |
29 | void VisitStmt_(const BlockNode* block) override { |
30 | if (block->name_hint == name_) { |
31 | auto it = self_->stmt2ref.find(block); |
32 | ICHECK(it != self_->stmt2ref.end()); |
33 | results_.push_back(it->second); |
34 | } |
35 | StmtVisitor::VisitStmt_(block); |
36 | } |
37 | |
38 | const ScheduleState& self_; |
39 | const String& name_; |
40 | Array<StmtSRef> results_; |
41 | }; |
42 | |
43 | BaseFunc func = self->mod->Lookup(gv); |
44 | const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode); |
45 | Finder finder(self, name); |
46 | finder(prim_func->body); |
47 | return std::move(finder.results_); |
48 | } |
49 | |
50 | Array<StmtSRef> GetLoops(const StmtSRef& block_sref) { |
51 | std::vector<StmtSRef> result; |
52 | for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance<ForNode>(); |
53 | parent = parent->parent) { |
54 | result.push_back(GetRef<StmtSRef>(parent)); |
55 | } |
56 | return {result.rbegin(), result.rend()}; |
57 | } |
58 | |
59 | Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { |
60 | struct Collector : public StmtVisitor { |
61 | private: |
62 | void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } |
63 | |
64 | public: |
65 | explicit Collector(const ScheduleState& self) : self(self) {} |
66 | |
67 | const ScheduleState& self; |
68 | Array<StmtSRef> result; |
69 | }; |
70 | Collector collector(self); |
71 | if (parent_sref->stmt->IsInstance<ForNode>()) { |
72 | const auto* loop = static_cast<const ForNode*>(parent_sref->stmt); |
73 | collector(loop->body); |
74 | } else if (parent_sref->stmt->IsInstance<BlockNode>()) { |
75 | const auto* block = static_cast<const BlockNode*>(parent_sref->stmt); |
76 | collector(block->body); |
77 | } |
78 | return std::move(collector.result); |
79 | } |
80 | |
81 | Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { |
82 | StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
83 | return tir::GetProducers(block_sref, self->GetBlockScope(scope_root)); |
84 | } |
85 | |
86 | Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { |
87 | StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
88 | return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); |
89 | } |
90 | |
91 | /******** InstructionKind Registration ********/ |
92 | |
93 | struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> { |
94 | static constexpr const char* kName = "GetBlock" ; |
95 | static constexpr bool kIsPure = true; |
96 | |
97 | private: |
98 | static constexpr size_t kNumInputs = 0; |
99 | static constexpr size_t kNumAttrs = 2; |
100 | static constexpr size_t kNumDecisions = 0; |
101 | |
102 | static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) { |
103 | return sch->GetBlock(name, func_name); |
104 | } |
105 | |
106 | static String UnpackedAsPython(Array<String> outputs, String name, String func_name) { |
107 | PythonAPICall py("get_block" ); |
108 | py.Input("name" , name); |
109 | py.Input("func_name" , func_name); |
110 | py.SingleOutput(outputs); |
111 | return py.Str(); |
112 | } |
113 | |
114 | template <typename> |
115 | friend struct ::tvm::tir::UnpackedInstTraits; |
116 | }; |
117 | |
118 | struct GetLoopsTraits : public UnpackedInstTraits<GetLoopsTraits> { |
119 | static constexpr const char* kName = "GetLoops" ; |
120 | static constexpr bool kIsPure = true; |
121 | |
122 | private: |
123 | static constexpr size_t kNumInputs = 1; |
124 | static constexpr size_t kNumAttrs = 0; |
125 | static constexpr size_t kNumDecisions = 0; |
126 | |
127 | static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { |
128 | return sch->GetLoops(block_rv); |
129 | } |
130 | |
131 | static String UnpackedAsPython(Array<String> outputs, String block_rv) { |
132 | PythonAPICall py("get_loops" ); |
133 | py.Input("block" , block_rv); |
134 | py.OutputList(outputs); |
135 | return py.Str(); |
136 | } |
137 | |
138 | template <typename> |
139 | friend struct ::tvm::tir::UnpackedInstTraits; |
140 | }; |
141 | |
142 | struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> { |
143 | static constexpr const char* kName = "GetChildBlocks" ; |
144 | static constexpr bool kIsPure = true; |
145 | |
146 | private: |
147 | static constexpr size_t kNumInputs = 1; |
148 | static constexpr size_t kNumAttrs = 0; |
149 | static constexpr size_t kNumDecisions = 0; |
150 | |
151 | static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { |
152 | if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) { |
153 | return sch->GetChildBlocks(GetRef<BlockRV>(block)); |
154 | } |
155 | if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) { |
156 | return sch->GetChildBlocks(GetRef<LoopRV>(loop)); |
157 | } |
158 | LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); |
159 | throw; |
160 | } |
161 | |
162 | static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv) { |
163 | PythonAPICall py("get_child_blocks" ); |
164 | py.Input("" , block_or_loop_rv); |
165 | py.OutputList(outputs); |
166 | return py.Str(); |
167 | } |
168 | |
169 | template <typename> |
170 | friend struct ::tvm::tir::UnpackedInstTraits; |
171 | }; |
172 | |
173 | struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> { |
174 | static constexpr const char* kName = "GetProducers" ; |
175 | static constexpr bool kIsPure = true; |
176 | |
177 | private: |
178 | static constexpr size_t kNumInputs = 1; |
179 | static constexpr size_t kNumAttrs = 0; |
180 | static constexpr size_t kNumDecisions = 0; |
181 | |
182 | static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { |
183 | return sch->GetProducers(block_rv); |
184 | } |
185 | |
186 | static String UnpackedAsPython(Array<String> outputs, String block_rv) { |
187 | PythonAPICall py("get_producers" ); |
188 | py.Input("block" , block_rv); |
189 | py.OutputList(outputs); |
190 | return py.Str(); |
191 | } |
192 | |
193 | template <typename> |
194 | friend struct ::tvm::tir::UnpackedInstTraits; |
195 | }; |
196 | |
197 | struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> { |
198 | static constexpr const char* kName = "GetConsumers" ; |
199 | static constexpr bool kIsPure = true; |
200 | |
201 | private: |
202 | static constexpr size_t kNumInputs = 1; |
203 | static constexpr size_t kNumAttrs = 0; |
204 | static constexpr size_t kNumDecisions = 0; |
205 | |
206 | static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { |
207 | return sch->GetConsumers(block_rv); |
208 | } |
209 | |
210 | static String UnpackedAsPython(Array<String> outputs, String block_rv) { |
211 | PythonAPICall py("get_consumers" ); |
212 | py.Input("block" , block_rv); |
213 | py.OutputList(outputs); |
214 | return py.Str(); |
215 | } |
216 | |
217 | template <typename> |
218 | friend struct ::tvm::tir::UnpackedInstTraits; |
219 | }; |
220 | |
221 | TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); |
222 | TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); |
223 | TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); |
224 | TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); |
225 | TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); |
226 | |
227 | } // namespace tir |
228 | } // namespace tvm |
229 | |