1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file extract_fused_functions.cc |
22 | * \brief Apply fusion and extract fused primitive functions from an IRModule |
23 | */ |
24 | #include <tvm/node/structural_hash.h> |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/expr.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | |
33 | class : private ExprVisitor { |
34 | public: |
35 | explicit (const IRModule& mod) : mod_(mod) {} |
36 | |
37 | IRModule () { |
38 | VisitExpr(this->mod_->Lookup("main" )); |
39 | |
40 | auto functions = Map<GlobalVar, BaseFunc>(); |
41 | for (auto pair : this->functions) { |
42 | functions.Set(GlobalVar(pair.first), pair.second); |
43 | } |
44 | |
45 | this->mod_->functions = functions; |
46 | return this->mod_; |
47 | } |
48 | |
49 | private: |
50 | const IRModule ; |
51 | // This is not simply Map<GlobalVar, Function> because GlobalVar doesn't |
52 | // have the desired equals property |
53 | Map<String, Function> ; |
54 | |
55 | void (const FunctionNode* n) final { |
56 | if (n->HasNonzeroAttr(attr::kPrimitive)) { |
57 | // Add function to functions, keyed by function hash string |
58 | Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); |
59 | size_t hash_ = tvm::StructuralHash()(func); |
60 | this->functions.Set(std::to_string(hash_), func); |
61 | } |
62 | |
63 | ExprVisitor::VisitExpr_(n); |
64 | } |
65 | }; |
66 | |
67 | namespace transform { |
68 | |
69 | Pass () { |
70 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
71 | [=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); }; |
72 | auto = CreateModulePass(pass_func, 1, "ExtractFusedFunctions" , {}); |
73 | |
74 | return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass}, |
75 | "ExtractFusedFunctions" ); |
76 | } |
77 | |
78 | TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions" ).set_body_typed(ExtractFusedFunctions); |
79 | |
80 | } // namespace transform |
81 | |
82 | } // namespace relay |
83 | } // namespace tvm |
84 | |