1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file sparse.cc
22 * \brief Property def of nn.sparse_dense operator.
23 */
24
25#include <tvm/relay/attrs/nn.h>
26#include <tvm/relay/op.h>
27#include <tvm/tir/data_layout.h>
28
29#include <string>
30#include <vector>
31
32#include "../../transforms/infer_layout_utils.h"
33
34namespace tvm {
35namespace relay {
36
37// relay.nn.sparse_dense
38TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);
39
40bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
41 const TypeReporter& reporter) {
42 ICHECK_EQ(types.size(), 5);
43 const auto* param = attrs.as<SparseDenseAttrs>();
44 ICHECK(param != nullptr);
45
46 if (param->sparse_lhs) {
47 const auto* weight = types[0].as<TensorTypeNode>();
48 const auto* data_data = types[1].as<TensorTypeNode>();
49 ICHECK(data_data->shape.size() == 1 || data_data->shape.size() == 3);
50 const auto* data_indptr = types[3].as<TensorTypeNode>();
51 if (weight == nullptr) return false;
52
53 if (data_data->shape.size() == 1) {
54 // CSR case.
55 Array<IndexExpr> oshape({data_indptr->shape[0] - 1, weight->shape[0]});
56 reporter->Assign(types[4], TensorType(oshape, weight->dtype));
57 return true;
58 }
59
60 if (data_data->shape.size() == 3) {
61 // BSR case.
62 Array<IndexExpr> oshape(
63 {(data_indptr->shape[0] - 1) * data_data->shape[1], weight->shape[0]});
64 reporter->Assign(types[4], TensorType(oshape, weight->dtype));
65 return true;
66 }
67 LOG(FATAL) << "Unknown data ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
68
69 } else {
70 const auto* data = types[0].as<TensorTypeNode>();
71 const auto* weight_data = types[1].as<TensorTypeNode>();
72 ICHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
73 const auto* weight_indptr = types[3].as<TensorTypeNode>();
74 if (data == nullptr) return false;
75
76 if (weight_data->shape.size() == 1) {
77 // CSR case.
78 Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
79 reporter->Assign(types[4], TensorType(oshape, data->dtype));
80 return true;
81 }
82
83 if (weight_data->shape.size() == 3) {
84 // BSR case.
85 Array<IndexExpr> oshape(
86 {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]});
87 reporter->Assign(types[4], TensorType(oshape, data->dtype));
88 return true;
89 }
90 LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
91 }
92}
93
94// Positional relay function to create dense operator used by frontend FFI.
95Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr,
96 bool sparse_lhs) {
97 auto attrs = make_object<SparseDenseAttrs>();
98 attrs->sparse_lhs = std::move(sparse_lhs);
99 static const Op& op = Op::Get("nn.sparse_dense");
100 return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
101}
102
103TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense").set_body_typed(MakeSparseDense);
104
105RELAY_REGISTER_OP("nn.sparse_dense")
106 .describe(
107 R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with either X or W sparse.
108
109- **data**: `(x1, x2, ..., xn, input_dim)`
110- **weight**: `(units, input_dim)`
111- **out**: `(x1, x2, ..., xn, units)`.
112
113)code" TVM_ADD_FILELINE)
114 .set_attrs_type<SparseDenseAttrs>()
115 .set_num_inputs(4)
116 .add_argument("dense_data", "nD Tensor", "Input dense data.")
117 .add_argument("sparse_data", "1D or 3D Tensor", "Sparse data matrix.")
118 .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
119 .add_argument("sparse_indptr", "1D Tensor", "Sparse indptr matrix.")
120 .set_support_level(1)
121 .add_type_rel("SparseDense", SparseDenseRel)
122 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
123
124Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
125 auto attrs = make_object<SparseDenseAttrs>();
126 static const Op& op = Op::Get("nn.internal.sparse_dense_padded");
127 return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
128}
129
130TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded").set_body_typed(MakeSparseDensePadded);
131
132RELAY_REGISTER_OP("nn.internal.sparse_dense_padded")
133 .describe(
134 R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with W
135sparse. This variation uses a matrix with row lengths padded to a
136multiple of 32 for better GPU performance.
137
138This op should not be directly used by a user. Instead, use `sparse_dense`
139which will be converted to this op when running on the GPU.
140
141- **data**: `(x1, x2, ..., xn, input_dim)`
142- **weight**: `(units, input_dim)`
143- **out**: `(x1, x2, ..., xn, units)`.
144
145)code" TVM_ADD_FILELINE)
146 .set_attrs_type<SparseDenseAttrs>()
147 .set_num_inputs(4)
148 .add_argument("data", "nD Tensor", "Input data.")
149 .add_argument("weight_data", "1D Tensor", "Weight data matrix.")
150 .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
151 .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
152 .set_support_level(1)
153 .add_type_rel("SparseDense", SparseDenseRel)
154 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
155
156// relay.nn.sparse_transpose
157TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs);
158
159bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
160 const TypeReporter& reporter) {
161 ICHECK_EQ(types.size(), 4);
162 const auto* sparse_data = types[0].as<TensorTypeNode>();
163 ICHECK_EQ(sparse_data->shape.size(), 1);
164 const auto* sparse_indices = types[1].as<TensorTypeNode>();
165 ICHECK_EQ(sparse_indices->shape.size(), 1);
166 const auto* sparse_indptr = types[2].as<TensorTypeNode>();
167
168 std::vector<Type> output_types;
169 output_types.push_back(TensorType(sparse_data->shape, sparse_data->dtype));
170 output_types.push_back(TensorType(sparse_indices->shape, sparse_indices->dtype));
171 output_types.push_back(TensorType(sparse_indptr->shape, sparse_indptr->dtype));
172
173 reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
174 return true;
175}
176
177Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
178 auto attrs = make_object<SparseTransposeAttrs>();
179 static const Op& op = Op::Get("nn.sparse_transpose");
180 return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
181}
182
183TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose").set_body_typed(MakeSparseTranspose);
184
185RELAY_REGISTER_OP("nn.sparse_transpose")
186 .describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix
187
188- **input**: `(N, N)`
189- **out**: `(N, N)`.
190
191)code" TVM_ADD_FILELINE)
192 .set_attrs_type<SparseTransposeAttrs>()
193 .set_num_inputs(3)
194 .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
195 .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
196 .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
197 .set_support_level(1)
198 .add_type_rel("SparseTranspose", SparseTransposeRel)
199 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
200
201// relay.nn.sparse_add
202bool SparseAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
203 const TypeReporter& reporter) {
204 ICHECK_EQ(types.size(), 5) << "expecting 4 inputs and 1 output.";
205 const auto* dense_data = types[0].as<TensorTypeNode>();
206 const auto* sparse_data = types[1].as<TensorTypeNode>();
207 ICHECK(reporter->Assert(sparse_data->dtype == dense_data->dtype))
208 << "sparse tensor and dense tensor datatype should match.";
209 ICHECK(reporter->Assert(sparse_data->shape.size() == 1)) << "sparse data tensor should be 1D.";
210 const auto* sparse_indices = types[2].as<TensorTypeNode>();
211 ICHECK(reporter->Assert(sparse_indices->shape.size() == 1))
212 << "sparse indices tensor should be 1D.";
213
214 reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype));
215 return true;
216}
217
218Expr MakeSparseAdd(Expr dense_data, Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
219 static const Op& op = Op::Get("nn.sparse_add");
220 return Call(op, {dense_data, sparse_data, sparse_indices, sparse_indptr}, Attrs(), {});
221}
222
223TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_add").set_body_typed(MakeSparseAdd);
224
225RELAY_REGISTER_OP("nn.sparse_add")
226 .describe(R"code(Add a dense matrix X with sparse matrix Y.
227
228- **dense**: `(M, N)`
229- **sparse**: `(M, N)`
230
231- **out**: `(M, N)`.
232
233)code" TVM_ADD_FILELINE)
234 .set_num_inputs(4)
235 .add_argument("dense_data", "2D Tensor", "Dense data matrix.")
236 .add_argument("sparse_data", "1D Tensor", "Sparse data vector.")
237 .add_argument("sparse_indices", "1D Tensor", "Sparse indices vector.")
238 .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer vector.")
239 .set_support_level(1)
240 .add_type_rel("SparseAdd", SparseAddRel)
241 .set_attr<TOpPattern>("TOpPattern", kOpaque);
242
243TVM_REGISTER_NODE_TYPE(SparseConv2DAttrs);
244
245bool SparseConv2dRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
246 const TypeReporter& reporter) {
247 ICHECK_EQ(types.size(), 5);
248 const auto* param = attrs.as<SparseConv2DAttrs>();
249 ICHECK(param != nullptr);
250
251 const auto* data = types[0].as<TensorTypeNode>();
252 const auto* weight_data = types[1].as<TensorTypeNode>();
253 ICHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 2 ||
254 weight_data->shape.size() == 3);
255 const auto* weight_indptr = types[3].as<TensorTypeNode>();
256 if (data == nullptr) return false;
257
258 if (weight_data->shape.size() == 2 || weight_data->shape.size() == 3) {
259 // BSR case.
260 if (param->layout == "NHWC") {
261 Array<IndexExpr> oshape({data->shape[0], data->shape[1], data->shape[2],
262 (weight_indptr->shape[0] - 1) * weight_data->shape[1]});
263 reporter->Assign(types[4], TensorType(oshape, data->dtype));
264 return true;
265 } else if (param->layout == "NCHW") {
266 Array<IndexExpr> oshape({data->shape[0],
267 (weight_indptr->shape[0] - 1) * weight_data->shape[1],
268 data->shape[2], data->shape[3]});
269 reporter->Assign(types[4], TensorType(oshape, data->dtype));
270 return true;
271 }
272 }
273 LOG(FATAL) << "Unknown weight ndim " << weight_data->shape.size()
274 << " for nn.sparse_conv2d, should be 2 or 3 (BSR)";
275 return false;
276}
277
278Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr,
279 std::string layout, Array<IndexExpr> kernel_size) {
280 static const Op& op = Op::Get("nn.sparse_conv2d");
281 auto attrs = make_object<SparseConv2DAttrs>();
282 attrs->layout = std::move(layout);
283 attrs->kernel_size = std::move(kernel_size);
284 return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
285}
286
287TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_conv2d").set_body_typed(MakeSparseConv2d);
288
289RELAY_REGISTER_OP("nn.sparse_conv2d")
290 .describe(
291 R"code(Applies a sparse convolution :math:`Y = X*W^T` with W sparse.
292
293- **data**: `(x1, x2, ..., xn, input_dim)`
294- **weight**: `(units, input_dim)`
295- **out**: `(x1, x2, ..., xn, units)`.
296
297)code" TVM_ADD_FILELINE)
298 .set_attrs_type<SparseConv2DAttrs>()
299 .set_num_inputs(4)
300 .add_argument("dense_data", "nD Tensor", "Input dense data.")
301 .add_argument("sparse_data", "1D or 3D Tensor", "Sparse data matrix.")
302 .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
303 .add_argument("sparse_indptr", "1D Tensor", "Sparse indptr matrix.")
304 .set_support_level(1)
305 .add_type_rel("SparseConv2d", SparseConv2dRel)
306 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
307
308} // namespace relay
309} // namespace tvm
310