1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file convert_sparse_dense.cc
23 *
24 * \brief Mutate dense operator to sparse dense operator
25 */
26#include <tvm/ir/expr.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/nn.h>
29#include <tvm/relay/attrs/transform.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/transform.h>
33
34#include <unordered_map>
35#include <unordered_set>
36
37namespace tvm {
38namespace relay {
39
40// Search dense op weight name from Expr
41class DenseOpWeightVisitor : private ExprVisitor {
42 public:
43 DenseOpWeightVisitor() : dense_op_(Op::Get("nn.dense")) {}
44
45 Array<String> Search(const Expr& expr) {
46 VisitExpr(expr);
47 return memo_;
48 }
49
50 private:
51 void VisitExpr_(const CallNode* n) final {
52 if (n->op == dense_op_) {
53 const auto weight = n->args[1].as<VarNode>();
54 if (weight) {
55 memo_.push_back(weight->name_hint());
56 }
57 }
58 for (const auto& arg : n->args) {
59 VisitExpr(arg);
60 }
61 }
62 // Cache op
63 const Op& dense_op_;
64
65 Array<String> memo_;
66}; // SearchDenseOpWeight
67
68Array<String> SearchDenseOpWeight(const Expr& e) { return DenseOpWeightVisitor().Search(e); }
69
70TVM_REGISTER_GLOBAL("relay.analysis.search_dense_op_weight").set_body_typed(SearchDenseOpWeight);
71
72// Mutate ```nn.dense``` to ```nn.sparse_dense```
73class DenseToSparseDenseMutator : public ExprRewriter {
74 public:
75 DenseToSparseDenseMutator(const Array<ObjectRef>& weight_name,
76 const Array<Array<PrimExpr>>& weight_shape)
77 : dense_op_(Op::Get("nn.dense")), sparse_dense_op_(Op::Get("nn.sparse_dense")) {
78 ICHECK_EQ(weight_name.size(), weight_shape.size());
79 for (size_t i = 0; i < weight_name.size(); ++i) {
80 ICHECK(weight_name[i]->IsInstance<runtime::StringObj>());
81 std::string k = weight_name[i].as<runtime::StringObj>()->data;
82 const auto& ws = weight_shape[i];
83 std::vector<int> v(ws.size());
84 for (size_t j = 0; j < ws.size(); ++j) {
85 v[j] = ws[j].as<IntImmNode>()->value;
86 }
87 target_weights_.emplace(k, v);
88 }
89 }
90
91 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
92 if (pre->op == dense_op_) {
93 const auto weight = pre->args[1].as<VarNode>();
94 if (weight) {
95 if (target_weights_.count(weight->name_hint())) {
96 const auto& prefix = weight->name_hint();
97 const auto& ws = target_weights_.at(prefix);
98 const auto data = post.as<CallNode>()->args[0];
99 auto ws_data_type =
100 relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32));
101 auto ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32));
102 auto ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32));
103 Var weight_data(prefix + ".data", ws_data_type);
104 Var weight_indices(prefix + ".indices", ws_indices_type);
105 Var weight_indptr(prefix + ".indptr", ws_indptr_type);
106 auto attrs = make_object<SparseDenseAttrs>();
107
108 return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr},
109 Attrs(attrs));
110 }
111 }
112 }
113 return post;
114 }
115
116 private:
117 // Cached op
118 const Op& dense_op_;
119 const Op& sparse_dense_op_;
120 std::unordered_map<std::string, std::vector<int>> target_weights_;
121}; // class DenseToSparseDenseAlter
122
123Expr DenseToSparse(const Expr& e, const Array<ObjectRef>& weight_name,
124 const Array<Array<PrimExpr>>& weight_shape) {
125 auto rewriter = DenseToSparseDenseMutator(weight_name, weight_shape);
126 return PostOrderRewrite(e, &rewriter);
127}
128
129namespace transform {
130
131Pass DenseToSparse(const Array<ObjectRef>& weight_name,
132 const Array<Array<PrimExpr>>& weight_shape) {
133 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
134 [=](Function f, IRModule m, PassContext pc) {
135 // Remove FreeVar warnings
136 auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape));
137 Array<Var> sparse_params = FreeVars(f0);
138 auto f1 = WithFields(f0, sparse_params);
139 Array<Var> params = FreeVars(f1);
140 for (const auto& var : sparse_params) {
141 params.push_back(var);
142 }
143 return WithFields(f1, params);
144 };
145 return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"});
146}
147
148TVM_REGISTER_GLOBAL("relay._transform.DenseToSparse").set_body_typed(DenseToSparse);
149
150} // namespace transform
151
152} // namespace relay
153} // namespace tvm
154