1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::DimensionHandle; |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | namespace { |
27 | |
28 | // Return in <out> the result of making the end of <s> a square matrix. |
29 | Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input, |
30 | ShapeHandle* out) { |
31 | ShapeHandle s; |
32 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s)); |
33 | |
34 | DimensionHandle d; |
35 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d)); |
36 | |
37 | ShapeHandle batch_shape; |
38 | TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape)); |
39 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out)); |
40 | return OkStatus(); |
41 | } |
42 | |
43 | Status BatchUnchangedSquareShapeFn(InferenceContext* c) { |
44 | ShapeHandle out; |
45 | TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out)); |
46 | c->set_output(0, out); |
47 | return OkStatus(); |
48 | } |
49 | |
50 | // The first input is [...,K,M] and second input is [...,M,N]. |
51 | Status BandedTriangularSolveShapeFn(InferenceContext* c) { |
52 | ShapeHandle lhs; |
53 | ShapeHandle rhs; |
54 | |
55 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); |
56 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); |
57 | |
58 | // Check K > 0. |
59 | DimensionHandle num_bands = c->Dim(lhs, -2); |
60 | DimensionHandle m = c->Dim(lhs, -1); |
61 | if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) { |
62 | return errors::InvalidArgument("Number of bands must be positive, but is " , |
63 | c->Value(num_bands)); |
64 | } |
65 | if (c->ValueKnown(num_bands) && c->ValueKnown(m) && |
66 | c->Value(num_bands) > c->Value(m)) { |
67 | return errors::InvalidArgument("Number of bands " , c->Value(num_bands), |
68 | " cannot exceed the size of the matrix " , |
69 | c->Value(m)); |
70 | } |
71 | |
72 | ShapeHandle lhs_batch_shape; |
73 | ShapeHandle rhs_batch_shape; |
74 | ShapeHandle output_batch_shape; |
75 | // Make the common batch subshape. |
76 | TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); |
77 | TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); |
78 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
79 | c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape)); |
80 | |
81 | // lhs and rhs have the same value for M to be compatible. |
82 | TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m)); |
83 | |
84 | // Build final shape (batch_shape + m + n) in <out>. |
85 | ShapeHandle out; |
86 | TF_RETURN_IF_ERROR( |
87 | c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out)); |
88 | |
89 | c->set_output(0, out); |
90 | return OkStatus(); |
91 | } |
92 | |
93 | // The first input is [...,M,N] and second input is either [...,M,K] or [...,M]. |
94 | // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M]. |
95 | Status MatrixSolveShapeFn(InferenceContext* c, bool square) { |
96 | ShapeHandle lhs; |
97 | ShapeHandle rhs; |
98 | if (square) { |
99 | TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); |
100 | } else { |
101 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); |
102 | } |
103 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); |
104 | |
105 | ShapeHandle lhs_batch_shape; |
106 | ShapeHandle rhs_batch_shape; |
107 | // Make the common batch subshape. |
108 | TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); |
109 | TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); |
110 | // Make sure the batch dimensions match between lhs and rhs. |
111 | TF_RETURN_IF_ERROR( |
112 | c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape)); |
113 | |
114 | DimensionHandle m; |
115 | // lhs and rhs have the same value for m to be compatible. |
116 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m)); |
117 | DimensionHandle n = c->Dim(lhs, -1); |
118 | if (square) { |
119 | TF_RETURN_IF_ERROR(c->Merge(m, n, &n)); |
120 | } |
121 | |
122 | ShapeHandle out; |
123 | // Build final shape (batch_shape + n + k) in <out>. |
124 | TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out)); |
125 | TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out)); |
126 | c->set_output(0, out); |
127 | return OkStatus(); |
128 | } |
129 | |
130 | // The first input is [...,M,M] and second input is [...,M,N]. |
131 | // Output is [...,M,N]. |
132 | Status MatrixTriangularSolveShapeFn(InferenceContext* c) { |
133 | ShapeHandle lhs; |
134 | ShapeHandle rhs; |
135 | TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); |
136 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); |
137 | |
138 | ShapeHandle lhs_batch_shape; |
139 | ShapeHandle rhs_batch_shape; |
140 | ShapeHandle output_batch_shape; |
141 | // Make the common batch subshape. |
142 | TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); |
143 | TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); |
144 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
145 | c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape)); |
146 | DimensionHandle m; |
147 | // lhs and rhs have the same value for m to be compatible. |
148 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m)); |
149 | |
150 | ShapeHandle out; |
151 | // Build final shape (batch_shape + m + n) in <out>. |
152 | TF_RETURN_IF_ERROR( |
153 | c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out)); |
154 | c->set_output(0, out); |
155 | return OkStatus(); |
156 | } |
157 | |
158 | // Input is [...,N,N]. Outputs are: |
159 | // [...,N];[0], if compute_v is false, |
160 | // [...,N];[...,N,N], if compute_v is true. |
161 | Status SelfAdjointEigV2ShapeFn(InferenceContext* c) { |
162 | ShapeHandle input; |
163 | TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); |
164 | DimensionHandle n; |
165 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n)); |
166 | ShapeHandle batch_shape; |
167 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); |
168 | ShapeHandle e_shape; |
169 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape)); |
170 | c->set_output(0, e_shape); |
171 | bool compute_v; |
172 | TF_RETURN_IF_ERROR(c->GetAttr("compute_v" , &compute_v)); |
173 | if (compute_v) { |
174 | ShapeHandle v_shape; |
175 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape)); |
176 | c->set_output(1, v_shape); |
177 | } else { |
178 | c->set_output(1, c->Vector(0ll)); |
179 | } |
180 | return OkStatus(); |
181 | } |
182 | |
183 | // Input is [...,N,N]. |
184 | // First and second outputs are: |
185 | // [...,N,N]; [...,N]. |
186 | Status LuShapeFn(InferenceContext* c) { |
187 | ShapeHandle input; |
188 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
189 | |
190 | DimensionHandle n; |
191 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n)); |
192 | |
193 | ShapeHandle batch_shape; |
194 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); |
195 | |
196 | ShapeHandle lu_shape; |
197 | ShapeHandle p_shape; |
198 | |
199 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape)); |
200 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape)); |
201 | |
202 | c->set_output(0, lu_shape); |
203 | c->set_output(1, p_shape); |
204 | return OkStatus(); |
205 | } |
206 | |
207 | // Input is [...,M,N]. |
208 | // First and second outputs are: |
209 | // [...,M,M]; [...,M,N], if full_matrices is true, |
210 | // [...,M,P]; [...,P,N], if full_matrices is false, |
211 | // where P = min(M,N). |
212 | Status QrShapeFn(InferenceContext* c) { |
213 | ShapeHandle input; |
214 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
215 | DimensionHandle m = c->Dim(input, -2); |
216 | DimensionHandle n = c->Dim(input, -1); |
217 | DimensionHandle p; |
218 | TF_RETURN_IF_ERROR(c->Min(m, n, &p)); |
219 | ShapeHandle batch_shape; |
220 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); |
221 | ShapeHandle q_shape; |
222 | ShapeHandle r_shape; |
223 | bool full_matrices; |
224 | TF_RETURN_IF_ERROR(c->GetAttr("full_matrices" , &full_matrices)); |
225 | if (full_matrices) { |
226 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape)); |
227 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape)); |
228 | } else { |
229 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape)); |
230 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape)); |
231 | } |
232 | c->set_output(0, q_shape); |
233 | c->set_output(1, r_shape); |
234 | return OkStatus(); |
235 | } |
236 | |
237 | // Input is [...,M,N]. First output is [...,min(M,N)]. |
238 | // Second and third outputs are: |
239 | // [0]; [0], if compute_uv is false. |
240 | // [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true, |
241 | // [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false, |
242 | // where P = min(M,N). |
243 | Status SvdShapeFn(InferenceContext* c) { |
244 | ShapeHandle input; |
245 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
246 | DimensionHandle m = c->Dim(input, -2); |
247 | DimensionHandle n = c->Dim(input, -1); |
248 | DimensionHandle p; |
249 | TF_RETURN_IF_ERROR(c->Min(m, n, &p)); |
250 | ShapeHandle batch_shape; |
251 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); |
252 | ShapeHandle e_shape; |
253 | TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape)); |
254 | c->set_output(0, e_shape); |
255 | bool compute_uv; |
256 | TF_RETURN_IF_ERROR(c->GetAttr("compute_uv" , &compute_uv)); |
257 | if (compute_uv) { |
258 | ShapeHandle u_shape; |
259 | ShapeHandle v_shape; |
260 | bool full_matrices; |
261 | TF_RETURN_IF_ERROR(c->GetAttr("full_matrices" , &full_matrices)); |
262 | if (full_matrices) { |
263 | TF_RETURN_IF_ERROR( |
264 | c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape)); |
265 | TF_RETURN_IF_ERROR( |
266 | c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape)); |
267 | } else { |
268 | TF_RETURN_IF_ERROR( |
269 | c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape)); |
270 | TF_RETURN_IF_ERROR( |
271 | c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape)); |
272 | } |
273 | c->set_output(1, u_shape); |
274 | c->set_output(2, v_shape); |
275 | } else { |
276 | c->set_output(1, c->Vector(0ll)); |
277 | c->set_output(2, c->Vector(0ll)); |
278 | } |
279 | return OkStatus(); |
280 | } |
281 | |
282 | // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N]. |
283 | // Output is [...,M,N]. |
284 | Status TridiagonalMatMulShapeFn(InferenceContext* c) { |
285 | ShapeHandle superdiag; |
286 | ShapeHandle maindiag; |
287 | ShapeHandle subdiag; |
288 | ShapeHandle rhs; |
289 | |
290 | // Check that rank is at least 2. |
291 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag)); |
292 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag)); |
293 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag)); |
294 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs)); |
295 | |
296 | // Extract batch dimensions and check they are the same. |
297 | ShapeHandle superdiag_batch_shape; |
298 | ShapeHandle maindiag_batch_shape; |
299 | ShapeHandle subdiag_batch_shape; |
300 | ShapeHandle rhs_batch_shape; |
301 | TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape)); |
302 | TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape)); |
303 | TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape)); |
304 | TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); |
305 | TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag)); |
306 | TF_RETURN_IF_ERROR( |
307 | c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape)); |
308 | TF_RETURN_IF_ERROR( |
309 | c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape)); |
310 | |
311 | // Check that diagonals have the same shape. |
312 | TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag)); |
313 | TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag)); |
314 | |
315 | // Check that size of tri-diagonal matrix is the same as height of matrix on |
316 | // the right. |
317 | DimensionHandle m_lhs = c->Dim(maindiag, -1); |
318 | DimensionHandle m_rhs = c->Dim(rhs, -2); |
319 | TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs)); |
320 | |
321 | // Check that next-to-last dimension of diagonals is 1. |
322 | DimensionHandle unused; |
323 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused)); |
324 | |
325 | // The output shape is the same as rhs shape. |
326 | c->set_output(0, rhs); |
327 | return OkStatus(); |
328 | } |
329 | |
330 | // The first input is [...,3,M] and second input is [...,M,K]. |
331 | // Output is [...,M,K]. |
332 | Status TridiagonalSolveShapeFn(InferenceContext* c) { |
333 | ShapeHandle lhs; |
334 | ShapeHandle rhs; |
335 | // Check that rank is at least 2. |
336 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); |
337 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); |
338 | |
339 | // Extract batch dimensions and check they are the same. |
340 | ShapeHandle lhs_batch_shape; |
341 | ShapeHandle rhs_batch_shape; |
342 | TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); |
343 | TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); |
344 | TF_RETURN_IF_ERROR( |
345 | c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape)); |
346 | |
347 | // Check that "M" is the same in both inputs. |
348 | DimensionHandle m_lhs = c->Dim(lhs, -1); |
349 | DimensionHandle m_rhs = c->Dim(rhs, -2); |
350 | TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs)); |
351 | |
352 | // Check that next-to-last dimension of the first input is 3. |
353 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs)); |
354 | |
355 | // The output shape is the same as rhs shape. |
356 | c->set_output(0, rhs); |
357 | return OkStatus(); |
358 | } |
359 | |
360 | } // namespace |
361 | |
362 | REGISTER_OP("MatrixDeterminant" ) |
363 | .Input("input: T" ) |
364 | .Output("output: T" ) |
365 | .Attr("T: {half, float, double, complex64, complex128}" ) |
366 | .SetShapeFn([](InferenceContext* c) { |
367 | ShapeHandle input; |
368 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
369 | |
370 | DimensionHandle unused; |
371 | TF_RETURN_IF_ERROR( |
372 | c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused)); |
373 | |
374 | ShapeHandle out; |
375 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out)); |
376 | c->set_output(0, out); |
377 | return OkStatus(); |
378 | }); |
379 | |
380 | REGISTER_OP("LogMatrixDeterminant" ) |
381 | .Input("input: T" ) |
382 | .Output("sign: T" ) |
383 | .Output("log_abs_determinant: T" ) |
384 | .Attr("T: {half, float, double, complex64, complex128}" ) |
385 | .SetShapeFn([](InferenceContext* c) { |
386 | ShapeHandle input; |
387 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
388 | |
389 | DimensionHandle unused; |
390 | TF_RETURN_IF_ERROR( |
391 | c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused)); |
392 | |
393 | ShapeHandle s; |
394 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s)); |
395 | c->set_output(0, s); |
396 | |
397 | ShapeHandle out; |
398 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out)); |
399 | c->set_output(1, out); |
400 | return OkStatus(); |
401 | }); |
402 | |
403 | REGISTER_OP("MatrixInverse" ) |
404 | .Input("input: T" ) |
405 | .Output("output: T" ) |
406 | .Attr("adjoint: bool = False" ) |
407 | .Attr("T: {double, float, half, complex64, complex128}" ) |
408 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
409 | |
410 | REGISTER_OP("MatrixExponential" ) |
411 | .Deprecated( |
412 | 27, "Use Python implementation tf.linalg.matrix_exponential instead." ) |
413 | .Input("input: T" ) |
414 | .Output("output: T" ) |
415 | .Attr("T: {double, float, half, complex64, complex128}" ) |
416 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
417 | |
418 | REGISTER_OP("MatrixLogarithm" ) |
419 | .Input("input: T" ) |
420 | .Output("output: T" ) |
421 | .Attr("T: {complex64, complex128}" ) |
422 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
423 | |
424 | REGISTER_OP("Cholesky" ) |
425 | .Input("input: T" ) |
426 | .Output("output: T" ) |
427 | .Attr("T: {double, float, half, complex64, complex128}" ) |
428 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
429 | |
430 | REGISTER_OP("CholeskyGrad" ) |
431 | .Input("l: T" ) |
432 | .Input("grad: T" ) |
433 | .Output("output: T" ) |
434 | .Attr("T: {half, float, double}" ) |
435 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
436 | |
437 | REGISTER_OP("SelfAdjointEig" ) |
438 | .Input("input: T" ) |
439 | .Output("output: T" ) |
440 | .Attr("T: {double, float, half}" ) |
441 | .Deprecated(11, "Use SelfAdjointEigV2 instead." ) |
442 | .SetShapeFn([](InferenceContext* c) { |
443 | ShapeHandle input; |
444 | TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); |
445 | |
446 | DimensionHandle d = c->Dim(input, -1); |
447 | DimensionHandle d_plus_1; |
448 | TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1)); |
449 | |
450 | ShapeHandle s; |
451 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s)); |
452 | TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s)); |
453 | c->set_output(0, s); |
454 | return OkStatus(); |
455 | }); |
456 | |
457 | REGISTER_OP("Eig" ) |
458 | .Input("input: T" ) |
459 | .Output("e: Tout" ) |
460 | .Output("v: Tout" ) |
461 | .Attr("compute_v: bool = True" ) |
462 | .Attr("T: {float, double, complex64, complex128}" ) |
463 | .Attr("Tout: {complex64, complex128}" ) |
464 | .SetShapeFn(SelfAdjointEigV2ShapeFn); |
465 | |
466 | REGISTER_OP("SelfAdjointEigV2" ) |
467 | .Input("input: T" ) |
468 | .Output("e: T" ) |
469 | .Output("v: T" ) |
470 | .Attr("compute_v: bool = True" ) |
471 | .Attr("T: {double, float, half, complex64, complex128}" ) |
472 | .SetShapeFn(SelfAdjointEigV2ShapeFn); |
473 | |
474 | REGISTER_OP("Lu" ) |
475 | .Input("input: T" ) |
476 | .Output("lu: T" ) |
477 | .Output("p: output_idx_type" ) |
478 | .Attr("T: {double, float, half, complex64, complex128}" ) |
479 | .Attr("output_idx_type: {int32, int64} = DT_INT32" ) |
480 | .SetShapeFn(LuShapeFn); |
481 | |
482 | REGISTER_OP("MatrixSolve" ) |
483 | .Input("matrix: T" ) |
484 | .Input("rhs: T" ) |
485 | .Output("output: T" ) |
486 | .Attr("adjoint: bool = False" ) |
487 | .Attr("T: {double, float, half, complex64, complex128}" ) |
488 | .SetShapeFn([](InferenceContext* c) { |
489 | return MatrixSolveShapeFn(c, true /* square (*/); |
490 | }); |
491 | |
492 | REGISTER_OP("BandedTriangularSolve" ) |
493 | .Input("matrix: T" ) |
494 | .Input("rhs: T" ) |
495 | .Output("output: T" ) |
496 | .Attr("lower: bool = True" ) |
497 | .Attr("adjoint: bool = False" ) |
498 | .Attr("T: {double, float, half, complex64, complex128}" ) |
499 | .SetShapeFn([](InferenceContext* c) { |
500 | return BandedTriangularSolveShapeFn(c); |
501 | }); |
502 | |
503 | REGISTER_OP("MatrixTriangularSolve" ) |
504 | .Input("matrix: T" ) |
505 | .Input("rhs: T" ) |
506 | .Output("output: T" ) |
507 | .Attr("lower: bool = True" ) |
508 | .Attr("adjoint: bool = False" ) |
509 | .Attr("T: {bfloat16, double, float, half, complex64, complex128}" ) |
510 | .SetShapeFn([](InferenceContext* c) { |
511 | return MatrixTriangularSolveShapeFn(c); |
512 | }); |
513 | |
514 | REGISTER_OP("MatrixSolveLs" ) |
515 | .Input("matrix: T" ) |
516 | .Input("rhs: T" ) |
517 | .Input("l2_regularizer: double" ) |
518 | .Output("output: T" ) |
519 | .Attr("T: {double, float, half, complex64, complex128}" ) |
520 | .Attr("fast: bool = True" ) |
521 | .SetShapeFn([](InferenceContext* c) { |
522 | ShapeHandle l2_regularizer; |
523 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer)); |
524 | return MatrixSolveShapeFn(c, false /* square */); |
525 | }); |
526 | |
527 | REGISTER_OP("MatrixSquareRoot" ) |
528 | .Input("input: T" ) |
529 | .Output("output: T" ) |
530 | .Attr("T: {double, float, half, complex64, complex128}" ) |
531 | .SetShapeFn(BatchUnchangedSquareShapeFn); |
532 | |
533 | REGISTER_OP("Qr" ) |
534 | .Input("input: T" ) |
535 | .Output("q: T" ) |
536 | .Output("r: T" ) |
537 | .Attr("full_matrices: bool = False" ) |
538 | .Attr("T: {double, float, half, complex64, complex128}" ) |
539 | .SetShapeFn(QrShapeFn); |
540 | |
541 | REGISTER_OP("Svd" ) |
542 | .Input("input: T" ) |
543 | .Output("s: T" ) |
544 | .Output("u: T" ) |
545 | .Output("v: T" ) |
546 | .Attr("compute_uv: bool = True" ) |
547 | .Attr("full_matrices: bool = False" ) |
548 | .Attr("T: {double, float, half, complex64, complex128}" ) |
549 | .SetShapeFn(SvdShapeFn); |
550 | |
551 | REGISTER_OP("TridiagonalMatMul" ) |
552 | .Input("superdiag: T" ) |
553 | .Input("maindiag: T" ) |
554 | .Input("subdiag: T" ) |
555 | .Input("rhs: T" ) |
556 | .Output("output: T" ) |
557 | .Attr("T: {double, float, complex64, complex128}" ) |
558 | .SetShapeFn(TridiagonalMatMulShapeFn); |
559 | |
560 | REGISTER_OP("TridiagonalSolve" ) |
561 | .Input("diagonals: T" ) |
562 | .Input("rhs: T" ) |
563 | .Output("output: T" ) |
564 | .Attr("partial_pivoting: bool = True" ) |
565 | .Attr("perturb_singular: bool = False" ) |
566 | .Attr("T: {double, float, complex64, complex128}" ) |
567 | .SetShapeFn(TridiagonalSolveShapeFn); |
568 | |
569 | REGISTER_OP("Einsum" ) |
570 | .Input("inputs: N * T" ) |
571 | .Output("output: T" ) |
572 | .Attr("equation: string" ) |
573 | .Attr("N: int >= 1" ) |
574 | .Attr("T: type" ) |
575 | .SetShapeFn(shape_inference::EinsumShape); |
576 | |
577 | // Deprecated op registrations: |
578 | |
579 | // Can be deleted after 3feb2017. |
580 | REGISTER_OP("BatchSelfAdjointEig" ) |
581 | .Input("input: T" ) |
582 | .Output("output: T" ) |
583 | .Attr("T: {double, float}" ) |
584 | .Deprecated(11, "Use SelfAdjointEigV2 instead." ) |
585 | .SetShapeFn(shape_inference::UnknownShape); |
586 | |
587 | // Can all be deleted after 9mar2017. |
588 | REGISTER_OP("BatchMatrixDeterminant" ) |
589 | .Input("input: T" ) |
590 | .Output("output: T" ) |
591 | .Attr("T: {float, double, complex64, complex128}" ) |
592 | .Deprecated(13, "Use MatrixDeterminant instead." ) |
593 | .SetShapeFn(shape_inference::UnknownShape); |
594 | |
595 | REGISTER_OP("BatchMatrixInverse" ) |
596 | .Input("input: T" ) |
597 | .Output("output: T" ) |
598 | .Attr("adjoint: bool = False" ) |
599 | .Attr("T: {double, float}" ) |
600 | .Deprecated(13, "Use MatrixInverse instead." ) |
601 | .SetShapeFn(shape_inference::UnknownShape); |
602 | |
603 | REGISTER_OP("BatchCholesky" ) |
604 | .Input("input: T" ) |
605 | .Output("output: T" ) |
606 | .Attr("T: {double, float}" ) |
607 | .Deprecated(13, "Use Cholesky instead." ) |
608 | .SetShapeFn(shape_inference::UnknownShape); |
609 | |
610 | REGISTER_OP("BatchCholeskyGrad" ) |
611 | .Input("l: T" ) |
612 | .Input("grad: T" ) |
613 | .Output("output: T" ) |
614 | .Attr("T: {float, double}" ) |
615 | .Deprecated(13, "Use CholeskyGrad instead." ) |
616 | .SetShapeFn(shape_inference::UnknownShape); |
617 | |
618 | REGISTER_OP("BatchSelfAdjointEigV2" ) |
619 | .Input("input: T" ) |
620 | .Output("e: T" ) |
621 | .Output("v: T" ) |
622 | .Attr("compute_v: bool = True" ) |
623 | .Attr("T: {double, float}" ) |
624 | .Deprecated(13, "Use SelfAdjointEigV2 instead." ) |
625 | .SetShapeFn(shape_inference::UnknownShape); |
626 | |
627 | REGISTER_OP("BatchMatrixSolve" ) |
628 | .Input("matrix: T" ) |
629 | .Input("rhs: T" ) |
630 | .Output("output: T" ) |
631 | .Attr("adjoint: bool = False" ) |
632 | .Attr("T: {double, float}" ) |
633 | .Deprecated(13, "Use MatrixSolve instead." ) |
634 | .SetShapeFn(shape_inference::UnknownShape); |
635 | |
636 | REGISTER_OP("BatchMatrixTriangularSolve" ) |
637 | .Input("matrix: T" ) |
638 | .Input("rhs: T" ) |
639 | .Output("output: T" ) |
640 | .Attr("lower: bool = True" ) |
641 | .Attr("adjoint: bool = False" ) |
642 | .Attr("T: {double, float}" ) |
643 | .Deprecated(13, "Use MatrixTriangularSolve instead." ) |
644 | .SetShapeFn(shape_inference::UnknownShape); |
645 | |
646 | REGISTER_OP("BatchMatrixSolveLs" ) |
647 | .Input("matrix: T" ) |
648 | .Input("rhs: T" ) |
649 | .Input("l2_regularizer: double" ) |
650 | .Output("output: T" ) |
651 | .Attr("T: {double, float}" ) |
652 | .Attr("fast: bool = True" ) |
653 | .Deprecated(13, "Use MatrixSolveLs instead." ) |
654 | .SetShapeFn(shape_inference::UnknownShape); |
655 | |
656 | REGISTER_OP("BatchSvd" ) |
657 | .Input("input: T" ) |
658 | .Output("s: T" ) |
659 | .Output("u: T" ) |
660 | .Output("v: T" ) |
661 | .Attr("compute_uv: bool = True" ) |
662 | .Attr("full_matrices: bool = False" ) |
663 | .Attr("T: {double, float, complex64, complex128}" ) |
664 | .Deprecated(13, "Use Svd instead." ) |
665 | .SetShapeFn(shape_inference::UnknownShape); |
666 | |
667 | } // namespace tensorflow |
668 | |