1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file target_hooks.cc
22 * \brief Relay passes for processing Target Hooks which have been registered on functions within
23 * the IRModule
24 */
25
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/transform.h>
28
29namespace tvm {
30namespace relay {
31namespace transform {
32
33namespace {
34
35/*!
36 * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any
37 * 'external codegen' Target instance with matching kind name which should be current when
38 * the pass is applied.
39 */
40struct CustomPass {
41 std::string target_kind_name;
42 Pass pass;
43 Optional<Target> opt_target;
44
45 CustomPass(std::string target_kind_name, Pass pass, Optional<Target> opt_target)
46 : target_kind_name(std::move(target_kind_name)),
47 pass(std::move(pass)),
48 opt_target(std::move(opt_target)) {}
49};
50
51/*!
52 * \brief Collect all the \p CustomPasses needed according to the "Compiler" attributes on
53 * inlined or global functions.
54 */
55class TargetHookVisitor : public MixedModeVisitor {
56 public:
57 TargetHookVisitor(IRModule mod, CompilationConfig config)
58 : mod_(std::move(mod)),
59 config_(std::move(config)),
60 target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>(tvm::attr::kRelayToTIR)) {}
61
62 std::vector<CustomPass> Visit() {
63 ICHECK(custom_passes_.empty());
64 // To ensure the passes are run in a deterministic order we'll search for functions in
65 // lexicographic order.
66 std::vector<std::pair<std::string, BaseFunc>> functions;
67 for (const auto& kv : mod_->functions) {
68 functions.emplace_back(kv.first->name_hint, kv.second);
69 }
70 std::sort(functions.begin(), functions.end());
71 for (const auto& kv : functions) {
72 if (const auto* function_node = kv.second.as<FunctionNode>()) {
73 // May be a top-level function with a "Compiler" attribute.
74 MaybeAddPassForFunction(function_node);
75 }
76 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
77 // May have calls to inlined "Compiler" functions in body.
78 VisitExpr(GetRef<Function>(function_node));
79 }
80 }
81 return std::move(custom_passes_);
82 }
83
84 private:
85 using tvm::relay::MixedModeVisitor::VisitExpr_;
86
87 void VisitExpr_(const LetNode* let_node) final {
88 auto pre_visit = [this](const LetNode* inner_let_node) {
89 this->VisitExpr(inner_let_node->var);
90 this->VisitExpr(inner_let_node->value);
91 };
92 auto post_visit = [this](const LetNode* inner_let_node) {
93 this->VisitExpr(inner_let_node->body);
94 this->visit_counter_[inner_let_node] += 1;
95 };
96 ExpandANormalForm(let_node, pre_visit, post_visit);
97 }
98
99 void VisitExpr_(const FunctionNode* function_node) override {
100 ExprVisitor::VisitExpr_(function_node);
101 MaybeAddPassForFunction(function_node);
102 }
103
104 /*!
105 * \brief If \p function_node has a "Compiler" attribute, checks if we should include a
106 * matching custom pass. Otherwise no-op.
107 */
108 void MaybeAddPassForFunction(const FunctionNode* function_node) {
109 Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
110 if (!opt_compiler) {
111 // No external codegen required.
112 return;
113 }
114 // First cross-over: use "Compiler" attribute name as target kind.
115 std::string kind_name = opt_compiler.value();
116 Optional<TargetKind> opt_target_kind = tvm::TargetKind::Get(kind_name);
117 if (!opt_target_kind || !target_attr_map_.count(opt_target_kind.value())) {
118 // Target kind does not exist or have the "RelayToTIR" attribute, no custom pass to consider.
119 return;
120 }
121 if (!seen_kinds_.emplace(kind_name).second) {
122 // Already accounted for custom pass.
123 return;
124 }
125 // Second (optional) cross-over: find unique Target instance in overall available targets with
126 // the same kind so that it can be made available when custom pass is invoked.
127 Optional<Target> opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value());
128 Pass custom_target_pass = target_attr_map_[opt_target_kind.value()];
129 custom_passes_.emplace_back(std::move(kind_name), std::move(custom_target_pass),
130 std::move(opt_target));
131 }
132
133 /*! \brief IRModule we are visiting. */
134 IRModule mod_;
135 /*! \brief All available targets. */
136 CompilationConfig config_;
137 /*! \brief Cached attribute map for all registered targets */
138 TargetKindAttrMap<Pass> target_attr_map_;
139 /*! \brief Which target kind names have already contributed to the custom passes list. */
140 std::unordered_set<std::string> seen_kinds_;
141 /*!
142 * \brief All the custom passes to run, paired with their corresponding target instances, if any.
143 */
144 std::vector<CustomPass> custom_passes_;
145};
146
147} // namespace
148
149Pass RelayToTIRTargetHook(CompilationConfig config) {
150 auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) {
151 VLOG(1) << "RelayToTIRTargetHook before:" << std::endl << PrettyPrint(mod);
152 TargetHookVisitor target_hook_visitor(mod, config);
153 std::vector<CustomPass> custom_passes = target_hook_visitor.Visit();
154 for (const auto& custom_pass : custom_passes) {
155 if (custom_pass.opt_target.defined()) {
156 VLOG(0) << "Invoking custom pass for target "
157 << custom_pass.opt_target.value()->ToDebugString();
158 // Push the target on the stack.
159 With<Target> with_target(custom_pass.opt_target.value());
160 // Invoke the pass with target in scope.
161 mod = custom_pass.pass(mod);
162 } else {
163 // Invoke the pass.
164 // Note that there may be a non-external codegen target in scope. Each custom pass
165 // must be prepared to handle this, eg by creating a default target instance if the
166 // current target is either null or of a generic kind such as 'cuda' or 'llvm'.
167 VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'";
168 mod = custom_pass.pass(mod);
169 }
170 }
171 VLOG(1) << "RelayToTIRTargetHook after:" << std::endl << PrettyPrint(mod);
172 return mod;
173 };
174 return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {});
175}
176
177} // namespace transform
178} // namespace relay
179} // namespace tvm
180