1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file convert_sparse_conv2d.cc
23 *
24 * \brief Mutate conv2d operator to sparse conv2d operator
25 */
26#include <tvm/ir/expr.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/nn.h>
29#include <tvm/relay/attrs/transform.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/transform.h>
33
34#include <unordered_map>
35#include <unordered_set>
36
37namespace tvm {
38namespace relay {
39
40// Search conv2d op weight name from Expr
41class Conv2dOpWeightVisitor : private ExprVisitor {
42 public:
43 Conv2dOpWeightVisitor() : conv2d_op_(Op::Get("nn.conv2d")) {}
44
45 Array<String> Search(const Expr& expr) {
46 VisitExpr(expr);
47 return memo_;
48 }
49
50 private:
51 void VisitExpr_(const CallNode* n) final {
52 if (n->op == conv2d_op_) {
53 const auto weight = n->args[1].as<VarNode>();
54 if (weight) {
55 memo_.push_back(weight->name_hint());
56 }
57 }
58 for (const auto& arg : n->args) {
59 VisitExpr(arg);
60 }
61 }
62 // Cache op
63 const Op& conv2d_op_;
64
65 Array<String> memo_;
66}; // SearchConv2dOpWeight
67
68Array<String> SearchConv2dOpWeight(const Expr& e) { return Conv2dOpWeightVisitor().Search(e); }
69
70TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(SearchConv2dOpWeight);
71
72// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d```
73class Conv2dToSparseConv2dMutator : public ExprRewriter {
74 public:
75 Conv2dToSparseConv2dMutator(const Array<ObjectRef>& weight_name,
76 const Array<Array<PrimExpr>>& weight_shape, const String& layout,
77 int kernel_size)
78 : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) {
79 ICHECK_EQ(weight_name.size(), weight_shape.size());
80 layout_ = layout;
81 kernel_size_ = kernel_size;
82 for (size_t i = 0; i < weight_name.size(); ++i) {
83 ICHECK(weight_name[i]->IsInstance<runtime::StringObj>());
84 std::string k = weight_name[i].as<runtime::StringObj>()->data;
85 const auto& ws = weight_shape[i];
86 std::vector<int> v(ws.size());
87 for (size_t j = 0; j < ws.size(); ++j) {
88 v[j] = ws[j].as<IntImmNode>()->value;
89 }
90 target_weights_.emplace(k, v);
91 }
92 }
93
94 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
95 if (pre->op == conv2d_op_) {
96 const auto weight = pre->args[1].as<VarNode>();
97 if (weight) {
98 if (target_weights_.count(weight->name_hint())) {
99 const auto& prefix = weight->name_hint();
100 const auto& ws = target_weights_.at(prefix);
101 const auto data = post.as<CallNode>()->args[0];
102 relay::TensorType ws_data_type, ws_indices_type, ws_indptr_type;
103 if (ws.size() == 5) {
104 ws_data_type = relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32));
105 ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32));
106 ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32));
107 } else if (ws.size() == 4) {
108 ws_data_type = relay::TensorType({ws.at(0), ws.at(1)}, DataType::Float(32));
109 ws_indices_type = relay::TensorType({ws.at(2)}, DataType::Int(32));
110 ws_indptr_type = relay::TensorType({ws.at(3)}, DataType::Int(32));
111 }
112 Var weight_data(prefix + ".data", ws_data_type);
113 Var weight_indices(prefix + ".indices", ws_indices_type);
114 Var weight_indptr(prefix + ".indptr", ws_indptr_type);
115 auto attrs = make_object<SparseConv2DAttrs>();
116 attrs->layout = std::move(layout_);
117 attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
118 return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr},
119 Attrs(attrs));
120 }
121 }
122 }
123 return post;
124 }
125
126 private:
127 // Cached op
128 const Op& conv2d_op_;
129 const Op& sparse_conv2d_op_;
130 std::unordered_map<std::string, std::vector<int>> target_weights_;
131 String layout_;
132 int kernel_size_;
133}; // class Conv2dToSparseConv2dAlter
134
135Expr Conv2dToSparse(const Expr& e, const Array<ObjectRef>& weight_name,
136 const Array<Array<PrimExpr>>& weight_shape, const String& layout,
137 int kernel_size) {
138 auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size);
139 return PostOrderRewrite(e, &rewriter);
140}
141
142template <typename elemTy, size_t... Is>
143auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence<Is...>) {
144 return std::make_tuple(arr[Is]...);
145}
146
147template <int N, typename elemTy>
148auto unpack_to_tuple(elemTy* arr) {
149 return unpack_to_tuple_internal(arr, std::make_index_sequence<N>{});
150}
151
152struct Range {
153 size_t dim;
154 explicit Range(size_t d) : dim(d) {}
155
156 struct iterpoint {
157 size_t val, lim;
158 iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {}
159
160 size_t operator*() const { return val; }
161
162 iterpoint operator/(const iterpoint& rhs) const {
163 return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim);
164 }
165 };
166
167 struct iterator {
168 size_t val, lim;
169 iterator(size_t v1, size_t v2) : val(v1), lim(v2) {}
170
171 bool operator!=(const iterator& rhs) const { return val != rhs.val; }
172
173 void operator++() { ++val; }
174
175 iterpoint operator*() const { return iterpoint(val, lim); }
176 };
177
178 iterator begin() { return iterator(0, dim); }
179
180 iterator end() { return iterator(dim, dim); }
181};
182
183// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d```
184class Conv2dToSparseConv2dMutator2 : public ExprRewriter {
185 public:
186 Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW,
187 double sparse_thresh)
188 : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")),
189 dev_cpu0_{DLDeviceType::kDLCPU, 0},
190 layout_(layout),
191 kernel_size_(kernel_size),
192 blockH_(blockH),
193 blockW_(blockW),
194 sparse_thresh_(sparse_thresh) {}
195
196 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
197 // check op type & attrs
198 const auto pre_attrs = pre->attrs.as<Conv2DAttrs>();
199 if (!pre_attrs || pre_attrs->data_layout != layout_ ||
200 pre_attrs->strides[0].as<IntImmNode>()->value != 1 ||
201 pre_attrs->kernel_size[0].as<IntImmNode>()->value != kernel_size_)
202 return post;
203 // check constant weight
204 const auto pre_weight_node = pre->args[1].as<ConstantNode>();
205 if (!pre_weight_node) return post;
206
207 // check weight dtype & shape
208 auto&& pre_weight = pre_weight_node->data;
209 auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32);
210 ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only
211 auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data());
212 int O, I, H, W;
213 if (layout_ == "NCHW") {
214 std::tie(O, I, H, W) = pre_weight_shape;
215 } else { // NHWC
216 std::tie(H, W, I, O) = pre_weight_shape;
217 }
218 int CO = O, CI = H * W * I;
219
220 // copy to vector
221 std::vector<float> pre_weight_data(CO * CI);
222 pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float));
223 if (layout_ == "NHWC") {
224 std::vector<float> tmp(pre_weight_data.size());
225 for (auto i : Range(CO))
226 for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)];
227 std::swap(tmp, pre_weight_data);
228 }
229 // convert to BSR
230 std::vector<float> wdata, block(blockH_ * blockW_);
231 std::vector<int32_t> windices, windptr;
232 for (auto bh : Range(CO / blockH_)) {
233 windptr.push_back(windices.size());
234 for (auto bw : Range(CI / blockW_)) {
235 int cntnnz = 0;
236 for (auto i : Range(blockH_))
237 for (auto j : Range(blockW_)) {
238 auto tmp = pre_weight_data[*(bh / i / bw / j)];
239 if (tmp) cntnnz++;
240 block[*(i / j)] = tmp;
241 }
242 if (cntnnz) {
243 wdata.insert(wdata.end(), block.begin(), block.end());
244 windices.push_back(*bw);
245 }
246 }
247 }
248 windptr.push_back(windices.size());
249 double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size();
250 if (sprate < sparse_thresh_) return post;
251
252 // constrct return data
253 int nnz = windices.size();
254 auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_);
255 auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_);
256 auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_);
257 weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float));
258 weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t));
259 weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t));
260
261 // construct return call
262 auto args = runtime::Array<relay::Expr>{post.as<CallNode>()->args[0], Constant(weight_data),
263 Constant(weight_indices), Constant(weight_indptr)};
264 auto attrs = make_object<SparseConv2DAttrs>();
265 attrs->layout = layout_;
266 attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
267 return Call(sparse_conv2d_op_, args, Attrs(attrs));
268 }
269
270 private:
271 const Op& sparse_conv2d_op_;
272 DLDevice dev_cpu0_;
273 String layout_;
274 int kernel_size_, blockH_, blockW_;
275 double sparse_thresh_;
276}; // class Conv2dToSparseConv2dMutator2
277
278Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW,
279 double sparse_thresh) {
280 auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh);
281 return PostOrderRewrite(e, &rewriter);
282}
283
284namespace transform {
285
286// Convert a model with separate weight info (already sparsified).
287Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimExpr>>& weight_shape,
288 const String& layout, int kernel_size) {
289 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
290 [=](Function f, IRModule m, PassContext pc) {
291 // Remove FreeVar warnings
292 auto f0 =
293 Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size));
294 Array<Var> sparse_params = FreeVars(f0);
295 auto f1 = WithFields(f0, sparse_params);
296 Array<Var> params = FreeVars(f1);
297 for (const auto& var : sparse_params) {
298 params.push_back(var);
299 }
300 return WithFields(f1, params);
301 };
302 return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"});
303}
304
305TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse").set_body_typed(Conv2dToSparse);
306
307// Convert a model with freezed params (sparsified in the pass).
308Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int blockW,
309 double sparse_thresh) {
310 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
311 [=](Function f, IRModule m, PassContext pc) {
312 auto f0 = Downcast<Function>(
313 Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh));
314 return f0;
315 };
316 return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"});
317}
318
319TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2);
320
321} // namespace transform
322
323} // namespace relay
324} // namespace tvm
325