1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file src/relay/transforms/defuse_ops.cc |
23 | * \brief This is an inverse operation of fusion pass. It transforms a fused |
24 | * program returned by relay::transform::FuseOps into the program before FuseOps. |
25 | * (i.e., x == DefuseOps(FuseOps(x))) |
26 | */ |
27 | |
28 | #include <tvm/relay/attrs/transform.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/transform.h> |
31 | |
32 | #include <string> |
33 | #include <unordered_map> |
34 | |
35 | #include "pattern_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | class DefuseOpsMutator : public ExprMutator { |
41 | public: |
42 | class FuncBodyMutator : public ExprMutator { |
43 | public: |
44 | explicit FuncBodyMutator(std::unordered_map<std::string, Expr> args) |
45 | : ExprMutator(), name_to_args_(std::move(args)) {} |
46 | |
47 | Expr VisitExpr_(const VarNode* n) { return name_to_args_[n->name_hint()]; } |
48 | |
49 | private: |
50 | std::unordered_map<std::string, Expr> name_to_args_; |
51 | }; |
52 | |
53 | Expr VisitExpr_(const CallNode* n) { |
54 | auto new_n = ExprMutator::VisitExpr_(n); |
55 | |
56 | if (const auto* call = new_n.as<CallNode>()) { |
57 | if (const auto* func = call->op.as<FunctionNode>()) { |
58 | std::unordered_map<std::string, Expr> name_to_args; |
59 | for (size_t i = 0; i < func->params.size(); ++i) { |
60 | const std::string& pname = func->params[i]->name_hint(); |
61 | ICHECK(name_to_args.cend() == name_to_args.find(pname)) |
62 | << "Found multiple parameters share the same variable name `" << pname |
63 | << "` which introduces uncertainty in DefuseOps pass" ; |
64 | name_to_args[pname] = call->args[i]; |
65 | } |
66 | return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body); |
67 | } |
68 | } |
69 | return new_n; |
70 | } |
71 | }; |
72 | |
73 | Expr DefuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } |
74 | |
75 | namespace transform { |
76 | |
77 | Pass DefuseOps() { |
78 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
79 | [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(DefuseOps(f)); }; |
80 | return CreateFunctionPass(pass_func, 3, "DefuseOps" , {"InferType" }); |
81 | } |
82 | |
83 | TVM_REGISTER_GLOBAL("relay._transform.DefuseOps" ).set_body_typed(DefuseOps); |
84 | |
85 | } // namespace transform |
86 | |
87 | } // namespace relay |
88 | } // namespace tvm |
89 | |