1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/ir/name_supply.h> |
20 | #include <tvm/meta_schedule/extracted_task.h> |
21 | #include <tvm/relay/expr.h> |
22 | #include <tvm/relay/expr_functor.h> |
23 | #include <tvm/relay/function.h> |
24 | #include <tvm/target/target.h> |
25 | |
26 | #include <numeric> |
27 | |
28 | #include "../../meta_schedule/module_equality.h" |
29 | #include "../../te/operation/create_primfunc.h" |
30 | #include "./te_compiler_cache.h" |
31 | #include "./utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | namespace backend { |
36 | |
37 | class OpCounter : public ExprVisitor { |
38 | public: |
39 | static size_t GetOpCount(relay::Function func) { |
40 | OpCounter counter; |
41 | counter(func->body); |
42 | return counter.count; |
43 | } |
44 | |
45 | private: |
46 | void VisitExpr_(const CallNode* call) final { |
47 | if (call->op->IsInstance<OpNode>()) { |
48 | ++count; |
49 | } |
50 | ExprVisitor::VisitExpr_(call); |
51 | } |
52 | |
53 | size_t count{0}; |
54 | }; |
55 | |
56 | Array<meta_schedule::ExtractedTask> (IRModule mod, Target target, |
57 | Map<String, runtime::NDArray> params, |
58 | String mod_eq_name) { |
59 | using meta_schedule::ExtractedTask; |
60 | using meta_schedule::ModuleEqual; |
61 | using meta_schedule::ModuleHash; |
62 | backend::BindParamsInModule(mod, params); |
63 | // is_vm=true for backward compatibility |
64 | Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); |
65 | pass_seqs.push_back(transform::FuseOps()); |
66 | |
67 | mod = transform::Sequential(pass_seqs)(std::move(mod)); |
68 | |
69 | std::vector<ExtractedTask> tasks; |
70 | |
71 | auto mod_eq = meta_schedule::ModuleEquality::Create(mod_eq_name); |
72 | |
73 | std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual> cache( |
74 | /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq)); |
75 | |
76 | std::vector<std::tuple<std::string, Function, IRModule>> lower_results; |
77 | |
78 | NameSupply constant_name_supply("" ); |
79 | |
80 | PostOrderVisit(mod->Lookup("main" ), [&](const Expr& exp) { |
81 | if (exp->IsInstance<FunctionNode>()) { |
82 | Function relay_func = Downcast<Function>(exp); |
83 | if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { |
84 | return; |
85 | } |
86 | |
87 | auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply); |
88 | if (f) { |
89 | IRModule tir_mod = PrimFuncToIRModule(f.value()); |
90 | lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod)); |
91 | } |
92 | } |
93 | }); |
94 | |
95 | std::vector<int> indices(lower_results.size()); |
96 | std::iota(indices.begin(), indices.end(), 0); |
97 | |
98 | if (mod_eq_name == "anchor-block" ) { |
99 | std::vector<size_t> op_counts(lower_results.size()); |
100 | for (size_t i = 0; i < op_counts.size(); ++i) { |
101 | op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i])); |
102 | } |
103 | |
104 | // When anchor-block based equality is used, tuning tasks "nn_conv2d_add_nn_relu" and |
105 | // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. Thus, one of |
106 | // them will be filtered by the cache below. |
107 | // |
108 | // To make sure that we tune "nn_conv2d_add_nn_relu" and not "nn_conv2d_add_add_nn_relu", |
109 | // we sort the TE lowering results based on the number of relay ops. This way, |
110 | // "nn_conv2d_add_nn_relu" will be added to the cache first, and "nn_conv2d_add_add_nn_relu" |
111 | // will be filtered. |
112 | std::sort(indices.begin(), indices.end(), |
113 | [&op_counts](int i1, int i2) { return op_counts[i1] < op_counts[i2]; }); |
114 | } |
115 | |
116 | for (auto i : indices) { |
117 | const auto& [fused_name, relay_func, tir_mod] = lower_results[i]; |
118 | auto it = cache.find(tir_mod); |
119 | if (it != cache.end()) { |
120 | it->second->weight += 1; |
121 | continue; |
122 | } |
123 | // Note that the cache is key-ed on the tir mod, rather than the relay mod |
124 | IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); |
125 | ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1); |
126 | tasks.push_back(task); |
127 | cache.emplace(tir_mod, task); |
128 | } |
129 | |
130 | // Tasks are extracted via post order visit, return the reversed list. |
131 | std::reverse(tasks.begin(), tasks.end()); |
132 | NameSupply name_supply = NameSupply("" ); |
133 | for (ExtractedTask task : tasks) { |
134 | task->task_name = name_supply->FreshName(task->task_name); |
135 | } |
136 | return tasks; |
137 | } |
138 | |
139 | TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask" ).set_body_typed(ExtractTask); |
140 | |
141 | } // namespace backend |
142 | } // namespace relay |
143 | } // namespace tvm |
144 | |