1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file simplify_fc_transpose.cc
23 *
24 * \brief Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to
25 * ```y = nn.dense(x, wt)```
26 */
27#include <tvm/ir/expr.h>
28#include <tvm/relay/analysis.h>
29#include <tvm/relay/attrs/nn.h>
30#include <tvm/relay/attrs/transform.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/op_attr_types.h>
33#include <tvm/relay/transform.h>
34
35#include <unordered_map>
36#include <unordered_set>
37
38namespace tvm {
39namespace relay {
40
41// Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))```
42class FCTransposeVisitor : private ExprVisitor {
43 public:
44 FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {}
45
46 Array<String> Search(const Expr& expr) {
47 VisitExpr(expr);
48 return memo_;
49 }
50
51 private:
52 void VisitExpr_(const CallNode* n) final {
53 if (n->op == dense_op_) {
54 const auto weight = n->args[1].as<CallNode>();
55 if (weight) {
56 if (weight->op == transpose_op_) {
57 if (weight->args[0].as<VarNode>()) {
58 const auto arg = weight->args[0].as<VarNode>();
59 memo_.push_back(arg->name_hint());
60 }
61 }
62 }
63 }
64 for (const auto& arg : n->args) {
65 VisitExpr(arg);
66 }
67 }
68
69 const Op& dense_op_;
70 const Op& transpose_op_;
71 Array<String> memo_;
72}; // SearchDenseOpWeight
73
74Array<String> SearchFCTranspose(const Expr& e) { return FCTransposeVisitor().Search(e); }
75
76TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchFCTranspose);
77
78// Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
79class FCTransposeMutator : public ExprRewriter {
80 public:
81 explicit FCTransposeMutator(const Array<ObjectRef>& target_weights)
82 : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {
83 for (size_t i = 0; i < target_weights.size(); ++i) {
84 ICHECK(target_weights[i]->IsInstance<runtime::StringObj>());
85 std::string k = target_weights[i].as<runtime::StringObj>()->data;
86 target_weights_.emplace(k);
87 }
88 }
89
90 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
91 if (pre->op == dense_op_) {
92 const auto data = post.as<CallNode>()->args[0];
93 const auto weight = pre->args[1].as<CallNode>();
94 if (weight) {
95 if (weight->op == transpose_op_) {
96 const auto arg = weight->args[0];
97 if (arg.as<VarNode>()) {
98 const auto& arg_node = arg.as<VarNode>();
99 ICHECK_GT(target_weights_.count(arg_node->name_hint()), 0);
100 const auto& tt = arg_node->type_annotation.as<TensorTypeNode>();
101 auto wt_type = TensorType({tt->shape[1], tt->shape[0]}, tt->dtype);
102 Var wt(arg_node->name_hint() + ".T", wt_type);
103 return Call(dense_op_, {data, wt}, pre->attrs, pre->type_args);
104 }
105 }
106 }
107 }
108 return post;
109 }
110
111 private:
112 // Cached op
113 const Op& dense_op_;
114 const Op& transpose_op_;
115 std::unordered_set<std::string> target_weights_;
116}; // class DenseToSparseDenseAlter
117
118Expr SimplifyFCTranspose(const Expr& e, const Array<ObjectRef>& target_weights) {
119 auto rewriter = FCTransposeMutator(target_weights);
120 return PostOrderRewrite(e, &rewriter);
121}
122
123namespace transform {
124
125Pass SimplifyFCTranspose(const Array<ObjectRef>& target_weights) {
126 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
127 [=](Function f, IRModule m, PassContext pc) {
128 // Remove FreeVar warning
129 auto f0 = Downcast<Function>(SimplifyFCTranspose(f, target_weights));
130 Array<Var> wt_params = FreeVars(f0);
131 auto f1 = WithFields(f0, wt_params);
132 Array<Var> params = FreeVars(f1);
133 for (const auto& var : wt_params) {
134 params.push_back(var);
135 }
136 return WithFields(f1, params);
137 };
138 return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"});
139}
140
141TVM_REGISTER_GLOBAL("relay._transform.SimplifyFCTranspose").set_body_typed(SimplifyFCTranspose);
142
143} // namespace transform
144
145} // namespace relay
146} // namespace tvm
147