1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "trace_apply.h"
20
21#include <tvm/tir/analysis.h>
22#include <tvm/tir/stmt_functor.h>
23
24#include <optional>
25#include <string>
26#include <unordered_map>
27#include <unordered_set>
28#include <utility>
29#include <vector>
30
31#include "../tir/schedule/analysis.h"
32#include "utils.h"
33
34namespace tvm {
35namespace meta_schedule {
36
37using namespace tir;
38
39// Returns true if b1 is an ancestor of b2
40bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) {
41 if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) {
42 return true;
43 }
44 for (auto prod : sch->GetProducers(b2)) {
45 if (IsAncestor(b1, prod, sch)) return true;
46 }
47 return false;
48}
49
50// Inline or reverse inline spatial blocks after the anchor block
51void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) {
52 static auto kind_get_block = InstructionKind::Get("GetBlock");
53 // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace.
54 // We record such block names to avoid inlining them here.
55 std::unordered_set<std::string> get_block_names;
56 for (const auto& inst : anchor_trace->insts) {
57 if (inst->kind.same_as(kind_get_block)) {
58 auto block_name = Downcast<String>(inst->attrs[0]);
59 ICHECK(block_name.defined());
60 get_block_names.insert(block_name);
61 }
62 }
63
64 auto anchor_block = FindAnchorBlock(sch->mod());
65
66 std::vector<std::string> inline_todos;
67 std::optional<int> last_block_idx{std::nullopt};
68
69 for (auto name : GetBlockNames(sch->mod())) {
70 auto block = sch->GetBlock(name);
71 if (anchor_block) {
72 auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint);
73 if (IsAncestor(block, anchor_block_rv, sch)) continue;
74 }
75 // Spatial blocks which are not referenced in the anchor trace will be inlined here.
76 auto block_sref = sch->GetSRef(block);
77 if (IsSpatial(block_sref) && !get_block_names.count(name)) {
78 if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), block_sref, false))) {
79 last_block_idx = inline_todos.size();
80 }
81 inline_todos.push_back(name);
82 }
83 }
84
85 if (last_block_idx) {
86 // The last block can only be reverse compute inlined. We make sure to inline all
87 // producer blocks of the last block beforehand so that reverse compute inline can succeed.
88 std::swap(inline_todos[*last_block_idx], inline_todos.back());
89 }
90
91 auto inline_rule = GetDefaultAutoInline(target->kind->name);
92
93 for (auto name : inline_todos) {
94 inline_rule->Apply(sch, sch->GetBlock(name));
95 }
96}
97
98// Apply instructions from the anchor trace to the target schedule, and returns blocks
99// that remain unscheduled.
100std::vector<BlockRV> ApplyAnchorTrace(Schedule sch, Trace anchor_trace) {
101 static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks");
102 static auto kind_get_block = InstructionKind::Get("GetBlock");
103 static auto kind_compute_inline = InstructionKind::Get("ComputeInline");
104 static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline");
105
106 const auto block_names_orig = GetBlockNames(sch->mod());
107 const auto sch_orig = sch->Copy();
108
109 std::unordered_map<const Object*, const Object*> rv_map;
110 // Blocks and loops that appear in the anchor trace but are not part of the target schedule.
111 std::unordered_set<BlockRV, ObjectHash, ObjectEqual> foreign_blocks;
112 std::unordered_set<LoopRV, ObjectHash, ObjectEqual> foreign_loops;
113
114 // Instructions in the anchor trace can be applied only if all inputs are part of the target
115 // schedule.
116 auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) {
117 for (auto input : inst->inputs) {
118 if (!input.defined()) continue;
119 if ((input->IsInstance<BlockRVNode>() && foreign_blocks.count(Downcast<BlockRV>(input))) ||
120 (input->IsInstance<LoopRVNode>() && foreign_loops.count(Downcast<LoopRV>(input)))) {
121 return false;
122 }
123 }
124 return true;
125 };
126
127 for (const auto& inst : anchor_trace->insts) {
128 if (!is_inst_applicable(inst)) {
129 // If we find an instruction that is not applicable, its outputs are recorded as "foreign"
130 // to the target schedule.
131 for (auto output : inst->outputs) {
132 if (output->IsInstance<BlockRVNode>()) {
133 foreign_blocks.insert(Downcast<BlockRV>(output));
134 } else if (output->IsInstance<LoopRVNode>()) {
135 foreign_loops.insert(Downcast<LoopRV>(output));
136 }
137 }
138 continue;
139 }
140
141 Array<ObjectRef> inputs = TranslateInputRVs(inst->inputs, rv_map);
142
143 if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast<String>(inst->attrs[0]))) {
144 // The anchor trace does get_block on a block that is not part of the target schedule.
145 auto block = Downcast<BlockRV>(inst->outputs[0]);
146 foreign_blocks.insert(block);
147 continue;
148 } else if (inst->kind.same_as(kind_reverse_compute_inline)) {
149 // The anchor trace does reverse_compute_inline on a block, but the block with the same name
150 // in the target schedule cannot be reverse compute inline-ed.
151 // In such cases, it should be possible to apply compute_inline instead.
152 auto block = Downcast<BlockRV>(inputs[0]);
153 auto block_sref = sch->GetSRef(block);
154 if (!CanReverseComputeInline(sch->state(), block_sref)) {
155 ICHECK(CanComputeInline(sch->state(), block_sref));
156 sch->ComputeInline(block);
157 continue;
158 }
159 } else if (inst->kind.same_as(kind_compute_inline)) {
160 // Similar to the reverse_compute_inline case above.
161 auto block = Downcast<BlockRV>(inputs[0]);
162 auto block_sref = sch->GetSRef(block);
163 auto state = sch->state();
164 if (!CanComputeInline(state, block_sref)) {
165 ICHECK(IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)))
166 << "If a spatial block cannot be inlined, it should be the output block";
167 if (CanReverseComputeInline(sch->state(), block_sref)) {
168 sch->ReverseComputeInline(block);
169 }
170 continue;
171 }
172 }
173
174 Optional<ObjectRef> decision = anchor_trace->GetDecision(inst);
175 Array<ObjectRef> outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision);
176
177 if (inst->kind.same_as(kind_get_child_blocks)) {
178 // We want to allow a trace generated for a single conv2d block to be applied to
179 // conv2d -> elemwise blocks, where two conv2d are the same workload.
180 // GetChildBlocks returns a different number of blocks for the two cases above, which
181 // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() ==
182 // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new"
183 // outputs matches with the "old" outputs, and truncating the new outputs accordingly.
184 ICHECK(inst->outputs.size() <= outputs.size());
185 TranslateAddOutputRVs(
186 inst->outputs, Array<ObjectRef>(outputs.begin(), outputs.begin() + inst->outputs.size()),
187 &rv_map);
188 } else {
189 TranslateAddOutputRVs(inst->outputs, outputs, &rv_map);
190 }
191 }
192
193 auto is_scheduled = [=](const std::string& block_name) {
194 auto loops = sch->GetLoops(sch->GetBlock(block_name));
195 auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name));
196 if (loops.size() != loops_orig.size()) {
197 return true;
198 }
199 for (size_t i = 0; i < loops.size(); ++i) {
200 auto loop = sch->Get(loops[i]);
201 auto loop_orig = sch_orig->Get(loops_orig[i]);
202 if (loop->kind != loop_orig->kind) {
203 return true;
204 }
205 }
206 return false;
207 };
208
209 const auto block_names_now = GetBlockNames(sch->mod());
210 std::vector<BlockRV> unscheduled_blocks;
211
212 for (auto name : block_names_orig) {
213 if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) {
214 unscheduled_blocks.push_back(sch->GetBlock(name));
215 }
216 }
217
218 return unscheduled_blocks;
219}
220
221void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) {
222 InlinePostBlocks(sch, anchor_trace, target);
223
224 auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace);
225 ICHECK(unscheduled_blocks.size() <= 1)
226 << "All blocks should have been scheduled or only one (fused) spatial block can remain "
227 "unscheduled at this point.";
228
229 if (unscheduled_blocks.empty()) {
230 // All blocks have already been scheduled.
231 return;
232 }
233
234 auto last_block = unscheduled_blocks[0];
235 auto last_block_producers = sch->GetProducers(last_block);
236
237 if (last_block_producers.size() == 1 && IsSpatial(sch->GetSRef(last_block_producers[0]))) {
238 // Inline into the cache write stage
239 sch->ReverseComputeInline(last_block);
240 } else if (target->kind->name == "llvm" || target->kind->name == "hexagon") {
241 sch->Parallel(sch->Fuse(sch->GetLoops(last_block)));
242 } else if (IsGPUTarget(target->kind->name)) {
243 auto max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block");
244 ICHECK(max_threads_per_block.defined())
245 << "ValueError: missing attribute `max_threads_per_block` in the target";
246
247 auto auto_bind_rule =
248 ScheduleRule::AutoBind(/*max_threadblocks=*/256,
249 /*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024},
250 max_threads_per_block.value()->value);
251 auto_bind_rule->Apply(sch, last_block);
252 }
253}
254
255TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace")
256 .set_body_typed(ScheduleUsingAnchorTrace);
257
258} // namespace meta_schedule
259} // namespace tvm
260