1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/ir/name_supply.h>
20#include <tvm/meta_schedule/extracted_task.h>
21#include <tvm/relay/expr.h>
22#include <tvm/relay/expr_functor.h>
23#include <tvm/relay/function.h>
24#include <tvm/target/target.h>
25
26#include <numeric>
27
28#include "../../meta_schedule/module_equality.h"
29#include "../../te/operation/create_primfunc.h"
30#include "./te_compiler_cache.h"
31#include "./utils.h"
32
33namespace tvm {
34namespace relay {
35namespace backend {
36
37class OpCounter : public ExprVisitor {
38 public:
39 static size_t GetOpCount(relay::Function func) {
40 OpCounter counter;
41 counter(func->body);
42 return counter.count;
43 }
44
45 private:
46 void VisitExpr_(const CallNode* call) final {
47 if (call->op->IsInstance<OpNode>()) {
48 ++count;
49 }
50 ExprVisitor::VisitExpr_(call);
51 }
52
53 size_t count{0};
54};
55
56Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
57 Map<String, runtime::NDArray> params,
58 String mod_eq_name) {
59 using meta_schedule::ExtractedTask;
60 using meta_schedule::ModuleEqual;
61 using meta_schedule::ModuleHash;
62 backend::BindParamsInModule(mod, params);
63 // is_vm=true for backward compatibility
64 Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
65 pass_seqs.push_back(transform::FuseOps());
66
67 mod = transform::Sequential(pass_seqs)(std::move(mod));
68
69 std::vector<ExtractedTask> tasks;
70
71 auto mod_eq = meta_schedule::ModuleEquality::Create(mod_eq_name);
72
73 std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual> cache(
74 /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq));
75
76 std::vector<std::tuple<std::string, Function, IRModule>> lower_results;
77
78 NameSupply constant_name_supply("");
79
80 PostOrderVisit(mod->Lookup("main"), [&](const Expr& exp) {
81 if (exp->IsInstance<FunctionNode>()) {
82 Function relay_func = Downcast<Function>(exp);
83 if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
84 return;
85 }
86
87 auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply);
88 if (f) {
89 IRModule tir_mod = PrimFuncToIRModule(f.value());
90 lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod));
91 }
92 }
93 });
94
95 std::vector<int> indices(lower_results.size());
96 std::iota(indices.begin(), indices.end(), 0);
97
98 if (mod_eq_name == "anchor-block") {
99 std::vector<size_t> op_counts(lower_results.size());
100 for (size_t i = 0; i < op_counts.size(); ++i) {
101 op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i]));
102 }
103
104 // When anchor-block based equality is used, tuning tasks "nn_conv2d_add_nn_relu" and
105 // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. Thus, one of
106 // them will be filtered by the cache below.
107 //
108 // To make sure that we tune "nn_conv2d_add_nn_relu" and not "nn_conv2d_add_add_nn_relu",
109 // we sort the TE lowering results based on the number of relay ops. This way,
110 // "nn_conv2d_add_nn_relu" will be added to the cache first, and "nn_conv2d_add_add_nn_relu"
111 // will be filtered.
112 std::sort(indices.begin(), indices.end(),
113 [&op_counts](int i1, int i2) { return op_counts[i1] < op_counts[i2]; });
114 }
115
116 for (auto i : indices) {
117 const auto& [fused_name, relay_func, tir_mod] = lower_results[i];
118 auto it = cache.find(tir_mod);
119 if (it != cache.end()) {
120 it->second->weight += 1;
121 continue;
122 }
123 // Note that the cache is key-ed on the tir mod, rather than the relay mod
124 IRModule relay_mod({{GlobalVar(fused_name), relay_func}});
125 ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1);
126 tasks.push_back(task);
127 cache.emplace(tir_mod, task);
128 }
129
130 // Tasks are extracted via post order visit, return the reversed list.
131 std::reverse(tasks.begin(), tasks.end());
132 NameSupply name_supply = NameSupply("");
133 for (ExtractedTask task : tasks) {
134 task->task_name = name_supply->FreshName(task->task_name);
135 }
136 return tasks;
137}
138
139TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask").set_body_typed(ExtractTask);
140
141} // namespace backend
142} // namespace relay
143} // namespace tvm
144