1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file src/relay/transforms/remove_standalone_reshapes.cc
21 * \brief This file contains the Relay pass for removing unfused reshapes from lowered graph.
22 */
23
24#include <tvm/relay/expr_functor.h>
25#include <tvm/relay/transform.h>
26
27#include "../op/call/call.h"
28#include "../op/memory/on_device.h"
29
30namespace tvm {
31namespace relay {
32
33TVM_REGISTER_PASS_CONFIG_OPTION("relay.remove_standalone_reshapes.enable", Bool);
34/*! Removes reshapes right after LowerTE. Removes preceding on_device calls
35 * while removing reshapes.
36 */
37class RemoveStandaloneReshapesMutator : public MixedModeMutator {
38 public:
39 explicit RemoveStandaloneReshapesMutator(IRModule& mod) {} // NOLINT(runtime/references)
40
41 using MixedModeMutator::VisitExpr_;
42
43 /*! * \brief Generated map of let variables to preceding CallLowered */
44 Expr VisitExpr_(const LetNode* let) final {
45 Let ret_let;
46 Var var = Downcast<Var>(this->Mutate(let->var));
47 auto value = this->Mutate(let->value);
48 if (auto* on_device_call = value.as<CallNode>()) {
49 OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call);
50 if (on_device_props.body.defined() && on_device_props.body->IsInstance<CallNode>()) {
51 const Call call_lowered = Downcast<Call>(on_device_props.body);
52 if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) {
53 let_var_to_call_lowered_.Set(var, call_lowered);
54 }
55 }
56 }
57 auto body = this->Mutate(let->body);
58 return WithFields(GetRef<Let>(let), var, value, body);
59 }
60
61 /*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */
62 Expr Rewrite_(const CallNode* call, const Expr& post) final {
63 /*
64 %1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...);
65 let %x: = on_device(%1, ...);
66 %2 = (%x,);
67 %3 = call_lowered(@tvmgen_default_fused_reshape, %2, ...,
68 "relay_attrs"=__dict__="relay.reshape_only"=1, ...);
69 */
70 const CallNode* post_call = post.as<CallNode>();
71 CallLoweredProps call_lowered_props = GetCallLoweredProps(post_call);
72 if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) {
73 if (!call_lowered_props.arguments.empty() &&
74 call_lowered_props.arguments[0]->IsInstance<VarNode>()) {
75 Var var = Downcast<Var>(call_lowered_props.arguments[0]);
76 if (var.defined() && let_var_to_call_lowered_.find(var) != let_var_to_call_lowered_.end()) {
77 return let_var_to_call_lowered_[var];
78 }
79 }
80 }
81
82 return post;
83 }
84
85 private:
86 /*! \brief Map of LetNode's var to previous call_lowered. */
87 Map<Var, Call> let_var_to_call_lowered_;
88};
89
90namespace transform {
91
92Pass RemoveStandaloneReshapes() {
93 auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) {
94 VLOG(1) << "RemoveStandaloneReshapes before:" << std::endl << PrettyPrint(mod);
95 RemoveStandaloneReshapesMutator remove_reshapes_mutator(mod);
96 Function main_func = Downcast<Function>(mod->Lookup("main"));
97 Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body);
98 if (!new_main_body.same_as(main_func->body)) {
99 auto main_var = mod->GetGlobalVar("main");
100 auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
101 main_func->type_params, main_func->attrs);
102 mod->Update(main_var, new_main_func);
103 }
104 Array<runtime::String> entry_functions{"main"};
105 mod = RemoveUnusedFunctions(entry_functions)(mod);
106
107 VLOG(1) << "RemoveStandaloneReshapes after:" << std::endl << PrettyPrint(mod);
108 return mod;
109 };
110 return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {});
111}
112
113TVM_REGISTER_GLOBAL("relay._transform.RemoveStandaloneReshapes")
114 .set_body_typed(RemoveStandaloneReshapes);
115
116} // namespace transform
117} // namespace relay
118} // namespace tvm
119