1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file extract_fused_functions.cc
22 * \brief Apply fusion and extract fused primitive functions from an IRModule
23 */
24#include <tvm/node/structural_hash.h>
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/expr.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29
30namespace tvm {
31namespace relay {
32
33class FusedFunctionExtractorWrapper : private ExprVisitor {
34 public:
35 explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {}
36
37 IRModule Extract() {
38 VisitExpr(this->mod_->Lookup("main"));
39
40 auto functions = Map<GlobalVar, BaseFunc>();
41 for (auto pair : this->functions) {
42 functions.Set(GlobalVar(pair.first), pair.second);
43 }
44
45 this->mod_->functions = functions;
46 return this->mod_;
47 }
48
49 private:
50 const IRModule mod_;
51 // This is not simply Map<GlobalVar, Function> because GlobalVar doesn't
52 // have the desired equals property
53 Map<String, Function> functions;
54
55 void VisitExpr_(const FunctionNode* n) final {
56 if (n->HasNonzeroAttr(attr::kPrimitive)) {
57 // Add function to functions, keyed by function hash string
58 Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
59 size_t hash_ = tvm::StructuralHash()(func);
60 this->functions.Set(std::to_string(hash_), func);
61 }
62
63 ExprVisitor::VisitExpr_(n);
64 }
65};
66
67namespace transform {
68
69Pass ExtractFusedFunctions() {
70 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
71 [=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); };
72 auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, "ExtractFusedFunctions", {});
73
74 return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass},
75 "ExtractFusedFunctions");
76}
77
78TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
79
80} // namespace transform
81
82} // namespace relay
83} // namespace tvm
84