1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/numeric_op.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20#include "tensorflow/core/lib/core/errors.h"
21
22namespace tensorflow {
23
24using shape_inference::DimensionHandle;
25using shape_inference::InferenceContext;
26using shape_inference::ShapeAndType;
27using shape_inference::ShapeHandle;
28
29Status GetVariantInput(InferenceContext* c, int index,
30 ShapeAndType* shape_and_type) {
31 ShapeHandle variant;
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant));
33 auto* shapes_and_types = c->input_handle_shapes_and_types(index);
34 if (shapes_and_types == nullptr || shapes_and_types->size() != 1) {
35 return errors::InvalidArgument(
36 "Unable to access shape and type info from variant input ", index);
37 }
38 *shape_and_type = shapes_and_types->at(0);
39 return OkStatus();
40}
41
42// Validates that a shape represents a (rank-2) square matrix or a (rank-3)
43// batch of square matrices.
44Status ValidateSquareMatrixShape(InferenceContext* c,
45 const ShapeHandle& matrix_shape,
46 DimensionHandle* matrix_dimension) {
47 ShapeHandle out;
48 TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out));
49 TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out));
50 if (!c->RankKnown(matrix_shape)) {
51 return errors::Internal("Sparse matrix has an unknown rank.");
52 }
53
54 TF_RETURN_IF_ERROR(c->Merge(c->Dim(matrix_shape, -2),
55 c->Dim(matrix_shape, -1), matrix_dimension));
56 return OkStatus();
57}
58
59REGISTER_OP("SparseTensorToCSRSparseMatrix")
60 .Input("indices: int64")
61 .Input("values: T")
62 .Input("dense_shape: int64")
63 .Attr("T: {float, double, complex64, complex128}")
64 .Output("sparse_matrix: variant")
65 .SetShapeFn([](InferenceContext* c) {
66 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
67 c, c->input(0), c->input(1), c->input(2)));
68 auto rank = c->Value(c->Dim(c->input(0), 1));
69 ShapeHandle dense_shape;
70 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &dense_shape));
71 TF_RETURN_IF_ERROR(c->WithRank(dense_shape, rank, &dense_shape));
72 if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
73 c->Rank(dense_shape) > 3) {
74 return errors::InvalidArgument(
75 "Invalid rank: ", c->Rank(dense_shape),
76 ". Expected a known rank of either 2 or 3.");
77 }
78
79 DataType dtype;
80 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
81 c->set_output(0, c->Scalar());
82 c->set_output_handle_shapes_and_types(0,
83 {ShapeAndType{dense_shape, dtype}});
84 return OkStatus();
85 });
86
87REGISTER_OP("CSRSparseMatrixToSparseTensor")
88 .Input("sparse_matrix: variant")
89 .Output("indices: int64")
90 .Output("values: type")
91 .Output("dense_shape: int64")
92 .Attr("type: {float, double, complex64, complex128}")
93 .SetShapeFn([](InferenceContext* c) {
94 ShapeAndType sparse_matrix_shape_and_type;
95 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
96 ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
97 TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
98 if (!c->RankKnown(sparse_matrix)) {
99 return errors::InvalidArgument("sparse_matrix has an unknown rank.");
100 }
101 int rank = c->Rank(sparse_matrix);
102 ShapeHandle indices = c->Matrix(c->UnknownDim(), rank);
103 ShapeHandle values = c->Vector(c->UnknownDim());
104 ShapeHandle dense_shape = c->Vector(rank);
105 c->set_output(0, indices);
106 c->set_output(1, values);
107 c->set_output(2, dense_shape);
108 return OkStatus();
109 });
110
111REGISTER_OP("DenseToCSRSparseMatrix")
112 .Input("dense_input: T")
113 .Input("indices: int64")
114 .Attr("T: {float, double, complex64, complex128}")
115 .Output("sparse_output: variant")
116 .SetShapeFn([](InferenceContext* c) {
117 ShapeHandle dense_shape = c->input(0);
118 if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
119 c->Rank(dense_shape) > 3) {
120 return errors::InvalidArgument(
121 "Invalid rank of dense: ", c->Rank(dense_shape),
122 ". Expected a known rank of either 2 or 3.");
123 }
124 auto rank = c->Rank(dense_shape);
125
126 ShapeHandle indices = c->input(1);
127 if (!c->RankKnown(indices) || c->Rank(indices) != 2) {
128 return errors::InvalidArgument(
129 "indices must be a matrix; but its rank is not 2: ",
130 c->Rank(indices));
131 }
132 auto indices_col = c->Dim(indices, 1);
133 if (!c->ValueKnown(indices_col) || c->Value(indices_col) != rank) {
134 return errors::InvalidArgument(
135 "indices.shape[1] must match rank of dense; saw: ",
136 c->Value(indices_col), " vs. ", rank);
137 }
138 ShapeHandle fake_values_vec = c->Vector(c->Dim(indices, 0));
139 ShapeHandle fake_shape_shape = c->Vector(rank);
140 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
141 c, indices /*indices_shape*/, fake_values_vec /*values_shape*/,
142 fake_shape_shape /*shape_shape*/));
143 DataType dtype;
144 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
145 c->set_output_handle_shapes_and_types(0,
146 {ShapeAndType{dense_shape, dtype}});
147 c->set_output(0, c->Scalar());
148 return OkStatus();
149 });
150
151REGISTER_OP("CSRSparseMatrixToDense")
152 .Input("sparse_input: variant")
153 .Output("dense_output: type")
154 .Attr("type: {float, double, complex64, complex128}")
155 .SetShapeFn([](InferenceContext* c) {
156 ShapeAndType sparse_matrix_shape_and_type;
157 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
158 ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
159 TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
160 if (!c->RankKnown(sparse_matrix)) {
161 return errors::InvalidArgument("sparse_matrix has an unknown rank.");
162 }
163 c->set_output(0, sparse_matrix);
164 return OkStatus();
165 });
166
167REGISTER_OP("CSRSparseMatrixComponents")
168 .Input("csr_sparse_matrix: variant")
169 .Input("index: int32")
170 .Output("row_ptrs: int32")
171 .Output("col_inds: int32")
172 .Output("values: type")
173 .Attr("type: {float, double, complex64, complex128}")
174 .SetShapeFn([](InferenceContext* c) {
175 ShapeAndType sparse_matrix_shape_and_type;
176 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
177 ShapeHandle csr_sparse_matrix = sparse_matrix_shape_and_type.shape;
178 TF_RETURN_IF_ERROR(
179 c->WithRankAtLeast(csr_sparse_matrix, 2, &csr_sparse_matrix));
180 TF_RETURN_IF_ERROR(
181 c->WithRankAtMost(csr_sparse_matrix, 3, &csr_sparse_matrix));
182 ShapeHandle index;
183 if (c->Rank(c->input(1)) != 0) {
184 return errors::InvalidArgument("index must be a scalar.");
185 }
186 if (!c->RankKnown(csr_sparse_matrix)) {
187 return errors::InvalidArgument(
188 "csr_sparse_matrix has an unknown rank.");
189 }
190 auto row_ptrs_dh = c->Dim(csr_sparse_matrix, -2);
191 TF_RETURN_IF_ERROR(c->Add(row_ptrs_dh, 1, &row_ptrs_dh));
192 ShapeHandle row_ptrs = c->Vector(row_ptrs_dh);
193 c->set_output(0, row_ptrs);
194 c->set_output(1, c->Vector(c->UnknownDim()));
195 c->set_output(2, c->Vector(c->UnknownDim()));
196 return OkStatus();
197 });
198
199REGISTER_OP("SparseMatrixNNZ")
200 .Input("sparse_matrix: variant")
201 .Output("nnz: int32")
202 .SetShapeFn([](InferenceContext* c) {
203 ShapeAndType sparse_matrix_shape_and_type;
204 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
205 ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
206 TF_RETURN_IF_ERROR(c->WithRankAtLeast(sparse_matrix, 2, &sparse_matrix));
207 TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
208 if (!c->RankKnown(sparse_matrix)) {
209 return errors::InvalidArgument("sparse_matrix has an unknown rank.");
210 }
211 ShapeHandle out;
212 if (c->Rank(sparse_matrix) == 3) {
213 out = c->Vector(c->Dim(sparse_matrix, 0));
214 } else {
215 out = c->Scalar();
216 }
217 c->set_output(0, out);
218 return OkStatus();
219 });
220
221REGISTER_OP("SparseMatrixMatMul")
222 .Input("a: variant")
223 .Input("b: T")
224 .Attr("T: type")
225 .Attr("transpose_a: bool = false")
226 .Attr("transpose_b: bool = false")
227 .Attr("adjoint_a: bool = false")
228 .Attr("adjoint_b: bool = false")
229 .Attr("transpose_output: bool = false")
230 .Attr("conjugate_output: bool = false")
231 .Output("output: T")
232 .SetShapeFn([](InferenceContext* c) {
233 ShapeAndType sparse_matrix_shape_and_type;
234 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
235 ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
236 TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
237 TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
238 if (!c->RankKnown(a_shape)) {
239 return errors::Internal("a has an unknown rank.");
240 }
241 ShapeHandle b_shape;
242 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
243 TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
244
245 bool transpose_a = false;
246 bool transpose_b = false;
247 bool transpose_output = false;
248
249 // TODO(ebrevdo): Add transpose support.
250 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
251 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
252 TF_RETURN_IF_ERROR(c->GetAttr("transpose_output", &transpose_output));
253
254 bool adjoint_a = false;
255 bool adjoint_b = false;
256 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
257 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
258 if (adjoint_a && transpose_a) {
259 return errors::InvalidArgument(
260 "Only one of adjoint_a and transpose_a may be true.");
261 }
262 if (adjoint_b && transpose_b) {
263 return errors::InvalidArgument(
264 "Only one of adjoint_b and transpose_b may be true.");
265 }
266 transpose_a = transpose_a || adjoint_a;
267 transpose_b = transpose_b || adjoint_b;
268
269 auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
270 auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
271 if (transpose_output) {
272 std::tie(output_rows, output_cols) =
273 std::make_tuple(output_cols, output_rows);
274 }
275
276 // Batch dims match between inputs.
277 ShapeHandle a_batch_dims;
278 ShapeHandle b_batch_dims;
279 ShapeHandle batch_dims;
280 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
281 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
282 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
283
284 // Assert inner dims match.
285 shape_inference::DimensionHandle unused;
286 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
287 c->Dim(b_shape, transpose_b ? -1 : -2),
288 &unused));
289
290 ShapeHandle out;
291 TF_RETURN_IF_ERROR(c->Concatenate(
292 batch_dims, c->Matrix(output_rows, output_cols), &out));
293
294 c->set_output(0, out);
295 return OkStatus();
296 });
297
298REGISTER_OP("SparseMatrixMul")
299 .Input("a: variant")
300 .Input("b: T")
301 .Attr("T: type")
302 .Output("output: variant")
303 .SetShapeFn([](InferenceContext* c) {
304 ShapeAndType sparse_matrix_shape_and_type;
305 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
306 ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
307 TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
308 if (!c->RankKnown(a_shape)) {
309 return errors::Internal("a has an unknown rank.");
310 }
311 ShapeHandle b_shape;
312 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 3, &b_shape));
313 if (!c->RankKnown(b_shape)) {
314 return errors::Internal("b has an unknown rank.");
315 }
316 ShapeHandle out;
317 if (c->Rank(b_shape) == 0) {
318 out = a_shape;
319 } else if (c->Rank(b_shape) == 3) {
320 if (c->Rank(a_shape) != 3) {
321 return errors::Unimplemented("rank of b is 3 but rank of a is not.");
322 }
323 if (!(c->Value(c->Dim(b_shape, 1)) == 1 &&
324 c->Value(c->Dim(b_shape, 2)) == 1)) {
325 return errors::Unimplemented(
326 "b must be a scalar or shaped [batch_size, 1, 1]");
327 }
328 DimensionHandle batch_size = c->Dim(a_shape, 0);
329 TF_RETURN_IF_ERROR(
330 c->Merge(batch_size, c->Dim(b_shape, 0), &batch_size));
331 TF_RETURN_IF_ERROR(c->ReplaceDim(b_shape, 0, batch_size, &b_shape));
332 TF_RETURN_IF_ERROR(c->ReplaceDim(a_shape, 0, batch_size, &a_shape));
333 out = a_shape;
334 } else {
335 return errors::Unimplemented(
336 "b must be a scalar or shaped [batch_size, 1, 1]");
337 }
338 c->set_output_handle_shapes_and_types(
339 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
340 c->set_output(0, c->Scalar());
341 return OkStatus();
342 });
343
344REGISTER_OP("SparseMatrixAdd")
345 .Input("a: variant")
346 .Input("b: variant")
347 .Input("alpha: T")
348 .Input("beta: T")
349 .Attr("T: {float, double, complex64, complex128}")
350 .Output("c: variant")
351 .SetShapeFn([](InferenceContext* c) {
352 // alpha and beta are scalars.
353 ShapeHandle unused_scalar_shape;
354 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_scalar_shape));
355 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_scalar_shape));
356
357 ShapeAndType sparse_matrix_shape_and_type;
358 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
359 ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
360 TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
361 TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
362 if (!c->RankKnown(a_shape)) {
363 return errors::InvalidArgument("a has an unknown rank.");
364 }
365
366 TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
367 ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
368 TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
369 TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
370 if (!c->RankKnown(b_shape)) {
371 return errors::InvalidArgument("b has an unknown rank.");
372 }
373 ShapeHandle out;
374 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &out));
375 c->set_output_handle_shapes_and_types(
376 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
377 c->set_output(0, c->Scalar());
378 return OkStatus();
379 });
380
381REGISTER_OP("SparseMatrixSparseMatMul")
382 .Input("a: variant")
383 .Input("b: variant")
384 .Attr("type: {float, double, complex64, complex128}")
385 .Attr("transpose_a: bool = false")
386 .Attr("transpose_b: bool = false")
387 .Attr("adjoint_a: bool = false")
388 .Attr("adjoint_b: bool = false")
389 .Output("c: variant")
390 .SetShapeFn([](InferenceContext* c) {
391 ShapeAndType sparse_matrix_shape_and_type;
392 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
393 ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
394 TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
395 TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
396 if (!c->RankKnown(a_shape)) {
397 return errors::Internal("a has an unknown rank.");
398 }
399
400 TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
401 ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
402 TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
403 TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
404 if (!c->RankKnown(b_shape)) {
405 return errors::Internal("b has an unknown rank.");
406 }
407
408 bool transpose_a = false;
409 bool transpose_b = false;
410 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
411 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
412 bool adjoint_a = false;
413 bool adjoint_b = false;
414 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
415 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
416 if (adjoint_a && transpose_a) {
417 return errors::InvalidArgument(
418 "Only one of adjoint_a and transpose_a may be true.");
419 } else if (adjoint_b && transpose_b) {
420 return errors::InvalidArgument(
421 "Only one of adjoint_b and transpose_b may be true.");
422 }
423 transpose_a = transpose_a || adjoint_a;
424 transpose_b = transpose_b || adjoint_b;
425
426 auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
427 auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
428
429 // Batch dims match between inputs.
430 ShapeHandle a_batch_dims;
431 ShapeHandle b_batch_dims;
432 ShapeHandle batch_dims;
433 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
434 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
435 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
436
437 // Assert inner dims match.
438 shape_inference::DimensionHandle unused;
439 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
440 c->Dim(b_shape, transpose_b ? -1 : -2),
441 &unused));
442
443 ShapeHandle out;
444 TF_RETURN_IF_ERROR(c->Concatenate(
445 batch_dims, c->Matrix(output_rows, output_cols), &out));
446
447 c->set_output_handle_shapes_and_types(
448 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
449 c->set_output(0, c->Scalar());
450 return OkStatus();
451 });
452
453REGISTER_OP("SparseMatrixZeros")
454 .Input("dense_shape: int64")
455 .Attr("type: {float, double, complex64, complex128}")
456 .Output("sparse_matrix: variant")
457 .SetShapeFn([](InferenceContext* c) {
458 auto rank = c->NumElements(c->input(0));
459 ShapeHandle dense_shape;
460 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &dense_shape));
461 TF_RETURN_IF_ERROR(
462 c->WithRank(dense_shape, c->Value(rank), &dense_shape));
463 if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
464 c->Rank(dense_shape) > 3) {
465 return errors::InvalidArgument(
466 "Invalid rank: ", c->Rank(dense_shape),
467 ". Expected a known rank of either 2 or 3.");
468 }
469 DataType dtype;
470 TF_RETURN_IF_ERROR(c->GetAttr("type", &dtype));
471 c->set_output_handle_shapes_and_types(0,
472 {ShapeAndType{dense_shape, dtype}});
473 c->set_output(0, c->Scalar());
474 return OkStatus();
475 });
476
477REGISTER_OP("SparseMatrixTranspose")
478 .Input("input: variant")
479 .Attr("conjugate: bool = false")
480 .Attr("type: {float, double, complex64, complex128}")
481 .Output("output: variant")
482 .SetShapeFn([](InferenceContext* c) {
483 ShapeAndType sparse_matrix_shape_and_type;
484 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
485 ShapeHandle input = sparse_matrix_shape_and_type.shape;
486 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
487 TF_RETURN_IF_ERROR(c->WithRankAtMost(input, 3, &input));
488 if (!c->RankKnown(input)) {
489 return errors::InvalidArgument("input has an unknown rank.");
490 }
491 ShapeHandle output;
492 if (c->Rank(input) == 2) {
493 output = c->Matrix(c->Dim(input, 1), c->Dim(input, 0));
494 } else {
495 output = c->MakeShape(
496 {c->Dim(input, 0), c->Dim(input, 2), c->Dim(input, 1)});
497 }
498 c->set_output_handle_shapes_and_types(
499 0, {ShapeAndType{output, sparse_matrix_shape_and_type.dtype}});
500 c->set_output(0, c->Scalar());
501
502 return OkStatus();
503 });
504
505REGISTER_OP("SparseMatrixSoftmax")
506 .Input("logits: variant")
507 .Attr("type: {float, double}")
508 .Output("softmax: variant")
509 .SetShapeFn([](InferenceContext* c) {
510 ShapeAndType sparse_matrix_shape_and_type;
511 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
512 ShapeHandle logits = sparse_matrix_shape_and_type.shape;
513 TF_RETURN_IF_ERROR(c->WithRankAtLeast(logits, 2, &logits));
514 TF_RETURN_IF_ERROR(c->WithRankAtMost(logits, 3, &logits));
515 if (!c->RankKnown(logits)) {
516 return errors::InvalidArgument("logits has an unknown rank.");
517 }
518 c->set_output_handle_shapes_and_types(
519 0, {ShapeAndType{logits, sparse_matrix_shape_and_type.dtype}});
520 c->set_output(0, c->Scalar());
521 return OkStatus();
522 });
523
524REGISTER_OP("SparseMatrixSoftmaxGrad")
525 .Input("softmax: variant")
526 .Input("grad_softmax: variant")
527 .Attr("type: {float, double}")
528 .Output("gradient: variant")
529 .SetShapeFn([](InferenceContext* c) {
530 ShapeAndType sparse_matrix_shape_and_type;
531 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
532 ShapeHandle softmax = sparse_matrix_shape_and_type.shape;
533 TF_RETURN_IF_ERROR(c->WithRankAtLeast(softmax, 2, &softmax));
534 TF_RETURN_IF_ERROR(c->WithRankAtMost(softmax, 3, &softmax));
535 if (!c->RankKnown(softmax)) {
536 return errors::InvalidArgument("softmax has an unknown rank.");
537 }
538 TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
539 ShapeHandle grad_softmax = sparse_matrix_shape_and_type.shape;
540 TF_RETURN_IF_ERROR(c->WithRankAtLeast(grad_softmax, 2, &grad_softmax));
541 TF_RETURN_IF_ERROR(c->WithRankAtMost(grad_softmax, 3, &grad_softmax));
542 if (!c->RankKnown(grad_softmax)) {
543 return errors::InvalidArgument("grad_softmax has an unknown rank.");
544 }
545 TF_RETURN_IF_ERROR(c->Merge(softmax, grad_softmax, &softmax));
546 c->set_output_handle_shapes_and_types(
547 0, {ShapeAndType{softmax, sparse_matrix_shape_and_type.dtype}});
548 c->set_output(0, c->Scalar());
549 return OkStatus();
550 });
551
552REGISTER_OP("SparseMatrixOrderingAMD")
553 .Input("input: variant")
554 .Output("output: int32")
555 .SetShapeFn([](InferenceContext* c) {
556 ShapeAndType sparse_matrix_shape_and_type;
557 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
558 ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
559 DimensionHandle n;
560 TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
561
562 ShapeHandle output;
563 if (c->Rank(matrix_shape) == 2) {
564 output = c->Vector(c->Dim(matrix_shape, 0));
565 } else {
566 output = c->Matrix(c->Dim(matrix_shape, 0), c->Dim(matrix_shape, 1));
567 }
568 c->set_output(0, output);
569 return OkStatus();
570 });
571
572REGISTER_OP("SparseMatrixSparseCholesky")
573 .Input("input: variant")
574 .Input("permutation: int32")
575 .Attr("type: {float, double, complex64, complex128}")
576 .Output("output: variant")
577 .SetShapeFn([](InferenceContext* c) {
578 ShapeAndType sparse_matrix_shape_and_type;
579 TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
580 ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
581 DimensionHandle n;
582 TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
583
584 ShapeHandle perm_shape;
585 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &perm_shape));
586 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &perm_shape));
587 if (!c->RankKnown(perm_shape)) {
588 return errors::Internal("permutation has an unknown rank.");
589 }
590
591 // Each batch component of permutation must have the same number of
592 // elements as number of rows of sparse_matrix.
593 TF_RETURN_IF_ERROR(c->Merge(n, c->Dim(perm_shape, -1), &n));
594 ShapeHandle matrix_batch_shape;
595 ShapeHandle perm_batch_shape;
596
597 // Make the common batch subshape.
598 TF_RETURN_IF_ERROR(c->Subshape(matrix_shape, 0, -2, &matrix_batch_shape));
599 TF_RETURN_IF_ERROR(c->Subshape(perm_shape, 0, -1, &perm_shape));
600 // Make sure the batch dimensions match between sparse_matrix and
601 // permutation.
602 TF_RETURN_IF_ERROR(
603 c->Merge(matrix_batch_shape, perm_batch_shape, &matrix_batch_shape));
604
605 ShapeHandle out = matrix_shape;
606 c->set_output_handle_shapes_and_types(
607 0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
608 c->set_output(0, c->Scalar());
609
610 return OkStatus();
611 });
612
613} // namespace tensorflow
614