1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "trace_apply.h" |
20 | |
21 | #include <tvm/tir/analysis.h> |
22 | #include <tvm/tir/stmt_functor.h> |
23 | |
24 | #include <optional> |
25 | #include <string> |
26 | #include <unordered_map> |
27 | #include <unordered_set> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | #include "../tir/schedule/analysis.h" |
32 | #include "utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace meta_schedule { |
36 | |
37 | using namespace tir; |
38 | |
39 | // Returns true if b1 is an ancestor of b2 |
40 | bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { |
41 | if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { |
42 | return true; |
43 | } |
44 | for (auto prod : sch->GetProducers(b2)) { |
45 | if (IsAncestor(b1, prod, sch)) return true; |
46 | } |
47 | return false; |
48 | } |
49 | |
50 | // Inline or reverse inline spatial blocks after the anchor block |
51 | void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { |
52 | static auto kind_get_block = InstructionKind::Get("GetBlock" ); |
53 | // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace. |
54 | // We record such block names to avoid inlining them here. |
55 | std::unordered_set<std::string> get_block_names; |
56 | for (const auto& inst : anchor_trace->insts) { |
57 | if (inst->kind.same_as(kind_get_block)) { |
58 | auto block_name = Downcast<String>(inst->attrs[0]); |
59 | ICHECK(block_name.defined()); |
60 | get_block_names.insert(block_name); |
61 | } |
62 | } |
63 | |
64 | auto anchor_block = FindAnchorBlock(sch->mod()); |
65 | |
66 | std::vector<std::string> inline_todos; |
67 | std::optional<int> last_block_idx{std::nullopt}; |
68 | |
69 | for (auto name : GetBlockNames(sch->mod())) { |
70 | auto block = sch->GetBlock(name); |
71 | if (anchor_block) { |
72 | auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint); |
73 | if (IsAncestor(block, anchor_block_rv, sch)) continue; |
74 | } |
75 | // Spatial blocks which are not referenced in the anchor trace will be inlined here. |
76 | auto block_sref = sch->GetSRef(block); |
77 | if (IsSpatial(block_sref) && !get_block_names.count(name)) { |
78 | if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), block_sref, false))) { |
79 | last_block_idx = inline_todos.size(); |
80 | } |
81 | inline_todos.push_back(name); |
82 | } |
83 | } |
84 | |
85 | if (last_block_idx) { |
86 | // The last block can only be reverse compute inlined. We make sure to inline all |
87 | // producer blocks of the last block beforehand so that reverse compute inline can succeed. |
88 | std::swap(inline_todos[*last_block_idx], inline_todos.back()); |
89 | } |
90 | |
91 | auto inline_rule = GetDefaultAutoInline(target->kind->name); |
92 | |
93 | for (auto name : inline_todos) { |
94 | inline_rule->Apply(sch, sch->GetBlock(name)); |
95 | } |
96 | } |
97 | |
98 | // Apply instructions from the anchor trace to the target schedule, and returns blocks |
99 | // that remain unscheduled. |
100 | std::vector<BlockRV> ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { |
101 | static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks" ); |
102 | static auto kind_get_block = InstructionKind::Get("GetBlock" ); |
103 | static auto kind_compute_inline = InstructionKind::Get("ComputeInline" ); |
104 | static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline" ); |
105 | |
106 | const auto block_names_orig = GetBlockNames(sch->mod()); |
107 | const auto sch_orig = sch->Copy(); |
108 | |
109 | std::unordered_map<const Object*, const Object*> rv_map; |
110 | // Blocks and loops that appear in the anchor trace but are not part of the target schedule. |
111 | std::unordered_set<BlockRV, ObjectHash, ObjectEqual> foreign_blocks; |
112 | std::unordered_set<LoopRV, ObjectHash, ObjectEqual> foreign_loops; |
113 | |
114 | // Instructions in the anchor trace can be applied only if all inputs are part of the target |
115 | // schedule. |
116 | auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { |
117 | for (auto input : inst->inputs) { |
118 | if (!input.defined()) continue; |
119 | if ((input->IsInstance<BlockRVNode>() && foreign_blocks.count(Downcast<BlockRV>(input))) || |
120 | (input->IsInstance<LoopRVNode>() && foreign_loops.count(Downcast<LoopRV>(input)))) { |
121 | return false; |
122 | } |
123 | } |
124 | return true; |
125 | }; |
126 | |
127 | for (const auto& inst : anchor_trace->insts) { |
128 | if (!is_inst_applicable(inst)) { |
129 | // If we find an instruction that is not applicable, its outputs are recorded as "foreign" |
130 | // to the target schedule. |
131 | for (auto output : inst->outputs) { |
132 | if (output->IsInstance<BlockRVNode>()) { |
133 | foreign_blocks.insert(Downcast<BlockRV>(output)); |
134 | } else if (output->IsInstance<LoopRVNode>()) { |
135 | foreign_loops.insert(Downcast<LoopRV>(output)); |
136 | } |
137 | } |
138 | continue; |
139 | } |
140 | |
141 | Array<ObjectRef> inputs = TranslateInputRVs(inst->inputs, rv_map); |
142 | |
143 | if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast<String>(inst->attrs[0]))) { |
144 | // The anchor trace does get_block on a block that is not part of the target schedule. |
145 | auto block = Downcast<BlockRV>(inst->outputs[0]); |
146 | foreign_blocks.insert(block); |
147 | continue; |
148 | } else if (inst->kind.same_as(kind_reverse_compute_inline)) { |
149 | // The anchor trace does reverse_compute_inline on a block, but the block with the same name |
150 | // in the target schedule cannot be reverse compute inline-ed. |
151 | // In such cases, it should be possible to apply compute_inline instead. |
152 | auto block = Downcast<BlockRV>(inputs[0]); |
153 | auto block_sref = sch->GetSRef(block); |
154 | if (!CanReverseComputeInline(sch->state(), block_sref)) { |
155 | ICHECK(CanComputeInline(sch->state(), block_sref)); |
156 | sch->ComputeInline(block); |
157 | continue; |
158 | } |
159 | } else if (inst->kind.same_as(kind_compute_inline)) { |
160 | // Similar to the reverse_compute_inline case above. |
161 | auto block = Downcast<BlockRV>(inputs[0]); |
162 | auto block_sref = sch->GetSRef(block); |
163 | auto state = sch->state(); |
164 | if (!CanComputeInline(state, block_sref)) { |
165 | ICHECK(IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false))) |
166 | << "If a spatial block cannot be inlined, it should be the output block" ; |
167 | if (CanReverseComputeInline(sch->state(), block_sref)) { |
168 | sch->ReverseComputeInline(block); |
169 | } |
170 | continue; |
171 | } |
172 | } |
173 | |
174 | Optional<ObjectRef> decision = anchor_trace->GetDecision(inst); |
175 | Array<ObjectRef> outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); |
176 | |
177 | if (inst->kind.same_as(kind_get_child_blocks)) { |
178 | // We want to allow a trace generated for a single conv2d block to be applied to |
179 | // conv2d -> elemwise blocks, where two conv2d are the same workload. |
180 | // GetChildBlocks returns a different number of blocks for the two cases above, which |
181 | // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == |
182 | // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" |
183 | // outputs matches with the "old" outputs, and truncating the new outputs accordingly. |
184 | ICHECK(inst->outputs.size() <= outputs.size()); |
185 | TranslateAddOutputRVs( |
186 | inst->outputs, Array<ObjectRef>(outputs.begin(), outputs.begin() + inst->outputs.size()), |
187 | &rv_map); |
188 | } else { |
189 | TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); |
190 | } |
191 | } |
192 | |
193 | auto is_scheduled = [=](const std::string& block_name) { |
194 | auto loops = sch->GetLoops(sch->GetBlock(block_name)); |
195 | auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); |
196 | if (loops.size() != loops_orig.size()) { |
197 | return true; |
198 | } |
199 | for (size_t i = 0; i < loops.size(); ++i) { |
200 | auto loop = sch->Get(loops[i]); |
201 | auto loop_orig = sch_orig->Get(loops_orig[i]); |
202 | if (loop->kind != loop_orig->kind) { |
203 | return true; |
204 | } |
205 | } |
206 | return false; |
207 | }; |
208 | |
209 | const auto block_names_now = GetBlockNames(sch->mod()); |
210 | std::vector<BlockRV> unscheduled_blocks; |
211 | |
212 | for (auto name : block_names_orig) { |
213 | if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) { |
214 | unscheduled_blocks.push_back(sch->GetBlock(name)); |
215 | } |
216 | } |
217 | |
218 | return unscheduled_blocks; |
219 | } |
220 | |
221 | void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) { |
222 | InlinePostBlocks(sch, anchor_trace, target); |
223 | |
224 | auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); |
225 | ICHECK(unscheduled_blocks.size() <= 1) |
226 | << "All blocks should have been scheduled or only one (fused) spatial block can remain " |
227 | "unscheduled at this point." ; |
228 | |
229 | if (unscheduled_blocks.empty()) { |
230 | // All blocks have already been scheduled. |
231 | return; |
232 | } |
233 | |
234 | auto last_block = unscheduled_blocks[0]; |
235 | auto last_block_producers = sch->GetProducers(last_block); |
236 | |
237 | if (last_block_producers.size() == 1 && IsSpatial(sch->GetSRef(last_block_producers[0]))) { |
238 | // Inline into the cache write stage |
239 | sch->ReverseComputeInline(last_block); |
240 | } else if (target->kind->name == "llvm" || target->kind->name == "hexagon" ) { |
241 | sch->Parallel(sch->Fuse(sch->GetLoops(last_block))); |
242 | } else if (IsGPUTarget(target->kind->name)) { |
243 | auto max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block" ); |
244 | ICHECK(max_threads_per_block.defined()) |
245 | << "ValueError: missing attribute `max_threads_per_block` in the target" ; |
246 | |
247 | auto auto_bind_rule = |
248 | ScheduleRule::AutoBind(/*max_threadblocks=*/256, |
249 | /*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024}, |
250 | max_threads_per_block.value()->value); |
251 | auto_bind_rule->Apply(sch, last_block); |
252 | } |
253 | } |
254 | |
255 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace" ) |
256 | .set_body_typed(ScheduleUsingAnchorTrace); |
257 | |
258 | } // namespace meta_schedule |
259 | } // namespace tvm |
260 | |