1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file src/relay/transforms/defuse_ops.cc
23 * \brief This is an inverse operation of fusion pass. It transforms a fused
24 * program returned by relay::transform::FuseOps into the program before FuseOps.
25 * (i.e., x == DefuseOps(FuseOps(x)))
26 */
27
28#include <tvm/relay/attrs/transform.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/transform.h>
31
32#include <string>
33#include <unordered_map>
34
35#include "pattern_utils.h"
36
37namespace tvm {
38namespace relay {
39
40class DefuseOpsMutator : public ExprMutator {
41 public:
42 class FuncBodyMutator : public ExprMutator {
43 public:
44 explicit FuncBodyMutator(std::unordered_map<std::string, Expr> args)
45 : ExprMutator(), name_to_args_(std::move(args)) {}
46
47 Expr VisitExpr_(const VarNode* n) { return name_to_args_[n->name_hint()]; }
48
49 private:
50 std::unordered_map<std::string, Expr> name_to_args_;
51 };
52
53 Expr VisitExpr_(const CallNode* n) {
54 auto new_n = ExprMutator::VisitExpr_(n);
55
56 if (const auto* call = new_n.as<CallNode>()) {
57 if (const auto* func = call->op.as<FunctionNode>()) {
58 std::unordered_map<std::string, Expr> name_to_args;
59 for (size_t i = 0; i < func->params.size(); ++i) {
60 const std::string& pname = func->params[i]->name_hint();
61 ICHECK(name_to_args.cend() == name_to_args.find(pname))
62 << "Found multiple parameters share the same variable name `" << pname
63 << "` which introduces uncertainty in DefuseOps pass";
64 name_to_args[pname] = call->args[i];
65 }
66 return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body);
67 }
68 }
69 return new_n;
70 }
71};
72
73Expr DefuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); }
74
75namespace transform {
76
77Pass DefuseOps() {
78 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
79 [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(DefuseOps(f)); };
80 return CreateFunctionPass(pass_func, 3, "DefuseOps", {"InferType"});
81}
82
83TVM_REGISTER_GLOBAL("relay._transform.DefuseOps").set_body_typed(DefuseOps);
84
85} // namespace transform
86
87} // namespace relay
88} // namespace tvm
89