1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file convert_sparse_dense.cc |
23 | * |
24 | * \brief Mutate dense operator to sparse dense operator |
25 | */ |
26 | #include <tvm/ir/expr.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/nn.h> |
29 | #include <tvm/relay/attrs/transform.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/transform.h> |
33 | |
34 | #include <unordered_map> |
35 | #include <unordered_set> |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | // Search dense op weight name from Expr |
41 | class DenseOpWeightVisitor : private ExprVisitor { |
42 | public: |
43 | DenseOpWeightVisitor() : dense_op_(Op::Get("nn.dense" )) {} |
44 | |
45 | Array<String> Search(const Expr& expr) { |
46 | VisitExpr(expr); |
47 | return memo_; |
48 | } |
49 | |
50 | private: |
51 | void VisitExpr_(const CallNode* n) final { |
52 | if (n->op == dense_op_) { |
53 | const auto weight = n->args[1].as<VarNode>(); |
54 | if (weight) { |
55 | memo_.push_back(weight->name_hint()); |
56 | } |
57 | } |
58 | for (const auto& arg : n->args) { |
59 | VisitExpr(arg); |
60 | } |
61 | } |
62 | // Cache op |
63 | const Op& dense_op_; |
64 | |
65 | Array<String> memo_; |
66 | }; // SearchDenseOpWeight |
67 | |
68 | Array<String> SearchDenseOpWeight(const Expr& e) { return DenseOpWeightVisitor().Search(e); } |
69 | |
70 | TVM_REGISTER_GLOBAL("relay.analysis.search_dense_op_weight" ).set_body_typed(SearchDenseOpWeight); |
71 | |
72 | // Mutate ```nn.dense``` to ```nn.sparse_dense``` |
73 | class DenseToSparseDenseMutator : public ExprRewriter { |
74 | public: |
75 | DenseToSparseDenseMutator(const Array<ObjectRef>& weight_name, |
76 | const Array<Array<PrimExpr>>& weight_shape) |
77 | : dense_op_(Op::Get("nn.dense" )), sparse_dense_op_(Op::Get("nn.sparse_dense" )) { |
78 | ICHECK_EQ(weight_name.size(), weight_shape.size()); |
79 | for (size_t i = 0; i < weight_name.size(); ++i) { |
80 | ICHECK(weight_name[i]->IsInstance<runtime::StringObj>()); |
81 | std::string k = weight_name[i].as<runtime::StringObj>()->data; |
82 | const auto& ws = weight_shape[i]; |
83 | std::vector<int> v(ws.size()); |
84 | for (size_t j = 0; j < ws.size(); ++j) { |
85 | v[j] = ws[j].as<IntImmNode>()->value; |
86 | } |
87 | target_weights_.emplace(k, v); |
88 | } |
89 | } |
90 | |
91 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
92 | if (pre->op == dense_op_) { |
93 | const auto weight = pre->args[1].as<VarNode>(); |
94 | if (weight) { |
95 | if (target_weights_.count(weight->name_hint())) { |
96 | const auto& prefix = weight->name_hint(); |
97 | const auto& ws = target_weights_.at(prefix); |
98 | const auto data = post.as<CallNode>()->args[0]; |
99 | auto ws_data_type = |
100 | relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32)); |
101 | auto ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); |
102 | auto ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32)); |
103 | Var weight_data(prefix + ".data" , ws_data_type); |
104 | Var weight_indices(prefix + ".indices" , ws_indices_type); |
105 | Var weight_indptr(prefix + ".indptr" , ws_indptr_type); |
106 | auto attrs = make_object<SparseDenseAttrs>(); |
107 | |
108 | return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr}, |
109 | Attrs(attrs)); |
110 | } |
111 | } |
112 | } |
113 | return post; |
114 | } |
115 | |
116 | private: |
117 | // Cached op |
118 | const Op& dense_op_; |
119 | const Op& sparse_dense_op_; |
120 | std::unordered_map<std::string, std::vector<int>> target_weights_; |
121 | }; // class DenseToSparseDenseAlter |
122 | |
123 | Expr DenseToSparse(const Expr& e, const Array<ObjectRef>& weight_name, |
124 | const Array<Array<PrimExpr>>& weight_shape) { |
125 | auto rewriter = DenseToSparseDenseMutator(weight_name, weight_shape); |
126 | return PostOrderRewrite(e, &rewriter); |
127 | } |
128 | |
129 | namespace transform { |
130 | |
131 | Pass DenseToSparse(const Array<ObjectRef>& weight_name, |
132 | const Array<Array<PrimExpr>>& weight_shape) { |
133 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
134 | [=](Function f, IRModule m, PassContext pc) { |
135 | // Remove FreeVar warnings |
136 | auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape)); |
137 | Array<Var> sparse_params = FreeVars(f0); |
138 | auto f1 = WithFields(f0, sparse_params); |
139 | Array<Var> params = FreeVars(f1); |
140 | for (const auto& var : sparse_params) { |
141 | params.push_back(var); |
142 | } |
143 | return WithFields(f1, params); |
144 | }; |
145 | return CreateFunctionPass(pass_func, 4, "DenseToSparse" , {"DeadCodeElimination" }); |
146 | } |
147 | |
148 | TVM_REGISTER_GLOBAL("relay._transform.DenseToSparse" ).set_body_typed(DenseToSparse); |
149 | |
150 | } // namespace transform |
151 | |
152 | } // namespace relay |
153 | } // namespace tvm |
154 | |