1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::DimensionHandle;
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26namespace {
27
28// Return in <out> the result of making the end of <s> a square matrix.
29Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
30 ShapeHandle* out) {
31 ShapeHandle s;
32 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
33
34 DimensionHandle d;
35 TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
36
37 ShapeHandle batch_shape;
38 TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
39 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
40 return OkStatus();
41}
42
43Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
44 ShapeHandle out;
45 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
46 c->set_output(0, out);
47 return OkStatus();
48}
49
50// The first input is [...,K,M] and second input is [...,M,N].
51Status BandedTriangularSolveShapeFn(InferenceContext* c) {
52 ShapeHandle lhs;
53 ShapeHandle rhs;
54
55 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
56 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
57
58 // Check K > 0.
59 DimensionHandle num_bands = c->Dim(lhs, -2);
60 DimensionHandle m = c->Dim(lhs, -1);
61 if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) {
62 return errors::InvalidArgument("Number of bands must be positive, but is ",
63 c->Value(num_bands));
64 }
65 if (c->ValueKnown(num_bands) && c->ValueKnown(m) &&
66 c->Value(num_bands) > c->Value(m)) {
67 return errors::InvalidArgument("Number of bands ", c->Value(num_bands),
68 " cannot exceed the size of the matrix ",
69 c->Value(m));
70 }
71
72 ShapeHandle lhs_batch_shape;
73 ShapeHandle rhs_batch_shape;
74 ShapeHandle output_batch_shape;
75 // Make the common batch subshape.
76 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
77 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
78 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
79 c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
80
81 // lhs and rhs have the same value for M to be compatible.
82 TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m));
83
84 // Build final shape (batch_shape + m + n) in <out>.
85 ShapeHandle out;
86 TF_RETURN_IF_ERROR(
87 c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
88
89 c->set_output(0, out);
90 return OkStatus();
91}
92
93// The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
94// Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
95Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
96 ShapeHandle lhs;
97 ShapeHandle rhs;
98 if (square) {
99 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
100 } else {
101 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
102 }
103 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
104
105 ShapeHandle lhs_batch_shape;
106 ShapeHandle rhs_batch_shape;
107 // Make the common batch subshape.
108 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
109 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
110 // Make sure the batch dimensions match between lhs and rhs.
111 TF_RETURN_IF_ERROR(
112 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
113
114 DimensionHandle m;
115 // lhs and rhs have the same value for m to be compatible.
116 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
117 DimensionHandle n = c->Dim(lhs, -1);
118 if (square) {
119 TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
120 }
121
122 ShapeHandle out;
123 // Build final shape (batch_shape + n + k) in <out>.
124 TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
125 TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
126 c->set_output(0, out);
127 return OkStatus();
128}
129
130// The first input is [...,M,M] and second input is [...,M,N].
131// Output is [...,M,N].
132Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
133 ShapeHandle lhs;
134 ShapeHandle rhs;
135 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
136 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
137
138 ShapeHandle lhs_batch_shape;
139 ShapeHandle rhs_batch_shape;
140 ShapeHandle output_batch_shape;
141 // Make the common batch subshape.
142 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
143 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
144 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
145 c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
146 DimensionHandle m;
147 // lhs and rhs have the same value for m to be compatible.
148 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
149
150 ShapeHandle out;
151 // Build final shape (batch_shape + m + n) in <out>.
152 TF_RETURN_IF_ERROR(
153 c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
154 c->set_output(0, out);
155 return OkStatus();
156}
157
158// Input is [...,N,N]. Outputs are:
159// [...,N];[0], if compute_v is false,
160// [...,N];[...,N,N], if compute_v is true.
161Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
162 ShapeHandle input;
163 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
164 DimensionHandle n;
165 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
166 ShapeHandle batch_shape;
167 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
168 ShapeHandle e_shape;
169 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
170 c->set_output(0, e_shape);
171 bool compute_v;
172 TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
173 if (compute_v) {
174 ShapeHandle v_shape;
175 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
176 c->set_output(1, v_shape);
177 } else {
178 c->set_output(1, c->Vector(0ll));
179 }
180 return OkStatus();
181}
182
183// Input is [...,N,N].
184// First and second outputs are:
185// [...,N,N]; [...,N].
186Status LuShapeFn(InferenceContext* c) {
187 ShapeHandle input;
188 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
189
190 DimensionHandle n;
191 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
192
193 ShapeHandle batch_shape;
194 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
195
196 ShapeHandle lu_shape;
197 ShapeHandle p_shape;
198
199 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
200 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
201
202 c->set_output(0, lu_shape);
203 c->set_output(1, p_shape);
204 return OkStatus();
205}
206
207// Input is [...,M,N].
208// First and second outputs are:
209// [...,M,M]; [...,M,N], if full_matrices is true,
210// [...,M,P]; [...,P,N], if full_matrices is false,
211// where P = min(M,N).
212Status QrShapeFn(InferenceContext* c) {
213 ShapeHandle input;
214 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
215 DimensionHandle m = c->Dim(input, -2);
216 DimensionHandle n = c->Dim(input, -1);
217 DimensionHandle p;
218 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
219 ShapeHandle batch_shape;
220 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
221 ShapeHandle q_shape;
222 ShapeHandle r_shape;
223 bool full_matrices;
224 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
225 if (full_matrices) {
226 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
227 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
228 } else {
229 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
230 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
231 }
232 c->set_output(0, q_shape);
233 c->set_output(1, r_shape);
234 return OkStatus();
235}
236
237// Input is [...,M,N]. First output is [...,min(M,N)].
238// Second and third outputs are:
239// [0]; [0], if compute_uv is false.
240// [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
241// [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
242// where P = min(M,N).
243Status SvdShapeFn(InferenceContext* c) {
244 ShapeHandle input;
245 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
246 DimensionHandle m = c->Dim(input, -2);
247 DimensionHandle n = c->Dim(input, -1);
248 DimensionHandle p;
249 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
250 ShapeHandle batch_shape;
251 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
252 ShapeHandle e_shape;
253 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
254 c->set_output(0, e_shape);
255 bool compute_uv;
256 TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
257 if (compute_uv) {
258 ShapeHandle u_shape;
259 ShapeHandle v_shape;
260 bool full_matrices;
261 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
262 if (full_matrices) {
263 TF_RETURN_IF_ERROR(
264 c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
265 TF_RETURN_IF_ERROR(
266 c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
267 } else {
268 TF_RETURN_IF_ERROR(
269 c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
270 TF_RETURN_IF_ERROR(
271 c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
272 }
273 c->set_output(1, u_shape);
274 c->set_output(2, v_shape);
275 } else {
276 c->set_output(1, c->Vector(0ll));
277 c->set_output(2, c->Vector(0ll));
278 }
279 return OkStatus();
280}
281
282// Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N].
283// Output is [...,M,N].
284Status TridiagonalMatMulShapeFn(InferenceContext* c) {
285 ShapeHandle superdiag;
286 ShapeHandle maindiag;
287 ShapeHandle subdiag;
288 ShapeHandle rhs;
289
290 // Check that rank is at least 2.
291 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag));
292 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag));
293 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag));
294 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs));
295
296 // Extract batch dimensions and check they are the same.
297 ShapeHandle superdiag_batch_shape;
298 ShapeHandle maindiag_batch_shape;
299 ShapeHandle subdiag_batch_shape;
300 ShapeHandle rhs_batch_shape;
301 TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape));
302 TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape));
303 TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape));
304 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
305 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag));
306 TF_RETURN_IF_ERROR(
307 c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
308 TF_RETURN_IF_ERROR(
309 c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
310
311 // Check that diagonals have the same shape.
312 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag));
313 TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag));
314
315 // Check that size of tri-diagonal matrix is the same as height of matrix on
316 // the right.
317 DimensionHandle m_lhs = c->Dim(maindiag, -1);
318 DimensionHandle m_rhs = c->Dim(rhs, -2);
319 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
320
321 // Check that next-to-last dimension of diagonals is 1.
322 DimensionHandle unused;
323 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused));
324
325 // The output shape is the same as rhs shape.
326 c->set_output(0, rhs);
327 return OkStatus();
328}
329
330// The first input is [...,3,M] and second input is [...,M,K].
331// Output is [...,M,K].
332Status TridiagonalSolveShapeFn(InferenceContext* c) {
333 ShapeHandle lhs;
334 ShapeHandle rhs;
335 // Check that rank is at least 2.
336 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
337 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
338
339 // Extract batch dimensions and check they are the same.
340 ShapeHandle lhs_batch_shape;
341 ShapeHandle rhs_batch_shape;
342 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
343 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
344 TF_RETURN_IF_ERROR(
345 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
346
347 // Check that "M" is the same in both inputs.
348 DimensionHandle m_lhs = c->Dim(lhs, -1);
349 DimensionHandle m_rhs = c->Dim(rhs, -2);
350 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
351
352 // Check that next-to-last dimension of the first input is 3.
353 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
354
355 // The output shape is the same as rhs shape.
356 c->set_output(0, rhs);
357 return OkStatus();
358}
359
360} // namespace
361
362REGISTER_OP("MatrixDeterminant")
363 .Input("input: T")
364 .Output("output: T")
365 .Attr("T: {half, float, double, complex64, complex128}")
366 .SetShapeFn([](InferenceContext* c) {
367 ShapeHandle input;
368 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
369
370 DimensionHandle unused;
371 TF_RETURN_IF_ERROR(
372 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
373
374 ShapeHandle out;
375 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
376 c->set_output(0, out);
377 return OkStatus();
378 });
379
380REGISTER_OP("LogMatrixDeterminant")
381 .Input("input: T")
382 .Output("sign: T")
383 .Output("log_abs_determinant: T")
384 .Attr("T: {half, float, double, complex64, complex128}")
385 .SetShapeFn([](InferenceContext* c) {
386 ShapeHandle input;
387 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
388
389 DimensionHandle unused;
390 TF_RETURN_IF_ERROR(
391 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
392
393 ShapeHandle s;
394 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
395 c->set_output(0, s);
396
397 ShapeHandle out;
398 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
399 c->set_output(1, out);
400 return OkStatus();
401 });
402
403REGISTER_OP("MatrixInverse")
404 .Input("input: T")
405 .Output("output: T")
406 .Attr("adjoint: bool = False")
407 .Attr("T: {double, float, half, complex64, complex128}")
408 .SetShapeFn(BatchUnchangedSquareShapeFn);
409
410REGISTER_OP("MatrixExponential")
411 .Deprecated(
412 27, "Use Python implementation tf.linalg.matrix_exponential instead.")
413 .Input("input: T")
414 .Output("output: T")
415 .Attr("T: {double, float, half, complex64, complex128}")
416 .SetShapeFn(BatchUnchangedSquareShapeFn);
417
418REGISTER_OP("MatrixLogarithm")
419 .Input("input: T")
420 .Output("output: T")
421 .Attr("T: {complex64, complex128}")
422 .SetShapeFn(BatchUnchangedSquareShapeFn);
423
424REGISTER_OP("Cholesky")
425 .Input("input: T")
426 .Output("output: T")
427 .Attr("T: {double, float, half, complex64, complex128}")
428 .SetShapeFn(BatchUnchangedSquareShapeFn);
429
430REGISTER_OP("CholeskyGrad")
431 .Input("l: T")
432 .Input("grad: T")
433 .Output("output: T")
434 .Attr("T: {half, float, double}")
435 .SetShapeFn(BatchUnchangedSquareShapeFn);
436
437REGISTER_OP("SelfAdjointEig")
438 .Input("input: T")
439 .Output("output: T")
440 .Attr("T: {double, float, half}")
441 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
442 .SetShapeFn([](InferenceContext* c) {
443 ShapeHandle input;
444 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
445
446 DimensionHandle d = c->Dim(input, -1);
447 DimensionHandle d_plus_1;
448 TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
449
450 ShapeHandle s;
451 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
452 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
453 c->set_output(0, s);
454 return OkStatus();
455 });
456
457REGISTER_OP("Eig")
458 .Input("input: T")
459 .Output("e: Tout")
460 .Output("v: Tout")
461 .Attr("compute_v: bool = True")
462 .Attr("T: {float, double, complex64, complex128}")
463 .Attr("Tout: {complex64, complex128}")
464 .SetShapeFn(SelfAdjointEigV2ShapeFn);
465
466REGISTER_OP("SelfAdjointEigV2")
467 .Input("input: T")
468 .Output("e: T")
469 .Output("v: T")
470 .Attr("compute_v: bool = True")
471 .Attr("T: {double, float, half, complex64, complex128}")
472 .SetShapeFn(SelfAdjointEigV2ShapeFn);
473
474REGISTER_OP("Lu")
475 .Input("input: T")
476 .Output("lu: T")
477 .Output("p: output_idx_type")
478 .Attr("T: {double, float, half, complex64, complex128}")
479 .Attr("output_idx_type: {int32, int64} = DT_INT32")
480 .SetShapeFn(LuShapeFn);
481
482REGISTER_OP("MatrixSolve")
483 .Input("matrix: T")
484 .Input("rhs: T")
485 .Output("output: T")
486 .Attr("adjoint: bool = False")
487 .Attr("T: {double, float, half, complex64, complex128}")
488 .SetShapeFn([](InferenceContext* c) {
489 return MatrixSolveShapeFn(c, true /* square (*/);
490 });
491
492REGISTER_OP("BandedTriangularSolve")
493 .Input("matrix: T")
494 .Input("rhs: T")
495 .Output("output: T")
496 .Attr("lower: bool = True")
497 .Attr("adjoint: bool = False")
498 .Attr("T: {double, float, half, complex64, complex128}")
499 .SetShapeFn([](InferenceContext* c) {
500 return BandedTriangularSolveShapeFn(c);
501 });
502
503REGISTER_OP("MatrixTriangularSolve")
504 .Input("matrix: T")
505 .Input("rhs: T")
506 .Output("output: T")
507 .Attr("lower: bool = True")
508 .Attr("adjoint: bool = False")
509 .Attr("T: {bfloat16, double, float, half, complex64, complex128}")
510 .SetShapeFn([](InferenceContext* c) {
511 return MatrixTriangularSolveShapeFn(c);
512 });
513
514REGISTER_OP("MatrixSolveLs")
515 .Input("matrix: T")
516 .Input("rhs: T")
517 .Input("l2_regularizer: double")
518 .Output("output: T")
519 .Attr("T: {double, float, half, complex64, complex128}")
520 .Attr("fast: bool = True")
521 .SetShapeFn([](InferenceContext* c) {
522 ShapeHandle l2_regularizer;
523 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
524 return MatrixSolveShapeFn(c, false /* square */);
525 });
526
527REGISTER_OP("MatrixSquareRoot")
528 .Input("input: T")
529 .Output("output: T")
530 .Attr("T: {double, float, half, complex64, complex128}")
531 .SetShapeFn(BatchUnchangedSquareShapeFn);
532
533REGISTER_OP("Qr")
534 .Input("input: T")
535 .Output("q: T")
536 .Output("r: T")
537 .Attr("full_matrices: bool = False")
538 .Attr("T: {double, float, half, complex64, complex128}")
539 .SetShapeFn(QrShapeFn);
540
541REGISTER_OP("Svd")
542 .Input("input: T")
543 .Output("s: T")
544 .Output("u: T")
545 .Output("v: T")
546 .Attr("compute_uv: bool = True")
547 .Attr("full_matrices: bool = False")
548 .Attr("T: {double, float, half, complex64, complex128}")
549 .SetShapeFn(SvdShapeFn);
550
551REGISTER_OP("TridiagonalMatMul")
552 .Input("superdiag: T")
553 .Input("maindiag: T")
554 .Input("subdiag: T")
555 .Input("rhs: T")
556 .Output("output: T")
557 .Attr("T: {double, float, complex64, complex128}")
558 .SetShapeFn(TridiagonalMatMulShapeFn);
559
560REGISTER_OP("TridiagonalSolve")
561 .Input("diagonals: T")
562 .Input("rhs: T")
563 .Output("output: T")
564 .Attr("partial_pivoting: bool = True")
565 .Attr("perturb_singular: bool = False")
566 .Attr("T: {double, float, complex64, complex128}")
567 .SetShapeFn(TridiagonalSolveShapeFn);
568
569REGISTER_OP("Einsum")
570 .Input("inputs: N * T")
571 .Output("output: T")
572 .Attr("equation: string")
573 .Attr("N: int >= 1")
574 .Attr("T: type")
575 .SetShapeFn(shape_inference::EinsumShape);
576
577// Deprecated op registrations:
578
579// Can be deleted after 3feb2017.
580REGISTER_OP("BatchSelfAdjointEig")
581 .Input("input: T")
582 .Output("output: T")
583 .Attr("T: {double, float}")
584 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
585 .SetShapeFn(shape_inference::UnknownShape);
586
587// Can all be deleted after 9mar2017.
588REGISTER_OP("BatchMatrixDeterminant")
589 .Input("input: T")
590 .Output("output: T")
591 .Attr("T: {float, double, complex64, complex128}")
592 .Deprecated(13, "Use MatrixDeterminant instead.")
593 .SetShapeFn(shape_inference::UnknownShape);
594
595REGISTER_OP("BatchMatrixInverse")
596 .Input("input: T")
597 .Output("output: T")
598 .Attr("adjoint: bool = False")
599 .Attr("T: {double, float}")
600 .Deprecated(13, "Use MatrixInverse instead.")
601 .SetShapeFn(shape_inference::UnknownShape);
602
603REGISTER_OP("BatchCholesky")
604 .Input("input: T")
605 .Output("output: T")
606 .Attr("T: {double, float}")
607 .Deprecated(13, "Use Cholesky instead.")
608 .SetShapeFn(shape_inference::UnknownShape);
609
610REGISTER_OP("BatchCholeskyGrad")
611 .Input("l: T")
612 .Input("grad: T")
613 .Output("output: T")
614 .Attr("T: {float, double}")
615 .Deprecated(13, "Use CholeskyGrad instead.")
616 .SetShapeFn(shape_inference::UnknownShape);
617
618REGISTER_OP("BatchSelfAdjointEigV2")
619 .Input("input: T")
620 .Output("e: T")
621 .Output("v: T")
622 .Attr("compute_v: bool = True")
623 .Attr("T: {double, float}")
624 .Deprecated(13, "Use SelfAdjointEigV2 instead.")
625 .SetShapeFn(shape_inference::UnknownShape);
626
627REGISTER_OP("BatchMatrixSolve")
628 .Input("matrix: T")
629 .Input("rhs: T")
630 .Output("output: T")
631 .Attr("adjoint: bool = False")
632 .Attr("T: {double, float}")
633 .Deprecated(13, "Use MatrixSolve instead.")
634 .SetShapeFn(shape_inference::UnknownShape);
635
636REGISTER_OP("BatchMatrixTriangularSolve")
637 .Input("matrix: T")
638 .Input("rhs: T")
639 .Output("output: T")
640 .Attr("lower: bool = True")
641 .Attr("adjoint: bool = False")
642 .Attr("T: {double, float}")
643 .Deprecated(13, "Use MatrixTriangularSolve instead.")
644 .SetShapeFn(shape_inference::UnknownShape);
645
646REGISTER_OP("BatchMatrixSolveLs")
647 .Input("matrix: T")
648 .Input("rhs: T")
649 .Input("l2_regularizer: double")
650 .Output("output: T")
651 .Attr("T: {double, float}")
652 .Attr("fast: bool = True")
653 .Deprecated(13, "Use MatrixSolveLs instead.")
654 .SetShapeFn(shape_inference::UnknownShape);
655
656REGISTER_OP("BatchSvd")
657 .Input("input: T")
658 .Output("s: T")
659 .Output("u: T")
660 .Output("v: T")
661 .Attr("compute_uv: bool = True")
662 .Attr("full_matrices: bool = False")
663 .Attr("T: {double, float, complex64, complex128}")
664 .Deprecated(13, "Use Svd instead.")
665 .SetShapeFn(shape_inference::UnknownShape);
666
667} // namespace tensorflow
668