1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file convert_sparse_conv2d.cc |
23 | * |
24 | * \brief Mutate conv2d operator to sparse conv2d operator |
25 | */ |
26 | #include <tvm/ir/expr.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/nn.h> |
29 | #include <tvm/relay/attrs/transform.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/transform.h> |
33 | |
34 | #include <unordered_map> |
35 | #include <unordered_set> |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | // Search conv2d op weight name from Expr |
41 | class Conv2dOpWeightVisitor : private ExprVisitor { |
42 | public: |
43 | Conv2dOpWeightVisitor() : conv2d_op_(Op::Get("nn.conv2d" )) {} |
44 | |
45 | Array<String> Search(const Expr& expr) { |
46 | VisitExpr(expr); |
47 | return memo_; |
48 | } |
49 | |
50 | private: |
51 | void VisitExpr_(const CallNode* n) final { |
52 | if (n->op == conv2d_op_) { |
53 | const auto weight = n->args[1].as<VarNode>(); |
54 | if (weight) { |
55 | memo_.push_back(weight->name_hint()); |
56 | } |
57 | } |
58 | for (const auto& arg : n->args) { |
59 | VisitExpr(arg); |
60 | } |
61 | } |
62 | // Cache op |
63 | const Op& conv2d_op_; |
64 | |
65 | Array<String> memo_; |
66 | }; // SearchConv2dOpWeight |
67 | |
68 | Array<String> SearchConv2dOpWeight(const Expr& e) { return Conv2dOpWeightVisitor().Search(e); } |
69 | |
70 | TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight" ).set_body_typed(SearchConv2dOpWeight); |
71 | |
72 | // Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` |
73 | class Conv2dToSparseConv2dMutator : public ExprRewriter { |
74 | public: |
75 | Conv2dToSparseConv2dMutator(const Array<ObjectRef>& weight_name, |
76 | const Array<Array<PrimExpr>>& weight_shape, const String& layout, |
77 | int kernel_size) |
78 | : conv2d_op_(Op::Get("nn.conv2d" )), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d" )) { |
79 | ICHECK_EQ(weight_name.size(), weight_shape.size()); |
80 | layout_ = layout; |
81 | kernel_size_ = kernel_size; |
82 | for (size_t i = 0; i < weight_name.size(); ++i) { |
83 | ICHECK(weight_name[i]->IsInstance<runtime::StringObj>()); |
84 | std::string k = weight_name[i].as<runtime::StringObj>()->data; |
85 | const auto& ws = weight_shape[i]; |
86 | std::vector<int> v(ws.size()); |
87 | for (size_t j = 0; j < ws.size(); ++j) { |
88 | v[j] = ws[j].as<IntImmNode>()->value; |
89 | } |
90 | target_weights_.emplace(k, v); |
91 | } |
92 | } |
93 | |
94 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
95 | if (pre->op == conv2d_op_) { |
96 | const auto weight = pre->args[1].as<VarNode>(); |
97 | if (weight) { |
98 | if (target_weights_.count(weight->name_hint())) { |
99 | const auto& prefix = weight->name_hint(); |
100 | const auto& ws = target_weights_.at(prefix); |
101 | const auto data = post.as<CallNode>()->args[0]; |
102 | relay::TensorType ws_data_type, ws_indices_type, ws_indptr_type; |
103 | if (ws.size() == 5) { |
104 | ws_data_type = relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32)); |
105 | ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); |
106 | ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32)); |
107 | } else if (ws.size() == 4) { |
108 | ws_data_type = relay::TensorType({ws.at(0), ws.at(1)}, DataType::Float(32)); |
109 | ws_indices_type = relay::TensorType({ws.at(2)}, DataType::Int(32)); |
110 | ws_indptr_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); |
111 | } |
112 | Var weight_data(prefix + ".data" , ws_data_type); |
113 | Var weight_indices(prefix + ".indices" , ws_indices_type); |
114 | Var weight_indptr(prefix + ".indptr" , ws_indptr_type); |
115 | auto attrs = make_object<SparseConv2DAttrs>(); |
116 | attrs->layout = std::move(layout_); |
117 | attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_}; |
118 | return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, |
119 | Attrs(attrs)); |
120 | } |
121 | } |
122 | } |
123 | return post; |
124 | } |
125 | |
126 | private: |
127 | // Cached op |
128 | const Op& conv2d_op_; |
129 | const Op& sparse_conv2d_op_; |
130 | std::unordered_map<std::string, std::vector<int>> target_weights_; |
131 | String layout_; |
132 | int kernel_size_; |
133 | }; // class Conv2dToSparseConv2dAlter |
134 | |
135 | Expr Conv2dToSparse(const Expr& e, const Array<ObjectRef>& weight_name, |
136 | const Array<Array<PrimExpr>>& weight_shape, const String& layout, |
137 | int kernel_size) { |
138 | auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); |
139 | return PostOrderRewrite(e, &rewriter); |
140 | } |
141 | |
142 | template <typename elemTy, size_t... Is> |
143 | auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence<Is...>) { |
144 | return std::make_tuple(arr[Is]...); |
145 | } |
146 | |
147 | template <int N, typename elemTy> |
148 | auto unpack_to_tuple(elemTy* arr) { |
149 | return unpack_to_tuple_internal(arr, std::make_index_sequence<N>{}); |
150 | } |
151 | |
152 | struct Range { |
153 | size_t dim; |
154 | explicit Range(size_t d) : dim(d) {} |
155 | |
156 | struct iterpoint { |
157 | size_t val, lim; |
158 | iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} |
159 | |
160 | size_t operator*() const { return val; } |
161 | |
162 | iterpoint operator/(const iterpoint& rhs) const { |
163 | return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); |
164 | } |
165 | }; |
166 | |
167 | struct iterator { |
168 | size_t val, lim; |
169 | iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} |
170 | |
171 | bool operator!=(const iterator& rhs) const { return val != rhs.val; } |
172 | |
173 | void operator++() { ++val; } |
174 | |
175 | iterpoint operator*() const { return iterpoint(val, lim); } |
176 | }; |
177 | |
178 | iterator begin() { return iterator(0, dim); } |
179 | |
180 | iterator end() { return iterator(dim, dim); } |
181 | }; |
182 | |
183 | // Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` |
184 | class Conv2dToSparseConv2dMutator2 : public ExprRewriter { |
185 | public: |
186 | Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, |
187 | double sparse_thresh) |
188 | : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d" )), |
189 | dev_cpu0_{DLDeviceType::kDLCPU, 0}, |
190 | layout_(layout), |
191 | kernel_size_(kernel_size), |
192 | blockH_(blockH), |
193 | blockW_(blockW), |
194 | sparse_thresh_(sparse_thresh) {} |
195 | |
196 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
197 | // check op type & attrs |
198 | const auto pre_attrs = pre->attrs.as<Conv2DAttrs>(); |
199 | if (!pre_attrs || pre_attrs->data_layout != layout_ || |
200 | pre_attrs->strides[0].as<IntImmNode>()->value != 1 || |
201 | pre_attrs->kernel_size[0].as<IntImmNode>()->value != kernel_size_) |
202 | return post; |
203 | // check constant weight |
204 | const auto pre_weight_node = pre->args[1].as<ConstantNode>(); |
205 | if (!pre_weight_node) return post; |
206 | |
207 | // check weight dtype & shape |
208 | auto&& pre_weight = pre_weight_node->data; |
209 | auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); |
210 | ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only |
211 | auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); |
212 | int O, I, H, W; |
213 | if (layout_ == "NCHW" ) { |
214 | std::tie(O, I, H, W) = pre_weight_shape; |
215 | } else { // NHWC |
216 | std::tie(H, W, I, O) = pre_weight_shape; |
217 | } |
218 | int CO = O, CI = H * W * I; |
219 | |
220 | // copy to vector |
221 | std::vector<float> pre_weight_data(CO * CI); |
222 | pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); |
223 | if (layout_ == "NHWC" ) { |
224 | std::vector<float> tmp(pre_weight_data.size()); |
225 | for (auto i : Range(CO)) |
226 | for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; |
227 | std::swap(tmp, pre_weight_data); |
228 | } |
229 | // convert to BSR |
230 | std::vector<float> wdata, block(blockH_ * blockW_); |
231 | std::vector<int32_t> windices, windptr; |
232 | for (auto bh : Range(CO / blockH_)) { |
233 | windptr.push_back(windices.size()); |
234 | for (auto bw : Range(CI / blockW_)) { |
235 | int cntnnz = 0; |
236 | for (auto i : Range(blockH_)) |
237 | for (auto j : Range(blockW_)) { |
238 | auto tmp = pre_weight_data[*(bh / i / bw / j)]; |
239 | if (tmp) cntnnz++; |
240 | block[*(i / j)] = tmp; |
241 | } |
242 | if (cntnnz) { |
243 | wdata.insert(wdata.end(), block.begin(), block.end()); |
244 | windices.push_back(*bw); |
245 | } |
246 | } |
247 | } |
248 | windptr.push_back(windices.size()); |
249 | double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); |
250 | if (sprate < sparse_thresh_) return post; |
251 | |
252 | // constrct return data |
253 | int nnz = windices.size(); |
254 | auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); |
255 | auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); |
256 | auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); |
257 | weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); |
258 | weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); |
259 | weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); |
260 | |
261 | // construct return call |
262 | auto args = runtime::Array<relay::Expr>{post.as<CallNode>()->args[0], Constant(weight_data), |
263 | Constant(weight_indices), Constant(weight_indptr)}; |
264 | auto attrs = make_object<SparseConv2DAttrs>(); |
265 | attrs->layout = layout_; |
266 | attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_}; |
267 | return Call(sparse_conv2d_op_, args, Attrs(attrs)); |
268 | } |
269 | |
270 | private: |
271 | const Op& sparse_conv2d_op_; |
272 | DLDevice dev_cpu0_; |
273 | String layout_; |
274 | int kernel_size_, blockH_, blockW_; |
275 | double sparse_thresh_; |
276 | }; // class Conv2dToSparseConv2dMutator2 |
277 | |
278 | Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, |
279 | double sparse_thresh) { |
280 | auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); |
281 | return PostOrderRewrite(e, &rewriter); |
282 | } |
283 | |
284 | namespace transform { |
285 | |
286 | // Convert a model with separate weight info (already sparsified). |
287 | Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimExpr>>& weight_shape, |
288 | const String& layout, int kernel_size) { |
289 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
290 | [=](Function f, IRModule m, PassContext pc) { |
291 | // Remove FreeVar warnings |
292 | auto f0 = |
293 | Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); |
294 | Array<Var> sparse_params = FreeVars(f0); |
295 | auto f1 = WithFields(f0, sparse_params); |
296 | Array<Var> params = FreeVars(f1); |
297 | for (const auto& var : sparse_params) { |
298 | params.push_back(var); |
299 | } |
300 | return WithFields(f1, params); |
301 | }; |
302 | return CreateFunctionPass(pass_func, 4, "Conv2dToSparse" , {"DeadCodeElimination" }); |
303 | } |
304 | |
305 | TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse" ).set_body_typed(Conv2dToSparse); |
306 | |
307 | // Convert a model with freezed params (sparsified in the pass). |
308 | Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int blockW, |
309 | double sparse_thresh) { |
310 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
311 | [=](Function f, IRModule m, PassContext pc) { |
312 | auto f0 = Downcast<Function>( |
313 | Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); |
314 | return f0; |
315 | }; |
316 | return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2" , {"DeadCodeElimination" }); |
317 | } |
318 | |
319 | TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2" ).set_body_typed(Conv2dToSparse2); |
320 | |
321 | } // namespace transform |
322 | |
323 | } // namespace relay |
324 | } // namespace tvm |
325 | |