1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file target_hooks.cc |
22 | * \brief Relay passes for processing Target Hooks which have been registered on functions within |
23 | * the IRModule |
24 | */ |
25 | |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace transform { |
32 | |
33 | namespace { |
34 | |
35 | /*! |
36 | * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any |
37 | * 'external codegen' Target instance with matching kind name which should be current when |
38 | * the pass is applied. |
39 | */ |
40 | struct CustomPass { |
41 | std::string target_kind_name; |
42 | Pass pass; |
43 | Optional<Target> opt_target; |
44 | |
45 | CustomPass(std::string target_kind_name, Pass pass, Optional<Target> opt_target) |
46 | : target_kind_name(std::move(target_kind_name)), |
47 | pass(std::move(pass)), |
48 | opt_target(std::move(opt_target)) {} |
49 | }; |
50 | |
51 | /*! |
52 | * \brief Collect all the \p CustomPasses needed according to the "Compiler" attributes on |
53 | * inlined or global functions. |
54 | */ |
55 | class TargetHookVisitor : public MixedModeVisitor { |
56 | public: |
57 | TargetHookVisitor(IRModule mod, CompilationConfig config) |
58 | : mod_(std::move(mod)), |
59 | config_(std::move(config)), |
60 | target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>(tvm::attr::kRelayToTIR)) {} |
61 | |
62 | std::vector<CustomPass> Visit() { |
63 | ICHECK(custom_passes_.empty()); |
64 | // To ensure the passes are run in a deterministic order we'll search for functions in |
65 | // lexicographic order. |
66 | std::vector<std::pair<std::string, BaseFunc>> functions; |
67 | for (const auto& kv : mod_->functions) { |
68 | functions.emplace_back(kv.first->name_hint, kv.second); |
69 | } |
70 | std::sort(functions.begin(), functions.end()); |
71 | for (const auto& kv : functions) { |
72 | if (const auto* function_node = kv.second.as<FunctionNode>()) { |
73 | // May be a top-level function with a "Compiler" attribute. |
74 | MaybeAddPassForFunction(function_node); |
75 | } |
76 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
77 | // May have calls to inlined "Compiler" functions in body. |
78 | VisitExpr(GetRef<Function>(function_node)); |
79 | } |
80 | } |
81 | return std::move(custom_passes_); |
82 | } |
83 | |
84 | private: |
85 | using tvm::relay::MixedModeVisitor::VisitExpr_; |
86 | |
87 | void VisitExpr_(const LetNode* let_node) final { |
88 | auto pre_visit = [this](const LetNode* inner_let_node) { |
89 | this->VisitExpr(inner_let_node->var); |
90 | this->VisitExpr(inner_let_node->value); |
91 | }; |
92 | auto post_visit = [this](const LetNode* inner_let_node) { |
93 | this->VisitExpr(inner_let_node->body); |
94 | this->visit_counter_[inner_let_node] += 1; |
95 | }; |
96 | ExpandANormalForm(let_node, pre_visit, post_visit); |
97 | } |
98 | |
99 | void VisitExpr_(const FunctionNode* function_node) override { |
100 | ExprVisitor::VisitExpr_(function_node); |
101 | MaybeAddPassForFunction(function_node); |
102 | } |
103 | |
104 | /*! |
105 | * \brief If \p function_node has a "Compiler" attribute, checks if we should include a |
106 | * matching custom pass. Otherwise no-op. |
107 | */ |
108 | void MaybeAddPassForFunction(const FunctionNode* function_node) { |
109 | Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler); |
110 | if (!opt_compiler) { |
111 | // No external codegen required. |
112 | return; |
113 | } |
114 | // First cross-over: use "Compiler" attribute name as target kind. |
115 | std::string kind_name = opt_compiler.value(); |
116 | Optional<TargetKind> opt_target_kind = tvm::TargetKind::Get(kind_name); |
117 | if (!opt_target_kind || !target_attr_map_.count(opt_target_kind.value())) { |
118 | // Target kind does not exist or have the "RelayToTIR" attribute, no custom pass to consider. |
119 | return; |
120 | } |
121 | if (!seen_kinds_.emplace(kind_name).second) { |
122 | // Already accounted for custom pass. |
123 | return; |
124 | } |
125 | // Second (optional) cross-over: find unique Target instance in overall available targets with |
126 | // the same kind so that it can be made available when custom pass is invoked. |
127 | Optional<Target> opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); |
128 | Pass custom_target_pass = target_attr_map_[opt_target_kind.value()]; |
129 | custom_passes_.emplace_back(std::move(kind_name), std::move(custom_target_pass), |
130 | std::move(opt_target)); |
131 | } |
132 | |
133 | /*! \brief IRModule we are visiting. */ |
134 | IRModule mod_; |
135 | /*! \brief All available targets. */ |
136 | CompilationConfig config_; |
137 | /*! \brief Cached attribute map for all registered targets */ |
138 | TargetKindAttrMap<Pass> target_attr_map_; |
139 | /*! \brief Which target kind names have already contributed to the custom passes list. */ |
140 | std::unordered_set<std::string> seen_kinds_; |
141 | /*! |
142 | * \brief All the custom passes to run, paired with their corresponding target instances, if any. |
143 | */ |
144 | std::vector<CustomPass> custom_passes_; |
145 | }; |
146 | |
147 | } // namespace |
148 | |
149 | Pass RelayToTIRTargetHook(CompilationConfig config) { |
150 | auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { |
151 | VLOG(1) << "RelayToTIRTargetHook before:" << std::endl << PrettyPrint(mod); |
152 | TargetHookVisitor target_hook_visitor(mod, config); |
153 | std::vector<CustomPass> custom_passes = target_hook_visitor.Visit(); |
154 | for (const auto& custom_pass : custom_passes) { |
155 | if (custom_pass.opt_target.defined()) { |
156 | VLOG(0) << "Invoking custom pass for target " |
157 | << custom_pass.opt_target.value()->ToDebugString(); |
158 | // Push the target on the stack. |
159 | With<Target> with_target(custom_pass.opt_target.value()); |
160 | // Invoke the pass with target in scope. |
161 | mod = custom_pass.pass(mod); |
162 | } else { |
163 | // Invoke the pass. |
164 | // Note that there may be a non-external codegen target in scope. Each custom pass |
165 | // must be prepared to handle this, eg by creating a default target instance if the |
166 | // current target is either null or of a generic kind such as 'cuda' or 'llvm'. |
167 | VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'" ; |
168 | mod = custom_pass.pass(mod); |
169 | } |
170 | } |
171 | VLOG(1) << "RelayToTIRTargetHook after:" << std::endl << PrettyPrint(mod); |
172 | return mod; |
173 | }; |
174 | return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook" , {}); |
175 | } |
176 | |
177 | } // namespace transform |
178 | } // namespace relay |
179 | } // namespace tvm |
180 | |