1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/numeric_op.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using shape_inference::DimensionHandle; |
25 | using shape_inference::InferenceContext; |
26 | using shape_inference::ShapeAndType; |
27 | using shape_inference::ShapeHandle; |
28 | |
29 | Status GetVariantInput(InferenceContext* c, int index, |
30 | ShapeAndType* shape_and_type) { |
31 | ShapeHandle variant; |
32 | TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant)); |
33 | auto* shapes_and_types = c->input_handle_shapes_and_types(index); |
34 | if (shapes_and_types == nullptr || shapes_and_types->size() != 1) { |
35 | return errors::InvalidArgument( |
36 | "Unable to access shape and type info from variant input " , index); |
37 | } |
38 | *shape_and_type = shapes_and_types->at(0); |
39 | return OkStatus(); |
40 | } |
41 | |
42 | // Validates that a shape represents a (rank-2) square matrix or a (rank-3) |
43 | // batch of square matrices. |
44 | Status ValidateSquareMatrixShape(InferenceContext* c, |
45 | const ShapeHandle& matrix_shape, |
46 | DimensionHandle* matrix_dimension) { |
47 | ShapeHandle out; |
48 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out)); |
49 | TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out)); |
50 | if (!c->RankKnown(matrix_shape)) { |
51 | return errors::Internal("Sparse matrix has an unknown rank." ); |
52 | } |
53 | |
54 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(matrix_shape, -2), |
55 | c->Dim(matrix_shape, -1), matrix_dimension)); |
56 | return OkStatus(); |
57 | } |
58 | |
59 | REGISTER_OP("SparseTensorToCSRSparseMatrix" ) |
60 | .Input("indices: int64" ) |
61 | .Input("values: T" ) |
62 | .Input("dense_shape: int64" ) |
63 | .Attr("T: {float, double, complex64, complex128}" ) |
64 | .Output("sparse_matrix: variant" ) |
65 | .SetShapeFn([](InferenceContext* c) { |
66 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
67 | c, c->input(0), c->input(1), c->input(2))); |
68 | auto rank = c->Value(c->Dim(c->input(0), 1)); |
69 | ShapeHandle dense_shape; |
70 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &dense_shape)); |
71 | TF_RETURN_IF_ERROR(c->WithRank(dense_shape, rank, &dense_shape)); |
72 | if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || |
73 | c->Rank(dense_shape) > 3) { |
74 | return errors::InvalidArgument( |
75 | "Invalid rank: " , c->Rank(dense_shape), |
76 | ". Expected a known rank of either 2 or 3." ); |
77 | } |
78 | |
79 | DataType dtype; |
80 | TF_RETURN_IF_ERROR(c->GetAttr("T" , &dtype)); |
81 | c->set_output(0, c->Scalar()); |
82 | c->set_output_handle_shapes_and_types(0, |
83 | {ShapeAndType{dense_shape, dtype}}); |
84 | return OkStatus(); |
85 | }); |
86 | |
87 | REGISTER_OP("CSRSparseMatrixToSparseTensor" ) |
88 | .Input("sparse_matrix: variant" ) |
89 | .Output("indices: int64" ) |
90 | .Output("values: type" ) |
91 | .Output("dense_shape: int64" ) |
92 | .Attr("type: {float, double, complex64, complex128}" ) |
93 | .SetShapeFn([](InferenceContext* c) { |
94 | ShapeAndType sparse_matrix_shape_and_type; |
95 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
96 | ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape; |
97 | TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); |
98 | if (!c->RankKnown(sparse_matrix)) { |
99 | return errors::InvalidArgument("sparse_matrix has an unknown rank." ); |
100 | } |
101 | int rank = c->Rank(sparse_matrix); |
102 | ShapeHandle indices = c->Matrix(c->UnknownDim(), rank); |
103 | ShapeHandle values = c->Vector(c->UnknownDim()); |
104 | ShapeHandle dense_shape = c->Vector(rank); |
105 | c->set_output(0, indices); |
106 | c->set_output(1, values); |
107 | c->set_output(2, dense_shape); |
108 | return OkStatus(); |
109 | }); |
110 | |
111 | REGISTER_OP("DenseToCSRSparseMatrix" ) |
112 | .Input("dense_input: T" ) |
113 | .Input("indices: int64" ) |
114 | .Attr("T: {float, double, complex64, complex128}" ) |
115 | .Output("sparse_output: variant" ) |
116 | .SetShapeFn([](InferenceContext* c) { |
117 | ShapeHandle dense_shape = c->input(0); |
118 | if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || |
119 | c->Rank(dense_shape) > 3) { |
120 | return errors::InvalidArgument( |
121 | "Invalid rank of dense: " , c->Rank(dense_shape), |
122 | ". Expected a known rank of either 2 or 3." ); |
123 | } |
124 | auto rank = c->Rank(dense_shape); |
125 | |
126 | ShapeHandle indices = c->input(1); |
127 | if (!c->RankKnown(indices) || c->Rank(indices) != 2) { |
128 | return errors::InvalidArgument( |
129 | "indices must be a matrix; but its rank is not 2: " , |
130 | c->Rank(indices)); |
131 | } |
132 | auto indices_col = c->Dim(indices, 1); |
133 | if (!c->ValueKnown(indices_col) || c->Value(indices_col) != rank) { |
134 | return errors::InvalidArgument( |
135 | "indices.shape[1] must match rank of dense; saw: " , |
136 | c->Value(indices_col), " vs. " , rank); |
137 | } |
138 | ShapeHandle fake_values_vec = c->Vector(c->Dim(indices, 0)); |
139 | ShapeHandle fake_shape_shape = c->Vector(rank); |
140 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
141 | c, indices /*indices_shape*/, fake_values_vec /*values_shape*/, |
142 | fake_shape_shape /*shape_shape*/)); |
143 | DataType dtype; |
144 | TF_RETURN_IF_ERROR(c->GetAttr("T" , &dtype)); |
145 | c->set_output_handle_shapes_and_types(0, |
146 | {ShapeAndType{dense_shape, dtype}}); |
147 | c->set_output(0, c->Scalar()); |
148 | return OkStatus(); |
149 | }); |
150 | |
151 | REGISTER_OP("CSRSparseMatrixToDense" ) |
152 | .Input("sparse_input: variant" ) |
153 | .Output("dense_output: type" ) |
154 | .Attr("type: {float, double, complex64, complex128}" ) |
155 | .SetShapeFn([](InferenceContext* c) { |
156 | ShapeAndType sparse_matrix_shape_and_type; |
157 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
158 | ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape; |
159 | TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); |
160 | if (!c->RankKnown(sparse_matrix)) { |
161 | return errors::InvalidArgument("sparse_matrix has an unknown rank." ); |
162 | } |
163 | c->set_output(0, sparse_matrix); |
164 | return OkStatus(); |
165 | }); |
166 | |
167 | REGISTER_OP("CSRSparseMatrixComponents" ) |
168 | .Input("csr_sparse_matrix: variant" ) |
169 | .Input("index: int32" ) |
170 | .Output("row_ptrs: int32" ) |
171 | .Output("col_inds: int32" ) |
172 | .Output("values: type" ) |
173 | .Attr("type: {float, double, complex64, complex128}" ) |
174 | .SetShapeFn([](InferenceContext* c) { |
175 | ShapeAndType sparse_matrix_shape_and_type; |
176 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
177 | ShapeHandle csr_sparse_matrix = sparse_matrix_shape_and_type.shape; |
178 | TF_RETURN_IF_ERROR( |
179 | c->WithRankAtLeast(csr_sparse_matrix, 2, &csr_sparse_matrix)); |
180 | TF_RETURN_IF_ERROR( |
181 | c->WithRankAtMost(csr_sparse_matrix, 3, &csr_sparse_matrix)); |
182 | ShapeHandle index; |
183 | if (c->Rank(c->input(1)) != 0) { |
184 | return errors::InvalidArgument("index must be a scalar." ); |
185 | } |
186 | if (!c->RankKnown(csr_sparse_matrix)) { |
187 | return errors::InvalidArgument( |
188 | "csr_sparse_matrix has an unknown rank." ); |
189 | } |
190 | auto row_ptrs_dh = c->Dim(csr_sparse_matrix, -2); |
191 | TF_RETURN_IF_ERROR(c->Add(row_ptrs_dh, 1, &row_ptrs_dh)); |
192 | ShapeHandle row_ptrs = c->Vector(row_ptrs_dh); |
193 | c->set_output(0, row_ptrs); |
194 | c->set_output(1, c->Vector(c->UnknownDim())); |
195 | c->set_output(2, c->Vector(c->UnknownDim())); |
196 | return OkStatus(); |
197 | }); |
198 | |
199 | REGISTER_OP("SparseMatrixNNZ" ) |
200 | .Input("sparse_matrix: variant" ) |
201 | .Output("nnz: int32" ) |
202 | .SetShapeFn([](InferenceContext* c) { |
203 | ShapeAndType sparse_matrix_shape_and_type; |
204 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
205 | ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape; |
206 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(sparse_matrix, 2, &sparse_matrix)); |
207 | TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); |
208 | if (!c->RankKnown(sparse_matrix)) { |
209 | return errors::InvalidArgument("sparse_matrix has an unknown rank." ); |
210 | } |
211 | ShapeHandle out; |
212 | if (c->Rank(sparse_matrix) == 3) { |
213 | out = c->Vector(c->Dim(sparse_matrix, 0)); |
214 | } else { |
215 | out = c->Scalar(); |
216 | } |
217 | c->set_output(0, out); |
218 | return OkStatus(); |
219 | }); |
220 | |
221 | REGISTER_OP("SparseMatrixMatMul" ) |
222 | .Input("a: variant" ) |
223 | .Input("b: T" ) |
224 | .Attr("T: type" ) |
225 | .Attr("transpose_a: bool = false" ) |
226 | .Attr("transpose_b: bool = false" ) |
227 | .Attr("adjoint_a: bool = false" ) |
228 | .Attr("adjoint_b: bool = false" ) |
229 | .Attr("transpose_output: bool = false" ) |
230 | .Attr("conjugate_output: bool = false" ) |
231 | .Output("output: T" ) |
232 | .SetShapeFn([](InferenceContext* c) { |
233 | ShapeAndType sparse_matrix_shape_and_type; |
234 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
235 | ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; |
236 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); |
237 | TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); |
238 | if (!c->RankKnown(a_shape)) { |
239 | return errors::Internal("a has an unknown rank." ); |
240 | } |
241 | ShapeHandle b_shape; |
242 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); |
243 | TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape)); |
244 | |
245 | bool transpose_a = false; |
246 | bool transpose_b = false; |
247 | bool transpose_output = false; |
248 | |
249 | // TODO(ebrevdo): Add transpose support. |
250 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_a" , &transpose_a)); |
251 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_b" , &transpose_b)); |
252 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_output" , &transpose_output)); |
253 | |
254 | bool adjoint_a = false; |
255 | bool adjoint_b = false; |
256 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a" , &adjoint_a)); |
257 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b" , &adjoint_b)); |
258 | if (adjoint_a && transpose_a) { |
259 | return errors::InvalidArgument( |
260 | "Only one of adjoint_a and transpose_a may be true." ); |
261 | } |
262 | if (adjoint_b && transpose_b) { |
263 | return errors::InvalidArgument( |
264 | "Only one of adjoint_b and transpose_b may be true." ); |
265 | } |
266 | transpose_a = transpose_a || adjoint_a; |
267 | transpose_b = transpose_b || adjoint_b; |
268 | |
269 | auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2); |
270 | auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1); |
271 | if (transpose_output) { |
272 | std::tie(output_rows, output_cols) = |
273 | std::make_tuple(output_cols, output_rows); |
274 | } |
275 | |
276 | // Batch dims match between inputs. |
277 | ShapeHandle a_batch_dims; |
278 | ShapeHandle b_batch_dims; |
279 | ShapeHandle batch_dims; |
280 | TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); |
281 | TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); |
282 | TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); |
283 | |
284 | // Assert inner dims match. |
285 | shape_inference::DimensionHandle unused; |
286 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1), |
287 | c->Dim(b_shape, transpose_b ? -1 : -2), |
288 | &unused)); |
289 | |
290 | ShapeHandle out; |
291 | TF_RETURN_IF_ERROR(c->Concatenate( |
292 | batch_dims, c->Matrix(output_rows, output_cols), &out)); |
293 | |
294 | c->set_output(0, out); |
295 | return OkStatus(); |
296 | }); |
297 | |
298 | REGISTER_OP("SparseMatrixMul" ) |
299 | .Input("a: variant" ) |
300 | .Input("b: T" ) |
301 | .Attr("T: type" ) |
302 | .Output("output: variant" ) |
303 | .SetShapeFn([](InferenceContext* c) { |
304 | ShapeAndType sparse_matrix_shape_and_type; |
305 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
306 | ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; |
307 | TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); |
308 | if (!c->RankKnown(a_shape)) { |
309 | return errors::Internal("a has an unknown rank." ); |
310 | } |
311 | ShapeHandle b_shape; |
312 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 3, &b_shape)); |
313 | if (!c->RankKnown(b_shape)) { |
314 | return errors::Internal("b has an unknown rank." ); |
315 | } |
316 | ShapeHandle out; |
317 | if (c->Rank(b_shape) == 0) { |
318 | out = a_shape; |
319 | } else if (c->Rank(b_shape) == 3) { |
320 | if (c->Rank(a_shape) != 3) { |
321 | return errors::Unimplemented("rank of b is 3 but rank of a is not." ); |
322 | } |
323 | if (!(c->Value(c->Dim(b_shape, 1)) == 1 && |
324 | c->Value(c->Dim(b_shape, 2)) == 1)) { |
325 | return errors::Unimplemented( |
326 | "b must be a scalar or shaped [batch_size, 1, 1]" ); |
327 | } |
328 | DimensionHandle batch_size = c->Dim(a_shape, 0); |
329 | TF_RETURN_IF_ERROR( |
330 | c->Merge(batch_size, c->Dim(b_shape, 0), &batch_size)); |
331 | TF_RETURN_IF_ERROR(c->ReplaceDim(b_shape, 0, batch_size, &b_shape)); |
332 | TF_RETURN_IF_ERROR(c->ReplaceDim(a_shape, 0, batch_size, &a_shape)); |
333 | out = a_shape; |
334 | } else { |
335 | return errors::Unimplemented( |
336 | "b must be a scalar or shaped [batch_size, 1, 1]" ); |
337 | } |
338 | c->set_output_handle_shapes_and_types( |
339 | 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}}); |
340 | c->set_output(0, c->Scalar()); |
341 | return OkStatus(); |
342 | }); |
343 | |
344 | REGISTER_OP("SparseMatrixAdd" ) |
345 | .Input("a: variant" ) |
346 | .Input("b: variant" ) |
347 | .Input("alpha: T" ) |
348 | .Input("beta: T" ) |
349 | .Attr("T: {float, double, complex64, complex128}" ) |
350 | .Output("c: variant" ) |
351 | .SetShapeFn([](InferenceContext* c) { |
352 | // alpha and beta are scalars. |
353 | ShapeHandle unused_scalar_shape; |
354 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_scalar_shape)); |
355 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_scalar_shape)); |
356 | |
357 | ShapeAndType sparse_matrix_shape_and_type; |
358 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
359 | ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; |
360 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); |
361 | TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); |
362 | if (!c->RankKnown(a_shape)) { |
363 | return errors::InvalidArgument("a has an unknown rank." ); |
364 | } |
365 | |
366 | TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); |
367 | ShapeHandle b_shape = sparse_matrix_shape_and_type.shape; |
368 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape)); |
369 | TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape)); |
370 | if (!c->RankKnown(b_shape)) { |
371 | return errors::InvalidArgument("b has an unknown rank." ); |
372 | } |
373 | ShapeHandle out; |
374 | TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &out)); |
375 | c->set_output_handle_shapes_and_types( |
376 | 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}}); |
377 | c->set_output(0, c->Scalar()); |
378 | return OkStatus(); |
379 | }); |
380 | |
381 | REGISTER_OP("SparseMatrixSparseMatMul" ) |
382 | .Input("a: variant" ) |
383 | .Input("b: variant" ) |
384 | .Attr("type: {float, double, complex64, complex128}" ) |
385 | .Attr("transpose_a: bool = false" ) |
386 | .Attr("transpose_b: bool = false" ) |
387 | .Attr("adjoint_a: bool = false" ) |
388 | .Attr("adjoint_b: bool = false" ) |
389 | .Output("c: variant" ) |
390 | .SetShapeFn([](InferenceContext* c) { |
391 | ShapeAndType sparse_matrix_shape_and_type; |
392 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
393 | ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; |
394 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); |
395 | TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); |
396 | if (!c->RankKnown(a_shape)) { |
397 | return errors::Internal("a has an unknown rank." ); |
398 | } |
399 | |
400 | TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); |
401 | ShapeHandle b_shape = sparse_matrix_shape_and_type.shape; |
402 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape)); |
403 | TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape)); |
404 | if (!c->RankKnown(b_shape)) { |
405 | return errors::Internal("b has an unknown rank." ); |
406 | } |
407 | |
408 | bool transpose_a = false; |
409 | bool transpose_b = false; |
410 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_a" , &transpose_a)); |
411 | TF_RETURN_IF_ERROR(c->GetAttr("transpose_b" , &transpose_b)); |
412 | bool adjoint_a = false; |
413 | bool adjoint_b = false; |
414 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a" , &adjoint_a)); |
415 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b" , &adjoint_b)); |
416 | if (adjoint_a && transpose_a) { |
417 | return errors::InvalidArgument( |
418 | "Only one of adjoint_a and transpose_a may be true." ); |
419 | } else if (adjoint_b && transpose_b) { |
420 | return errors::InvalidArgument( |
421 | "Only one of adjoint_b and transpose_b may be true." ); |
422 | } |
423 | transpose_a = transpose_a || adjoint_a; |
424 | transpose_b = transpose_b || adjoint_b; |
425 | |
426 | auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2); |
427 | auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1); |
428 | |
429 | // Batch dims match between inputs. |
430 | ShapeHandle a_batch_dims; |
431 | ShapeHandle b_batch_dims; |
432 | ShapeHandle batch_dims; |
433 | TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); |
434 | TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); |
435 | TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); |
436 | |
437 | // Assert inner dims match. |
438 | shape_inference::DimensionHandle unused; |
439 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1), |
440 | c->Dim(b_shape, transpose_b ? -1 : -2), |
441 | &unused)); |
442 | |
443 | ShapeHandle out; |
444 | TF_RETURN_IF_ERROR(c->Concatenate( |
445 | batch_dims, c->Matrix(output_rows, output_cols), &out)); |
446 | |
447 | c->set_output_handle_shapes_and_types( |
448 | 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}}); |
449 | c->set_output(0, c->Scalar()); |
450 | return OkStatus(); |
451 | }); |
452 | |
453 | REGISTER_OP("SparseMatrixZeros" ) |
454 | .Input("dense_shape: int64" ) |
455 | .Attr("type: {float, double, complex64, complex128}" ) |
456 | .Output("sparse_matrix: variant" ) |
457 | .SetShapeFn([](InferenceContext* c) { |
458 | auto rank = c->NumElements(c->input(0)); |
459 | ShapeHandle dense_shape; |
460 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &dense_shape)); |
461 | TF_RETURN_IF_ERROR( |
462 | c->WithRank(dense_shape, c->Value(rank), &dense_shape)); |
463 | if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || |
464 | c->Rank(dense_shape) > 3) { |
465 | return errors::InvalidArgument( |
466 | "Invalid rank: " , c->Rank(dense_shape), |
467 | ". Expected a known rank of either 2 or 3." ); |
468 | } |
469 | DataType dtype; |
470 | TF_RETURN_IF_ERROR(c->GetAttr("type" , &dtype)); |
471 | c->set_output_handle_shapes_and_types(0, |
472 | {ShapeAndType{dense_shape, dtype}}); |
473 | c->set_output(0, c->Scalar()); |
474 | return OkStatus(); |
475 | }); |
476 | |
477 | REGISTER_OP("SparseMatrixTranspose" ) |
478 | .Input("input: variant" ) |
479 | .Attr("conjugate: bool = false" ) |
480 | .Attr("type: {float, double, complex64, complex128}" ) |
481 | .Output("output: variant" ) |
482 | .SetShapeFn([](InferenceContext* c) { |
483 | ShapeAndType sparse_matrix_shape_and_type; |
484 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
485 | ShapeHandle input = sparse_matrix_shape_and_type.shape; |
486 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input)); |
487 | TF_RETURN_IF_ERROR(c->WithRankAtMost(input, 3, &input)); |
488 | if (!c->RankKnown(input)) { |
489 | return errors::InvalidArgument("input has an unknown rank." ); |
490 | } |
491 | ShapeHandle output; |
492 | if (c->Rank(input) == 2) { |
493 | output = c->Matrix(c->Dim(input, 1), c->Dim(input, 0)); |
494 | } else { |
495 | output = c->MakeShape( |
496 | {c->Dim(input, 0), c->Dim(input, 2), c->Dim(input, 1)}); |
497 | } |
498 | c->set_output_handle_shapes_and_types( |
499 | 0, {ShapeAndType{output, sparse_matrix_shape_and_type.dtype}}); |
500 | c->set_output(0, c->Scalar()); |
501 | |
502 | return OkStatus(); |
503 | }); |
504 | |
505 | REGISTER_OP("SparseMatrixSoftmax" ) |
506 | .Input("logits: variant" ) |
507 | .Attr("type: {float, double}" ) |
508 | .Output("softmax: variant" ) |
509 | .SetShapeFn([](InferenceContext* c) { |
510 | ShapeAndType sparse_matrix_shape_and_type; |
511 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
512 | ShapeHandle logits = sparse_matrix_shape_and_type.shape; |
513 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(logits, 2, &logits)); |
514 | TF_RETURN_IF_ERROR(c->WithRankAtMost(logits, 3, &logits)); |
515 | if (!c->RankKnown(logits)) { |
516 | return errors::InvalidArgument("logits has an unknown rank." ); |
517 | } |
518 | c->set_output_handle_shapes_and_types( |
519 | 0, {ShapeAndType{logits, sparse_matrix_shape_and_type.dtype}}); |
520 | c->set_output(0, c->Scalar()); |
521 | return OkStatus(); |
522 | }); |
523 | |
524 | REGISTER_OP("SparseMatrixSoftmaxGrad" ) |
525 | .Input("softmax: variant" ) |
526 | .Input("grad_softmax: variant" ) |
527 | .Attr("type: {float, double}" ) |
528 | .Output("gradient: variant" ) |
529 | .SetShapeFn([](InferenceContext* c) { |
530 | ShapeAndType sparse_matrix_shape_and_type; |
531 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
532 | ShapeHandle softmax = sparse_matrix_shape_and_type.shape; |
533 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(softmax, 2, &softmax)); |
534 | TF_RETURN_IF_ERROR(c->WithRankAtMost(softmax, 3, &softmax)); |
535 | if (!c->RankKnown(softmax)) { |
536 | return errors::InvalidArgument("softmax has an unknown rank." ); |
537 | } |
538 | TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); |
539 | ShapeHandle grad_softmax = sparse_matrix_shape_and_type.shape; |
540 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(grad_softmax, 2, &grad_softmax)); |
541 | TF_RETURN_IF_ERROR(c->WithRankAtMost(grad_softmax, 3, &grad_softmax)); |
542 | if (!c->RankKnown(grad_softmax)) { |
543 | return errors::InvalidArgument("grad_softmax has an unknown rank." ); |
544 | } |
545 | TF_RETURN_IF_ERROR(c->Merge(softmax, grad_softmax, &softmax)); |
546 | c->set_output_handle_shapes_and_types( |
547 | 0, {ShapeAndType{softmax, sparse_matrix_shape_and_type.dtype}}); |
548 | c->set_output(0, c->Scalar()); |
549 | return OkStatus(); |
550 | }); |
551 | |
552 | REGISTER_OP("SparseMatrixOrderingAMD" ) |
553 | .Input("input: variant" ) |
554 | .Output("output: int32" ) |
555 | .SetShapeFn([](InferenceContext* c) { |
556 | ShapeAndType sparse_matrix_shape_and_type; |
557 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
558 | ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape; |
559 | DimensionHandle n; |
560 | TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n)); |
561 | |
562 | ShapeHandle output; |
563 | if (c->Rank(matrix_shape) == 2) { |
564 | output = c->Vector(c->Dim(matrix_shape, 0)); |
565 | } else { |
566 | output = c->Matrix(c->Dim(matrix_shape, 0), c->Dim(matrix_shape, 1)); |
567 | } |
568 | c->set_output(0, output); |
569 | return OkStatus(); |
570 | }); |
571 | |
572 | REGISTER_OP("SparseMatrixSparseCholesky" ) |
573 | .Input("input: variant" ) |
574 | .Input("permutation: int32" ) |
575 | .Attr("type: {float, double, complex64, complex128}" ) |
576 | .Output("output: variant" ) |
577 | .SetShapeFn([](InferenceContext* c) { |
578 | ShapeAndType sparse_matrix_shape_and_type; |
579 | TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); |
580 | ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape; |
581 | DimensionHandle n; |
582 | TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n)); |
583 | |
584 | ShapeHandle perm_shape; |
585 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &perm_shape)); |
586 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &perm_shape)); |
587 | if (!c->RankKnown(perm_shape)) { |
588 | return errors::Internal("permutation has an unknown rank." ); |
589 | } |
590 | |
591 | // Each batch component of permutation must have the same number of |
592 | // elements as number of rows of sparse_matrix. |
593 | TF_RETURN_IF_ERROR(c->Merge(n, c->Dim(perm_shape, -1), &n)); |
594 | ShapeHandle matrix_batch_shape; |
595 | ShapeHandle perm_batch_shape; |
596 | |
597 | // Make the common batch subshape. |
598 | TF_RETURN_IF_ERROR(c->Subshape(matrix_shape, 0, -2, &matrix_batch_shape)); |
599 | TF_RETURN_IF_ERROR(c->Subshape(perm_shape, 0, -1, &perm_shape)); |
600 | // Make sure the batch dimensions match between sparse_matrix and |
601 | // permutation. |
602 | TF_RETURN_IF_ERROR( |
603 | c->Merge(matrix_batch_shape, perm_batch_shape, &matrix_batch_shape)); |
604 | |
605 | ShapeHandle out = matrix_shape; |
606 | c->set_output_handle_shapes_and_types( |
607 | 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}}); |
608 | c->set_output(0, c->Scalar()); |
609 | |
610 | return OkStatus(); |
611 | }); |
612 | |
613 | } // namespace tensorflow |
614 | |