1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <ostream> |
18 | |
19 | #include "tensorflow/core/framework/common_shape_fns.h" |
20 | #include "tensorflow/core/framework/kernel_shape_util.h" |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/framework/shape_inference.h" |
23 | #include "tensorflow/core/framework/tensor.pb.h" |
24 | #include "tensorflow/core/framework/types.h" |
25 | #include "tensorflow/core/framework/types.pb.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/platform/types.h" |
28 | #include "tensorflow/core/util/mirror_pad_mode.h" |
29 | #include "tensorflow/core/util/padding.h" |
30 | #include "tensorflow/core/util/strided_slice_op.h" |
31 | #include "tensorflow/core/util/tensor_format.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | using shape_inference::DimensionHandle; |
36 | using shape_inference::InferenceContext; |
37 | using shape_inference::ShapeHandle; |
38 | using shape_inference::UnchangedShape; |
39 | |
40 | namespace { |
41 | |
42 | Status GetAxisForPackAndUnpack(InferenceContext* c, int32_t rank_after_pack, |
43 | int32* axis) { |
44 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , axis)); |
45 | if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) { |
46 | return errors::InvalidArgument("Invalid axis: " , *axis, "; must be in [" , |
47 | -1 * rank_after_pack, "," , rank_after_pack, |
48 | ")" ); |
49 | } |
50 | if (*axis < 0) *axis = (rank_after_pack + *axis); |
51 | return OkStatus(); |
52 | } |
53 | |
54 | template <typename T> |
55 | std::vector<int64_t> AsInt64(const Tensor* tensor, int64_t num_elements) { |
56 | std::vector<int64_t> ret(num_elements); |
57 | auto data = tensor->vec<T>(); |
58 | for (int64_t i = 0; i < num_elements; ++i) { |
59 | ret[i] = data(i); |
60 | } |
61 | return ret; |
62 | } |
63 | |
64 | template <typename T> |
65 | Status PadKnown(InferenceContext* c, ShapeHandle input, |
66 | const Tensor* paddings_t, int64_t num_dims) { |
67 | // paddings_t is known. |
68 | std::vector<DimensionHandle> dims(num_dims); |
69 | auto paddings_data = paddings_t->matrix<T>(); |
70 | for (int64_t i = 0; i < num_dims; ++i) { |
71 | const T pad0 = paddings_data(i, 0); |
72 | const T pad1 = paddings_data(i, 1); |
73 | if (pad0 < 0 || pad1 < 0) { |
74 | return errors::InvalidArgument("Paddings must be non-negative" ); |
75 | } |
76 | TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i])); |
77 | } |
78 | c->set_output(0, c->MakeShape(dims)); |
79 | return OkStatus(); |
80 | } |
81 | |
82 | Status PadShapeFn(InferenceContext* c) { |
83 | // Paddings is a matrix of [input_rank, 2]. |
84 | ShapeHandle paddings; |
85 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); |
86 | DimensionHandle unused; |
87 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused)); |
88 | |
89 | // n_dim and input.rank are equivalent. |
90 | ShapeHandle input = c->input(0); |
91 | DimensionHandle n_dim = c->Dim(paddings, 0); |
92 | if (c->ValueKnown(n_dim)) { |
93 | TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input)); |
94 | } else if (c->RankKnown(input)) { |
95 | TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim)); |
96 | } |
97 | |
98 | const Tensor* paddings_t = c->input_tensor(1); |
99 | |
100 | // paddings_t is unknown |
101 | if (paddings_t == nullptr) { |
102 | if (c->ValueKnown(n_dim)) { |
103 | // Make output with n_dim unknown dims. |
104 | c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim))); |
105 | } else { |
106 | c->set_output(0, c->UnknownShape()); |
107 | } |
108 | return OkStatus(); |
109 | } |
110 | |
111 | const int64_t num_dims = paddings_t->shape().dim_size(0); |
112 | TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input)); |
113 | TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim)); |
114 | |
115 | if (paddings_t->dtype() == DT_INT32) { |
116 | return PadKnown<int32>(c, input, paddings_t, num_dims); |
117 | } else { |
118 | return PadKnown<int64_t>(c, input, paddings_t, num_dims); |
119 | } |
120 | } |
121 | |
122 | Status TransposeShapeFn(InferenceContext* c) { |
123 | ShapeHandle input = c->input(0); |
124 | ShapeHandle perm_shape = c->input(1); |
125 | const Tensor* perm = c->input_tensor(1); |
126 | DimensionHandle perm_elems = c->NumElements(perm_shape); |
127 | // If we don't have rank information on the input or value information on |
128 | // perm we can't return any shape information, otherwise we have enough |
129 | // information to at least find the rank of the output. |
130 | if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) { |
131 | c->set_output(0, c->UnknownShape()); |
132 | return OkStatus(); |
133 | } |
134 | |
135 | // Find our value of the rank. |
136 | int64_t rank; |
137 | if (c->RankKnown(input)) { |
138 | rank = c->Rank(input); |
139 | } else if (c->ValueKnown(perm_elems)) { |
140 | rank = c->Value(perm_elems); |
141 | } else { |
142 | rank = perm->NumElements(); |
143 | } |
144 | if (!c->RankKnown(input) && rank < 2) { |
145 | // A permutation array containing a single element is ambiguous. It could |
146 | // indicate either a scalar or a 1-dimensional array, both of which the |
147 | // transpose op returns unchanged. |
148 | c->set_output(0, input); |
149 | return OkStatus(); |
150 | } |
151 | |
152 | std::vector<DimensionHandle> dims; |
153 | dims.resize(rank); |
154 | TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); |
155 | // Ensure that perm is a vector and has rank elements. |
156 | TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape)); |
157 | TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems)); |
158 | |
159 | // If we know the rank of the input and the value of perm, we can return |
160 | // all shape information, otherwise we can only return rank information, |
161 | // but no information for the dimensions. |
162 | if (perm != nullptr) { |
163 | std::vector<int64_t> data; |
164 | if (perm->dtype() == DT_INT32) { |
165 | data = AsInt64<int32>(perm, rank); |
166 | } else { |
167 | data = AsInt64<int64_t>(perm, rank); |
168 | } |
169 | |
170 | for (int32_t i = 0; i < rank; ++i) { |
171 | int64_t in_idx = data[i]; |
172 | if (in_idx >= rank || in_idx <= -rank) { |
173 | return errors::InvalidArgument("perm dim " , in_idx, |
174 | " is out of range of input rank " , rank); |
175 | } |
176 | dims[i] = c->Dim(input, in_idx); |
177 | } |
178 | } else { |
179 | for (int i = 0; i < rank; ++i) { |
180 | dims[i] = c->UnknownDim(); |
181 | } |
182 | } |
183 | |
184 | c->set_output(0, c->MakeShape(dims)); |
185 | return OkStatus(); |
186 | } |
187 | |
188 | Status SetOutputShapeForReshape(InferenceContext* c) { |
189 | ShapeHandle in = c->input(0); |
190 | ShapeHandle out; |
191 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); |
192 | |
193 | if (!c->RankKnown(out)) { |
194 | // We have no information about the shape of the output. |
195 | c->set_output(0, out); |
196 | return OkStatus(); |
197 | } |
198 | if (c->RankKnown(in)) { |
199 | // We don't know the number of output elements, but we can try to infer |
200 | // the missing dimension. |
201 | bool too_many_unknown = false; |
202 | int32_t out_unknown_idx = -1; |
203 | |
204 | DimensionHandle known_out_elems = c->NumElements(out); |
205 | if (!c->ValueKnown(known_out_elems)) { |
206 | known_out_elems = c->MakeDim(1); |
207 | for (int32_t i = 0; i < c->Rank(out); ++i) { |
208 | DimensionHandle dim = c->Dim(out, i); |
209 | if (!c->ValueKnown(dim)) { |
210 | if (out_unknown_idx >= 0) { |
211 | too_many_unknown = true; |
212 | break; |
213 | } |
214 | out_unknown_idx = i; |
215 | } else { |
216 | TF_RETURN_IF_ERROR( |
217 | c->Multiply(known_out_elems, dim, &known_out_elems)); |
218 | } |
219 | } |
220 | } |
221 | int32_t in_unknown_idx = -1; |
222 | DimensionHandle known_in_elems = c->NumElements(in); |
223 | if (!c->ValueKnown(known_in_elems)) { |
224 | known_in_elems = c->MakeDim(1); |
225 | for (int32_t i = 0; i < c->Rank(in); ++i) { |
226 | DimensionHandle dim = c->Dim(in, i); |
227 | if (!c->ValueKnown(dim)) { |
228 | if (in_unknown_idx >= 0) { |
229 | too_many_unknown = true; |
230 | break; |
231 | } |
232 | in_unknown_idx = i; |
233 | } else { |
234 | TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems)); |
235 | } |
236 | } |
237 | } |
238 | |
239 | if (!too_many_unknown) { |
240 | if (in_unknown_idx < 0 && out_unknown_idx < 0) { |
241 | // Just check that the dimensions match. |
242 | if (c->Value(known_in_elems) != c->Value(known_out_elems)) { |
243 | return errors::InvalidArgument( |
244 | "Cannot reshape a tensor with " , c->DebugString(known_in_elems), |
245 | " elements to shape " , c->DebugString(out), " (" , |
246 | c->DebugString(known_out_elems), " elements)" ); |
247 | } |
248 | } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 && |
249 | c->Value(known_out_elems) > 0) { |
250 | // Input fully known, infer the one missing output dim |
251 | DimensionHandle inferred_dim; |
252 | TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems), |
253 | true /* evenly_divisible */, |
254 | &inferred_dim)); |
255 | TF_RETURN_IF_ERROR( |
256 | c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out)); |
257 | |
258 | } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 && |
259 | c->Value(known_in_elems) != 0) { |
260 | // Output fully known, infer the one missing input dim |
261 | DimensionHandle inferred_dim; |
262 | TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems), |
263 | true /* evenly_divisible */, |
264 | &inferred_dim)); |
265 | DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); |
266 | TF_RETURN_IF_ERROR( |
267 | c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim)); |
268 | } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) { |
269 | // Exactly one unknown dimension in both input and output. These 2 are |
270 | // equal iff the known elements are equal. |
271 | if (c->Value(known_in_elems) == c->Value(known_out_elems)) { |
272 | DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); |
273 | TF_RETURN_IF_ERROR( |
274 | c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out)); |
275 | } |
276 | } |
277 | } |
278 | } |
279 | c->set_output(0, out); |
280 | return OkStatus(); |
281 | } |
282 | |
283 | } // namespace |
284 | |
285 | REGISTER_OP("ParallelConcat" ) |
286 | .Input("values: N * T" ) |
287 | .Output("output: T" ) |
288 | .Attr("N: int >= 1" ) |
289 | .Attr("T: type" ) |
290 | .Attr("shape: shape" ) |
291 | .SetShapeFn([](InferenceContext* c) { |
292 | // Validate that the shape attr is correct. |
293 | PartialTensorShape shape; |
294 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
295 | ShapeHandle passed_shape; |
296 | TF_RETURN_IF_ERROR( |
297 | c->MakeShapeFromPartialTensorShape(shape, &passed_shape)); |
298 | if (!c->FullyDefined(passed_shape)) { |
299 | return errors::InvalidArgument("shape attr must be fully defined." ); |
300 | } |
301 | ShapeHandle cur; |
302 | TF_RETURN_IF_ERROR(c->ReplaceDim( |
303 | passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)), |
304 | &cur)); |
305 | for (int i = 0; i < c->num_inputs(); ++i) { |
306 | if (!c->FullyDefined(c->input(i))) { |
307 | return errors::InvalidArgument( |
308 | "All input shapes must be fully defined." ); |
309 | } |
310 | DimensionHandle unused; |
311 | if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) { |
312 | return errors::InvalidArgument("Size of first dimension must be 1." ); |
313 | } |
314 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
315 | "From merging shape " , i, |
316 | " with other shapes." ); |
317 | } |
318 | |
319 | c->set_output(0, passed_shape); |
320 | |
321 | return OkStatus(); |
322 | }); |
323 | |
324 | REGISTER_OP("Pack" ) |
325 | .Input("values: N * T" ) |
326 | .Output("output: T" ) |
327 | .Attr("N: int >= 1" ) |
328 | .Attr("T: type" ) |
329 | .Attr("axis: int = 0" ) |
330 | .SetShapeFn([](InferenceContext* c) { |
331 | // Validate shapes of all inputs are compatible |
332 | ShapeHandle cur = c->input(c->num_inputs() - 1); |
333 | for (int i = c->num_inputs() - 2; i >= 0; --i) { |
334 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
335 | "From merging shape " , i, |
336 | " with other shapes." ); |
337 | } |
338 | if (!c->RankKnown(cur)) { |
339 | c->set_output(0, c->UnknownShape()); |
340 | return OkStatus(); |
341 | } |
342 | // Determine the axis that will be added, converting from negative |
343 | // axes to a positive point per negative indexing rules. |
344 | int32_t rank = c->Rank(cur); |
345 | int32_t axis; |
346 | TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis)); |
347 | |
348 | // Copy all dimensions over, inserting a dimension of value #inputs |
349 | // at <axis>. |
350 | std::vector<DimensionHandle> dims; |
351 | int index = 0; |
352 | while (index < axis) dims.push_back(c->Dim(cur, index++)); |
353 | dims.push_back(c->MakeDim(c->num_inputs())); |
354 | while (index < rank) dims.push_back(c->Dim(cur, index++)); |
355 | |
356 | c->set_output(0, c->MakeShape(dims)); |
357 | for (int i = 0; i < c->num_inputs(); ++i) { |
358 | auto* shape_and_type = c->input_handle_shapes_and_types(i); |
359 | if (shape_and_type) { |
360 | if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) { |
361 | c->set_output_handle_shapes_and_types( |
362 | 0, std::vector<shape_inference::ShapeAndType>({})); |
363 | break; |
364 | } |
365 | } |
366 | } |
367 | return OkStatus(); |
368 | }); |
369 | |
370 | REGISTER_OP("DeepCopy" ) |
371 | .Input("x: T" ) |
372 | .Output("y: T" ) |
373 | .Attr("T: type" ) |
374 | .SetIsStateful() |
375 | .SetShapeFn(UnchangedShape); |
376 | |
377 | REGISTER_OP("InplaceUpdate" ) |
378 | .Input("x: T" ) |
379 | .Input("i: int32" ) |
380 | .Input("v: T" ) |
381 | .Output("y: T" ) |
382 | .Attr("T: type" ) |
383 | .SetShapeFn(UnchangedShape); |
384 | |
385 | REGISTER_OP("InplaceAdd" ) |
386 | .Input("x: T" ) |
387 | .Input("i: int32" ) |
388 | .Input("v: T" ) |
389 | .Output("y: T" ) |
390 | .Attr("T: type" ) |
391 | .SetShapeFn(UnchangedShape); |
392 | |
393 | REGISTER_OP("InplaceSub" ) |
394 | .Input("x: T" ) |
395 | .Input("i: int32" ) |
396 | .Input("v: T" ) |
397 | .Output("y: T" ) |
398 | .Attr("T: type" ) |
399 | .SetShapeFn(UnchangedShape); |
400 | |
401 | REGISTER_OP("Empty" ) |
402 | .Input("shape: int32" ) |
403 | .Output("output: dtype" ) |
404 | .Attr("dtype: type" ) |
405 | .Attr("init: bool = false" ) |
406 | .SetDoNotOptimize() |
407 | .SetShapeFn([](InferenceContext* c) { |
408 | ShapeHandle out; |
409 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
410 | c->set_output(0, out); |
411 | return OkStatus(); |
412 | }); |
413 | |
414 | // -------------------------------------------------------------------------- |
415 | REGISTER_OP("Unpack" ) |
416 | .Input("value: T" ) |
417 | .Output("output: num * T" ) |
418 | .Attr("num: int >= 0" ) |
419 | .Attr("T: type" ) |
420 | .Attr("axis: int = 0" ) |
421 | .SetShapeFn([](InferenceContext* c) { |
422 | ShapeHandle s = c->input(0); |
423 | ShapeHandle out; |
424 | if (c->RankKnown(s)) { |
425 | // Determine the axis that will be removed, converting from negative |
426 | // axes to a positive point per negative indexing rules. |
427 | int32_t rank = c->Rank(s); |
428 | int32_t axis; |
429 | TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis)); |
430 | |
431 | // The axis dim matches the number of outputs. |
432 | DimensionHandle unused; |
433 | TF_RETURN_IF_ERROR( |
434 | c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused)); |
435 | |
436 | // Copy all dimensions, removing the <axis> dimension. |
437 | std::vector<DimensionHandle> dims; |
438 | for (int i = 0; i < rank; ++i) { |
439 | if (i != axis) dims.push_back(c->Dim(s, i)); |
440 | } |
441 | out = c->MakeShape(dims); |
442 | } else { |
443 | // All outputs are the same shape, but it's not known. |
444 | out = c->UnknownShape(); |
445 | } |
446 | for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out); |
447 | return OkStatus(); |
448 | }); |
449 | |
450 | REGISTER_OP("UnravelIndex" ) |
451 | .Input("indices: Tidx" ) |
452 | .Input("dims: Tidx" ) |
453 | .Output("output: Tidx" ) |
454 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
455 | .SetShapeFn([](InferenceContext* c) { |
456 | ShapeHandle indices = c->input(0); |
457 | ShapeHandle dims; |
458 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); |
459 | if (c->RankKnown(indices) && c->Rank(indices) == 0) { |
460 | c->set_output(0, c->Vector(c->Dim(dims, 0))); |
461 | } else if (c->RankKnown(indices)) { |
462 | c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices))); |
463 | } else { |
464 | c->set_output(0, c->UnknownShape()); |
465 | } |
466 | return OkStatus(); |
467 | }); |
468 | |
469 | REGISTER_OP("BroadcastTo" ) |
470 | .Input("input: T" ) |
471 | .Input("shape: Tidx" ) |
472 | .Output("output: T" ) |
473 | .Attr("T: type" ) |
474 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
475 | .SetShapeFn([](InferenceContext* c) { |
476 | ShapeHandle shape_in = c->input(1); |
477 | TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in)); |
478 | ShapeHandle out; |
479 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); |
480 | if (!c->RankKnown(out)) { |
481 | // We have no information about the shape of the output. |
482 | c->set_output(0, out); |
483 | return OkStatus(); |
484 | } |
485 | |
486 | ShapeHandle in = c->input(0); |
487 | if (!c->RankKnown(in)) { |
488 | // We have no information about the shape of the input, |
489 | // nothing to do here. |
490 | c->set_output(0, out); |
491 | return OkStatus(); |
492 | } |
493 | int out_rank = c->Rank(out); |
494 | TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in)); |
495 | int in_rank = c->Rank(in); |
496 | for (int i = 0; i < in_rank; ++i) { |
497 | auto in_dim = c->Dim(in, in_rank - i - 1); |
498 | if (c->Value(in_dim) > 1) { |
499 | // If the input dimension is greater than 1 then the output dimension |
500 | // must be equal to it, since we only broadcast "from left to right". |
501 | auto out_dim = c->Dim(out, out_rank - i - 1); |
502 | TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim)); |
503 | TF_RETURN_IF_ERROR( |
504 | c->ReplaceDim(out, out_rank - i - 1, out_dim, &out)); |
505 | } |
506 | } |
507 | c->set_output(0, out); |
508 | return OkStatus(); |
509 | }); |
510 | |
511 | // -------------------------------------------------------------------------- |
512 | // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph |
513 | // in the N == 1 case to remove the node. |
514 | REGISTER_OP("Concat" ) |
515 | .Input("concat_dim: int32" ) |
516 | .Input("values: N * T" ) |
517 | .Output("output: T" ) |
518 | .Attr("N: int >= 2" ) |
519 | .Attr("T: type" ) |
520 | .SetShapeFn([](InferenceContext* c) { |
521 | return shape_inference::ConcatShape(c, c->num_inputs() - 1); |
522 | }); |
523 | |
524 | REGISTER_OP("ConcatV2" ) |
525 | .Input("values: N * T" ) |
526 | .Input("axis: Tidx" ) |
527 | .Output("output: T" ) |
528 | .Attr("N: int >= 2" ) |
529 | .Attr("T: type" ) |
530 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
531 | .SetShapeFn(shape_inference::ConcatV2Shape); |
532 | |
533 | // TODO([email protected]): Prefix the op names with underscore if the ops |
534 | // are not to be made user-accessible. |
535 | #ifdef INTEL_MKL |
536 | REGISTER_OP("_MklConcatV2" ) |
537 | .Input("values: N * T" ) |
538 | .Input("axis: Tidx" ) |
539 | .Input("mkl_values: N * uint8" ) |
540 | .Input("mkl_axis: uint8" ) |
541 | .Output("output: T" ) |
542 | .Output("mkl_output: uint8" ) |
543 | .Attr("N: int >= 2" ) |
544 | .Attr("T: type" ) |
545 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
546 | .SetShapeFn(shape_inference::ConcatV2Shape) |
547 | .Doc(R"doc( |
548 | MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation. |
549 | |
550 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
551 | expected to invoke these operators. |
552 | )doc" ); |
553 | #endif |
554 | |
555 | REGISTER_OP("ConcatOffset" ) |
556 | .Input("concat_dim: int32" ) |
557 | .Input("shape: N * int32" ) |
558 | .Output("offset: N * int32" ) |
559 | .Attr("N: int >= 2" ) |
560 | .SetShapeFn([](InferenceContext* c) { |
561 | for (int i = 1; i < c->num_inputs(); ++i) { |
562 | c->set_output(i - 1, c->input(i)); |
563 | } |
564 | return OkStatus(); |
565 | }); |
566 | |
567 | // -------------------------------------------------------------------------- |
568 | REGISTER_OP("Split" ) |
569 | .Input("split_dim: int32" ) |
570 | .Input("value: T" ) |
571 | .Output("output: num_split * T" ) |
572 | .Attr("num_split: int >= 1" ) |
573 | .Attr("T: type" ) |
574 | .SetShapeFn([](InferenceContext* c) { |
575 | DimensionHandle split_dimension; |
576 | ShapeHandle input = c->input(1); |
577 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( |
578 | 0, c->Rank(input), &split_dimension)); |
579 | int num_split = c->num_outputs(); |
580 | ShapeHandle out; |
581 | if (!c->ValueKnown(split_dimension)) { |
582 | if (c->RankKnown(input)) { |
583 | out = c->UnknownShapeOfRank(c->Rank(input)); |
584 | } else { |
585 | out = c->UnknownShape(); |
586 | } |
587 | } else { |
588 | int64_t split_dim = c->Value(split_dimension); |
589 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); |
590 | DimensionHandle split_dim_size; |
591 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
592 | c->Divide(c->Dim(input, split_dim), num_split, |
593 | true /* evenly_divisible */, &split_dim_size), |
594 | "Number of ways to split should evenly divide the split dimension" ); |
595 | TF_RETURN_IF_ERROR( |
596 | c->ReplaceDim(input, split_dim, split_dim_size, &out)); |
597 | } |
598 | for (int i = 0; i < num_split; ++i) c->set_output(i, out); |
599 | return OkStatus(); |
600 | }); |
601 | |
602 | REGISTER_OP("SplitV" ) |
603 | .Input("value: T" ) |
604 | .Input("size_splits: Tlen" ) |
605 | .Input("split_dim: int32" ) |
606 | .Output("output: num_split * T" ) |
607 | .Attr("num_split: int >= 1" ) |
608 | .Attr("T: type" ) |
609 | .Attr("Tlen: {int8, int32, int64} = DT_INT64" ) |
610 | .SetShapeFn([](InferenceContext* c) { |
611 | DimensionHandle split_dimension; |
612 | ShapeHandle input = c->input(0); |
613 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( |
614 | 2, c->Rank(input), &split_dimension)); |
615 | int32_t num_outputs = c->num_outputs(); |
616 | int32_t rank = c->Rank(input); |
617 | ShapeHandle output_shape; |
618 | const Tensor* size_splits = c->input_tensor(1); |
619 | if (rank == InferenceContext::kUnknownRank) { |
620 | // If the rank of input tensor is unknown, then return unknown shapes. |
621 | // Note that the shape of each output can be different. |
622 | for (int i = 0; i < num_outputs; ++i) { |
623 | c->set_output(i, c->UnknownShape()); |
624 | } |
625 | } else if (rank == 0) { |
626 | // Throw error if input is a scalar. |
627 | return errors::InvalidArgument("Can't split scalars" ); |
628 | } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) { |
629 | // If split dimension is known, but the sizes are unknown, then |
630 | // only the split dimension is unknown |
631 | output_shape = input; |
632 | for (int i = 0; i < num_outputs; ++i) { |
633 | TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, |
634 | c->Value(split_dimension), |
635 | c->UnknownDim(), &output_shape)); |
636 | c->set_output(i, output_shape); |
637 | } |
638 | } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) { |
639 | // If split dimension or tensor containing the split sizes is unknown, |
640 | // then return unknown shapes of same rank as input. Note that each |
641 | // output shape can be different since splitv doesn't always split |
642 | // tensors evenly. |
643 | for (int i = 0; i < num_outputs; ++i) { |
644 | c->set_output(i, c->UnknownShapeOfRank(rank)); |
645 | } |
646 | } else { |
647 | // Determine the output shape if split dimension and split sizes are |
648 | // known. |
649 | int64_t split_dim = c->Value(split_dimension); |
650 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); |
651 | std::vector<int64_t> data; |
652 | if (size_splits->dtype() == DT_INT32) { |
653 | data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0)); |
654 | } else { |
655 | data = |
656 | AsInt64<int64_t>(size_splits, size_splits->shape().dim_size(0)); |
657 | } |
658 | if (num_outputs != data.size()) { |
659 | return errors::InvalidArgument( |
660 | "Length of size_splits should be equal to num_outputs" ); |
661 | } |
662 | int64_t total_size = 0; |
663 | bool has_neg_one = false; |
664 | for (const auto size : data) { |
665 | if (size == -1) { |
666 | if (has_neg_one) { |
667 | return errors::InvalidArgument( |
668 | "size_splits can only have one -1" ); |
669 | } |
670 | has_neg_one = true; |
671 | } else { |
672 | total_size += size; |
673 | } |
674 | } |
675 | auto split_dim_size = c->Value(c->Dim(input, split_dim)); |
676 | // If the sizes of the splits are known, then |
677 | // make sure that the sizes add up to the expected |
678 | // dimension size, with the possibility of a -1. |
679 | // Specify the full output shapes. |
680 | for (int i = 0; i < num_outputs; ++i) { |
681 | auto size = data[i]; |
682 | if (data[i] == -1 && c->ValueKnown(split_dim_size)) { |
683 | size = split_dim_size - total_size; |
684 | } |
685 | // If we have a negative known size (either explicit, or computed |
686 | // via -1), then the split sizes are invalid. |
687 | if (size < -1 || (size == -1 && c->ValueKnown(split_dim_size))) { |
688 | return errors::InvalidArgument("Split size at index " , i, |
689 | " must be >= 0. Got: " , size); |
690 | } |
691 | TF_RETURN_IF_ERROR( |
692 | c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape)); |
693 | c->set_output(i, output_shape); |
694 | } |
695 | if (c->ValueKnown(split_dim_size)) { |
696 | if (has_neg_one ? total_size > split_dim_size |
697 | : total_size != split_dim_size) { |
698 | return errors::InvalidArgument( |
699 | "can't split axis of size " , split_dim_size, |
700 | " into pieces of size [" , absl::StrJoin(data, "," ), "]" ); |
701 | } |
702 | } |
703 | } |
704 | |
705 | return OkStatus(); |
706 | }); |
707 | |
708 | // -------------------------------------------------------------------------- |
709 | REGISTER_OP("Const" ) |
710 | .Output("output: dtype" ) |
711 | .Attr("value: tensor" ) |
712 | .Attr("dtype: type" ) |
713 | .SetShapeFn([](InferenceContext* c) { |
714 | const TensorProto* proto = nullptr; |
715 | TF_RETURN_IF_ERROR(c->GetAttr("value" , &proto)); |
716 | TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape())); |
717 | TensorShape shape(proto->tensor_shape()); |
718 | std::vector<DimensionHandle> dims; |
719 | dims.reserve(shape.dims()); |
720 | for (int i = 0; i < shape.dims(); ++i) { |
721 | dims.push_back(c->MakeDim(shape.dim_size(i))); |
722 | } |
723 | c->set_output(0, c->MakeShape(dims)); |
724 | return OkStatus(); |
725 | }); |
726 | |
727 | // Returns a constant tensor on the host. Useful for writing C++ tests |
728 | // and benchmarks which run on GPU but require arguments pinned to the host. |
729 | // Used by test::graph::HostConstant. |
730 | // value: Attr `value` is the tensor to return. |
731 | REGISTER_OP("HostConst" ) |
732 | .Output("output: dtype" ) |
733 | .Attr("value: tensor" ) |
734 | .Attr("dtype: type" ) |
735 | .SetShapeFn(shape_inference::UnknownShape); |
736 | |
737 | // Used executing op-by-op to copy constants to the current device without |
738 | // serializing tensors as TensorProtos, after a host tensor has been |
739 | // created. Same behavior as Identity, but no gradient and potentially relaxed |
740 | // copy semantics. |
741 | REGISTER_OP("_EagerConst" ) |
742 | .Input("input: T" ) |
743 | .Output("output: T" ) |
744 | .Attr("T: type" ) |
745 | .SetShapeFn(shape_inference::UnchangedShape); |
746 | |
747 | // -------------------------------------------------------------------------- |
748 | // TODO(mgubin): Update the doc when the freeze_graph script supports converting |
749 | // into memmapped format. |
750 | REGISTER_OP("ImmutableConst" ) |
751 | .Attr("dtype: type" ) |
752 | .Attr("shape: shape" ) |
753 | .Attr("memory_region_name: string" ) |
754 | .Output("tensor: dtype" ) |
755 | .SetShapeFn(shape_inference::ExplicitShape); |
756 | |
757 | REGISTER_OP("GuaranteeConst" ) |
758 | .Input("input: T" ) |
759 | .Output("output: T" ) |
760 | .Attr("T: type" ) |
761 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
762 | return UnchangedShape(c); |
763 | }) |
764 | // We don't want this to be optimized away. |
765 | .SetDoNotOptimize(); |
766 | |
767 | // -------------------------------------------------------------------------- |
768 | REGISTER_OP("ZerosLike" ) |
769 | .Input("x: T" ) |
770 | .Output("y: T" ) |
771 | .Attr("T: type" ) |
772 | .SetShapeFn(shape_inference::UnchangedShape); |
773 | |
774 | // -------------------------------------------------------------------------- |
775 | REGISTER_OP("OnesLike" ) |
776 | .Input("x: T" ) |
777 | .Output("y: T" ) |
778 | .Attr( |
779 | "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, " |
780 | "uint32, int64, uint64, complex64, complex128, bool}" ) |
781 | .SetShapeFn(shape_inference::UnchangedShape); |
782 | |
783 | // -------------------------------------------------------------------------- |
784 | REGISTER_OP("Diag" ) |
785 | .Input("diagonal: T" ) |
786 | .Output("output: T" ) |
787 | .Attr( |
788 | "T: {bfloat16, half, float, double, int32, int64, complex64, " |
789 | "complex128}" ) |
790 | .SetShapeFn([](InferenceContext* c) { |
791 | ShapeHandle in = c->input(0); |
792 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in)); |
793 | // Output shape is original concatenated with itself. |
794 | ShapeHandle out; |
795 | TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out)); |
796 | c->set_output(0, out); |
797 | return OkStatus(); |
798 | }); |
799 | |
800 | // -------------------------------------------------------------------------- |
801 | REGISTER_OP("DiagPart" ) |
802 | .Input("input: T" ) |
803 | .Output("diagonal: T" ) |
804 | .Attr( |
805 | "T: {bfloat16, half, float, double, int32, int64, complex64, " |
806 | "complex128}" ) |
807 | .SetShapeFn([](InferenceContext* c) { |
808 | ShapeHandle in = c->input(0); |
809 | if (!c->RankKnown(in)) { |
810 | c->set_output(0, c->UnknownShape()); |
811 | return OkStatus(); |
812 | } |
813 | // Rank must be even, and result will have rank <rank/2>. |
814 | const int32_t rank = c->Rank(in); |
815 | if ((rank % 2) != 0 || rank <= 0) { |
816 | return errors::InvalidArgument( |
817 | "Input must have even and non-zero rank, input rank is " , rank); |
818 | } |
819 | const int32_t mid = rank / 2; |
820 | |
821 | // output dim[i] is the merge of in.dim[i] and in.dim[i+mid]. |
822 | std::vector<DimensionHandle> dims(mid); |
823 | for (int i = 0; i < mid; ++i) { |
824 | TF_RETURN_IF_ERROR( |
825 | c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i])); |
826 | } |
827 | c->set_output(0, c->MakeShape(dims)); |
828 | return OkStatus(); |
829 | }); |
830 | |
831 | // -------------------------------------------------------------------------- |
832 | REGISTER_OP("MatrixDiag" ) |
833 | .Input("diagonal: T" ) |
834 | .Output("output: T" ) |
835 | .Attr("T: type" ) |
836 | .SetShapeFn([](InferenceContext* c) { |
837 | ShapeHandle in; |
838 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); |
839 | if (!c->RankKnown(in)) { |
840 | c->set_output(0, c->UnknownShape()); |
841 | return OkStatus(); |
842 | } |
843 | const int32_t rank = c->Rank(in); |
844 | ShapeHandle out; |
845 | TF_RETURN_IF_ERROR( |
846 | c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out)); |
847 | c->set_output(0, out); |
848 | return OkStatus(); |
849 | }); |
850 | |
851 | REGISTER_OP("MatrixDiagV2" ) |
852 | .Input("diagonal: T" ) |
853 | .Input("k: int32" ) |
854 | .Input("num_rows: int32" ) |
855 | .Input("num_cols: int32" ) |
856 | .Input("padding_value: T" ) |
857 | .Output("output: T" ) |
858 | .Attr("T: type" ) |
859 | .SetShapeFn(shape_inference::MatrixDiagV2Shape); |
860 | |
861 | REGISTER_OP("MatrixDiagV3" ) |
862 | .Input("diagonal: T" ) |
863 | .Input("k: int32" ) |
864 | .Input("num_rows: int32" ) |
865 | .Input("num_cols: int32" ) |
866 | .Input("padding_value: T" ) |
867 | .Output("output: T" ) |
868 | .Attr("T: type" ) |
869 | .Attr( |
870 | "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " |
871 | "'RIGHT_LEFT'" ) |
872 | .SetShapeFn(shape_inference::MatrixDiagV2Shape); |
873 | |
874 | // -------------------------------------------------------------------------- |
875 | REGISTER_OP("MatrixSetDiag" ) |
876 | .Input("input: T" ) |
877 | .Input("diagonal: T" ) |
878 | .Output("output: T" ) |
879 | .Attr("T: type" ) |
880 | .SetShapeFn([](InferenceContext* c) { |
881 | ShapeHandle input; |
882 | ShapeHandle diag; |
883 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); |
884 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag)); |
885 | if (c->RankKnown(input)) { |
886 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag)); |
887 | } |
888 | DimensionHandle smallest_dim; |
889 | TF_RETURN_IF_ERROR( |
890 | c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim)); |
891 | TF_RETURN_IF_ERROR( |
892 | c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim)); |
893 | |
894 | ShapeHandle output = input; |
895 | if (c->RankKnown(diag) && !c->FullyDefined(input)) { |
896 | // Try to infer parts of shape from diag. |
897 | ShapeHandle diag_batch_shape; |
898 | TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_batch_shape)); |
899 | TF_RETURN_IF_ERROR( |
900 | c->Concatenate(diag_batch_shape, c->UnknownShapeOfRank(2), &diag)); |
901 | TF_RETURN_IF_ERROR(c->Merge(input, diag, &output)); |
902 | } |
903 | c->set_output(0, output); |
904 | return OkStatus(); |
905 | }); |
906 | |
907 | REGISTER_OP("MatrixSetDiagV2" ) |
908 | .Input("input: T" ) |
909 | .Input("diagonal: T" ) |
910 | .Input("k: int32" ) |
911 | .Output("output: T" ) |
912 | .Attr("T: type" ) |
913 | .SetShapeFn(shape_inference::MatrixSetDiagV2Shape); |
914 | |
915 | REGISTER_OP("MatrixSetDiagV3" ) |
916 | .Input("input: T" ) |
917 | .Input("diagonal: T" ) |
918 | .Input("k: int32" ) |
919 | .Output("output: T" ) |
920 | .Attr("T: type" ) |
921 | .Attr( |
922 | "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " |
923 | "'RIGHT_LEFT'" ) |
924 | .SetShapeFn(shape_inference::MatrixSetDiagV2Shape); |
925 | |
926 | // -------------------------------------------------------------------------- |
927 | REGISTER_OP("MatrixDiagPart" ) |
928 | .Input("input: T" ) |
929 | .Output("diagonal: T" ) |
930 | .Attr("T: type" ) |
931 | .SetShapeFn([](InferenceContext* c) { |
932 | ShapeHandle in; |
933 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in)); |
934 | if (!c->RankKnown(in)) { |
935 | c->set_output(0, c->UnknownShape()); |
936 | return OkStatus(); |
937 | } |
938 | const int32_t rank = c->Rank(in); |
939 | std::vector<DimensionHandle> dims; |
940 | dims.reserve(rank - 2); |
941 | for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i)); |
942 | |
943 | DimensionHandle min_dim; |
944 | TF_RETURN_IF_ERROR( |
945 | c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim)); |
946 | dims.push_back(min_dim); |
947 | c->set_output(0, c->MakeShape(dims)); |
948 | return OkStatus(); |
949 | }); |
950 | |
951 | REGISTER_OP("MatrixDiagPartV2" ) |
952 | .Input("input: T" ) |
953 | .Input("k: int32" ) |
954 | .Input("padding_value: T" ) |
955 | .Output("diagonal: T" ) |
956 | .Attr("T: type" ) |
957 | .SetShapeFn(shape_inference::MatrixDiagPartV2Shape); |
958 | |
959 | REGISTER_OP("MatrixDiagPartV3" ) |
960 | .Input("input: T" ) |
961 | .Input("k: int32" ) |
962 | .Input("padding_value: T" ) |
963 | .Output("diagonal: T" ) |
964 | .Attr("T: type" ) |
965 | .Attr( |
966 | "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " |
967 | "'RIGHT_LEFT'" ) |
968 | .SetShapeFn(shape_inference::MatrixDiagPartV2Shape); |
969 | |
970 | // -------------------------------------------------------------------------- |
971 | REGISTER_OP("MatrixBandPart" ) |
972 | .Input("input: T" ) |
973 | .Input("num_lower: Tindex" ) |
974 | .Input("num_upper: Tindex" ) |
975 | .Output("band: T" ) |
976 | .Attr("T: type" ) |
977 | .Attr("Tindex: {int32, int64} = DT_INT64" ) |
978 | .SetShapeFn(shape_inference::UnchangedShape); |
979 | |
980 | // -------------------------------------------------------------------------- |
981 | REGISTER_OP("Reverse" ) |
982 | .Input("tensor: T" ) |
983 | .Input("dims: bool" ) |
984 | .Output("output: T" ) |
985 | .Attr( |
986 | "T: {uint8, int8, uint16, int16, uint32, int32, uint64, int64, bool, " |
987 | "bfloat16, half, float, double, complex64, complex128, string}" ) |
988 | .SetShapeFn([](InferenceContext* c) { |
989 | ShapeHandle input = c->input(0); |
990 | ShapeHandle dims; |
991 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); |
992 | DimensionHandle dims_dim = c->Dim(dims, 0); |
993 | if (c->ValueKnown(dims_dim)) { |
994 | TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input)); |
995 | } |
996 | if (c->Rank(input) > 8) { |
997 | return errors::InvalidArgument( |
998 | "reverse does not work on tensors with more than 8 dimensions" ); |
999 | } |
1000 | c->set_output(0, input); |
1001 | return OkStatus(); |
1002 | }); |
1003 | |
1004 | // -------------------------------------------------------------------------- |
1005 | REGISTER_OP("ReverseV2" ) |
1006 | .Input("tensor: T" ) |
1007 | .Input("axis: Tidx" ) |
1008 | .Output("output: T" ) |
1009 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1010 | .Attr( |
1011 | "T: {uint8, int8, uint16, int16, int32, uint32, int64, uint64, bool, " |
1012 | "bfloat16, half, float, double, complex64, complex128, string}" ) |
1013 | .SetShapeFn([](InferenceContext* c) { |
1014 | ShapeHandle input = c->input(0); |
1015 | ShapeHandle axis; |
1016 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis)); |
1017 | if (c->Rank(input) > 8) { |
1018 | return errors::InvalidArgument( |
1019 | "reverse does not work on tensors with more than 8 dimensions" ); |
1020 | } |
1021 | const Tensor* axis_tensor = c->input_tensor(1); |
1022 | if (axis_tensor != nullptr && c->RankKnown(input)) { |
1023 | int32_t rank = c->Rank(input); |
1024 | std::vector<int64_t> axis_value; |
1025 | if (axis_tensor->dtype() == DT_INT32) { |
1026 | axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements()); |
1027 | } else { |
1028 | axis_value = |
1029 | AsInt64<int64_t>(axis_tensor, axis_tensor->NumElements()); |
1030 | } |
1031 | std::vector<bool> axes_dense(c->Rank(input), false); |
1032 | for (int i = 0; i < axis_value.size(); i++) { |
1033 | int64_t canonical_axis = |
1034 | axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i]; |
1035 | if (canonical_axis < 0 || canonical_axis >= rank) { |
1036 | return errors::InvalidArgument("'axis'[" , i, "] = " , axis_value[i], |
1037 | " is out of valid range [" , 0, ", " , |
1038 | rank - 1); |
1039 | } |
1040 | if (axes_dense[canonical_axis]) { |
1041 | return errors::InvalidArgument("axis " , canonical_axis, |
1042 | " specified more than once." ); |
1043 | } |
1044 | axes_dense[canonical_axis] = true; |
1045 | } |
1046 | } |
1047 | c->set_output(0, input); |
1048 | return OkStatus(); |
1049 | }); |
1050 | |
1051 | // -------------------------------------------------------------------------- |
1052 | REGISTER_OP("EditDistance" ) |
1053 | .Input("hypothesis_indices: int64" ) |
1054 | .Input("hypothesis_values: T" ) |
1055 | .Input("hypothesis_shape: int64" ) |
1056 | .Input("truth_indices: int64" ) |
1057 | .Input("truth_values: T" ) |
1058 | .Input("truth_shape: int64" ) |
1059 | .Attr("normalize: bool = true" ) |
1060 | .Attr("T: type" ) |
1061 | .Output("output: float" ) |
1062 | .SetShapeFn([](InferenceContext* c) { |
1063 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
1064 | c, c->input(0), c->input(1), c->input(2))); |
1065 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
1066 | c, c->input(3), c->input(4), c->input(5))); |
1067 | const Tensor* hypothesis_shape_t = c->input_tensor(2); |
1068 | const Tensor* truth_shape_t = c->input_tensor(5); |
1069 | if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) { |
1070 | // We need to know the runtime shape of the two tensors, |
1071 | // or else the output shape is unknown. |
1072 | return shape_inference::UnknownShape(c); |
1073 | } |
1074 | |
1075 | if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) { |
1076 | return errors::InvalidArgument( |
1077 | "Num elements of hypothesis_shape does not match truth_shape: " , |
1078 | hypothesis_shape_t->NumElements(), " vs. " , |
1079 | truth_shape_t->NumElements()); |
1080 | } |
1081 | |
1082 | auto h_values = hypothesis_shape_t->flat<int64_t>(); |
1083 | auto t_values = truth_shape_t->flat<int64_t>(); |
1084 | std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1); |
1085 | for (int i = 0; i < dims.size(); ++i) { |
1086 | dims[i] = c->MakeDim(std::max(h_values(i), t_values(i))); |
1087 | } |
1088 | |
1089 | c->set_output(0, c->MakeShape(dims)); |
1090 | return OkStatus(); |
1091 | }); |
1092 | |
1093 | // -------------------------------------------------------------------------- |
1094 | REGISTER_OP("Fill" ) |
1095 | .Input("dims: index_type" ) |
1096 | .Input("value: T" ) |
1097 | .Output("output: T" ) |
1098 | .Attr("T: type" ) |
1099 | .Attr("index_type: {int32, int64} = DT_INT32" ) |
1100 | .SetShapeFn([](InferenceContext* c) { |
1101 | DataType index_type = DT_INT32; |
1102 | Status s = c->GetAttr("index_type" , &index_type); |
1103 | if (!s.ok() && s.code() != error::NOT_FOUND) { |
1104 | return s; |
1105 | } |
1106 | ShapeHandle unused; |
1107 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
1108 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1109 | |
1110 | const Tensor* t = c->input_tensor(0); |
1111 | if (t != nullptr) { |
1112 | for (int i = 0; i < t->NumElements(); ++i) { |
1113 | if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) || |
1114 | (index_type == DT_INT64 && t->vec<int64_t>()(i) < 0)) { |
1115 | return errors::InvalidArgument("Fill dimensions must be >= 0" ); |
1116 | } |
1117 | } |
1118 | } |
1119 | |
1120 | ShapeHandle out; |
1121 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
1122 | c->set_output(0, out); |
1123 | |
1124 | auto* shape_and_type = c->input_handle_shapes_and_types(1); |
1125 | if (shape_and_type) { |
1126 | c->set_output_handle_shapes_and_types(0, *shape_and_type); |
1127 | } |
1128 | |
1129 | return OkStatus(); |
1130 | }); |
1131 | |
1132 | // -------------------------------------------------------------------------- |
1133 | REGISTER_OP("_ParallelConcatStart" ) |
1134 | .Output("output: dtype" ) |
1135 | .Attr("shape: shape" ) |
1136 | .Attr("dtype: type" ) |
1137 | .SetIsStateful() |
1138 | .SetShapeFn(shape_inference::ExplicitShape) |
1139 | .Doc(R"doc( |
1140 | Creates an empty Tensor with shape `shape` and type `dtype`. |
1141 | |
1142 | The memory can optionally be initialized. This is usually useful in |
1143 | conjunction with inplace operations. |
1144 | |
1145 | shape: 1-D `Tensor` indicating the shape of the output. |
1146 | dtype: The element type of the returned tensor. |
1147 | output: An empty Tensor of the specified type. |
1148 | )doc" ); |
1149 | |
1150 | // -------------------------------------------------------------------------- |
1151 | REGISTER_OP("_ParallelConcatUpdate" ) |
1152 | .Input("value: T" ) |
1153 | .Input("update: T" ) |
1154 | .Output("output: T" ) |
1155 | .Attr("T: type" ) |
1156 | .Attr("loc: int" ) |
1157 | .SetShapeFn(shape_inference::UnchangedShape) |
1158 | .Doc(R"doc( |
1159 | Updates input `value` at `loc` with `update`. |
1160 | |
1161 | If you use this function you will almost certainly want to add |
1162 | a control dependency as done in the implementation of parallel_stack to |
1163 | avoid race conditions. |
1164 | |
1165 | value: A `Tensor` object that will be updated in-place. |
1166 | loc: A scalar indicating the index of the first dimension such that |
1167 | value[loc, :] is updated. |
1168 | update: A `Tensor` of rank one less than `value` if `loc` is a scalar, |
1169 | otherwise of rank equal to `value` that contains the new values |
1170 | for `value`. |
1171 | output: `value` that has been updated accordingly. |
1172 | )doc" ); |
1173 | |
1174 | // -------------------------------------------------------------------------- |
1175 | REGISTER_OP("Gather" ) |
1176 | .Input("params: Tparams" ) |
1177 | .Input("indices: Tindices" ) |
1178 | .Attr("validate_indices: bool = true" ) |
1179 | .Output("output: Tparams" ) |
1180 | .Attr("Tparams: type" ) |
1181 | .Attr("Tindices: {int32,int64}" ) |
1182 | .SetShapeFn([](InferenceContext* c) { |
1183 | ShapeHandle unused; |
1184 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); |
1185 | ShapeHandle params_subshape; |
1186 | TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, ¶ms_subshape)); |
1187 | ShapeHandle indices_shape = c->input(1); |
1188 | ShapeHandle out; |
1189 | TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out)); |
1190 | c->set_output(0, out); |
1191 | return OkStatus(); |
1192 | }); |
1193 | |
1194 | // -------------------------------------------------------------------------- |
1195 | REGISTER_OP("GatherV2" ) |
1196 | .Input("params: Tparams" ) |
1197 | .Input("indices: Tindices" ) |
1198 | .Input("axis: Taxis" ) |
1199 | .Attr("batch_dims: int = 0" ) |
1200 | .Output("output: Tparams" ) |
1201 | .Attr("Tparams: type" ) |
1202 | .Attr("Tindices: {int16, int32,int64}" ) |
1203 | .Attr("Taxis: {int32,int64}" ) |
1204 | .SetShapeFn([](InferenceContext* c) { |
1205 | ShapeHandle params_shape; |
1206 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, ¶ms_shape)); |
1207 | |
1208 | ShapeHandle indices_shape = c->input(1); |
1209 | ShapeHandle unused_axis_shape; |
1210 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape)); |
1211 | const Tensor* axis_t = c->input_tensor(2); |
1212 | |
1213 | // If axis is unknown, we can only infer that the result is params_rank + |
1214 | // indices_rank - 1. |
1215 | if (axis_t == nullptr) { |
1216 | if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) { |
1217 | int32_t batch_dims; |
1218 | TF_RETURN_IF_ERROR(c->GetAttr("batch_dims" , &batch_dims)); |
1219 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) + |
1220 | c->Rank(indices_shape) - 1 - |
1221 | batch_dims)); |
1222 | } else { |
1223 | c->set_output(0, c->UnknownShape()); |
1224 | } |
1225 | return OkStatus(); |
1226 | } |
1227 | |
1228 | // Note, axis can be negative. |
1229 | int64_t axis = 0; |
1230 | if (axis_t->dtype() == DT_INT32) { |
1231 | axis = axis_t->scalar<int32>()(); |
1232 | } else { |
1233 | axis = axis_t->scalar<int64_t>()(); |
1234 | } |
1235 | |
1236 | // Check that params has rank of at least axis + 1. |
1237 | ShapeHandle unused; |
1238 | TF_RETURN_IF_ERROR(c->WithRankAtLeast( |
1239 | params_shape, axis < 0 ? -axis : axis + 1, &unused)); |
1240 | |
1241 | // Note, batch_dims can be negative. |
1242 | int32_t batch_dims; |
1243 | TF_RETURN_IF_ERROR(c->GetAttr("batch_dims" , &batch_dims)); |
1244 | // -rank(indices) <= batch_dims <= rank(indices) |
1245 | TF_RETURN_IF_ERROR( |
1246 | c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused)); |
1247 | if (batch_dims < 0) { |
1248 | batch_dims += c->Rank(indices_shape); |
1249 | } |
1250 | // rank(params) > batch_dims |
1251 | TF_RETURN_IF_ERROR( |
1252 | c->WithRankAtLeast(params_shape, batch_dims + 1, &unused)); |
1253 | |
1254 | ShapeHandle params_outer_subshape; |
1255 | TF_RETURN_IF_ERROR( |
1256 | c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape)); |
1257 | |
1258 | ShapeHandle indices_inner_subshape; |
1259 | TF_RETURN_IF_ERROR( |
1260 | c->Subshape(indices_shape, batch_dims, &indices_inner_subshape)); |
1261 | |
1262 | ShapeHandle out; |
1263 | TF_RETURN_IF_ERROR( |
1264 | c->Concatenate(params_outer_subshape, indices_inner_subshape, &out)); |
1265 | |
1266 | // Slice from axis + 1 to the end of params_shape to collect the inner |
1267 | // dimensions of the result. Special case -1 here since -1 + 1 wraps, and |
1268 | // we slice from 0 to the end of shape. Subshape() handles all other |
1269 | // out-of-bounds checking. |
1270 | if (axis != -1) { |
1271 | ShapeHandle params_inner_subshape; |
1272 | TF_RETURN_IF_ERROR( |
1273 | c->Subshape(params_shape, axis + 1, ¶ms_inner_subshape)); |
1274 | TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out)); |
1275 | } |
1276 | |
1277 | c->set_output(0, out); |
1278 | return OkStatus(); |
1279 | }); |
1280 | |
1281 | // -------------------------------------------------------------------------- |
1282 | REGISTER_OP("GatherNd" ) |
1283 | .Input("params: Tparams" ) |
1284 | .Input("indices: Tindices" ) |
1285 | .Output("output: Tparams" ) |
1286 | .Attr("Tparams: type" ) |
1287 | .Attr("Tindices: {int16, int32,int64}" ) |
1288 | .SetShapeFn(shape_inference::GatherNdShape); |
1289 | |
1290 | // -------------------------------------------------------------------------- |
1291 | REGISTER_OP("Identity" ) |
1292 | .Input("input: T" ) |
1293 | .Output("output: T" ) |
1294 | .Attr("T: type" ) |
1295 | .SetForwardTypeFn(full_type::ReplicateInput()) |
1296 | .SetShapeFn(shape_inference::UnchangedShape); |
1297 | |
1298 | REGISTER_OP("Snapshot" ) |
1299 | .Input("input: T" ) |
1300 | .Output("output: T" ) |
1301 | .Attr("T: type" ) |
1302 | .SetShapeFn(shape_inference::UnchangedShape); |
1303 | |
1304 | #ifdef INTEL_MKL |
1305 | REGISTER_OP("_MklIdentity" ) |
1306 | .Input("input: T" ) |
1307 | .Input("mkl_input: uint8" ) |
1308 | .Output("output: T" ) |
1309 | .Output("mkl_output: uint8" ) |
1310 | .Attr("T: type" ) |
1311 | .SetShapeFn(shape_inference::UnchangedShape) |
1312 | .Doc(R"Doc( Mkl implementation of IdentityOp |
1313 | )Doc" ); |
1314 | #endif |
1315 | |
1316 | REGISTER_OP("IdentityN" ) |
1317 | .Input("input: T" ) |
1318 | .Output("output: T" ) |
1319 | .Attr("T: list(type)" ) |
1320 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1321 | std::vector<ShapeHandle> input; |
1322 | TF_RETURN_IF_ERROR(c->input("input" , &input)); |
1323 | TF_RETURN_IF_ERROR(c->set_output("output" , input)); |
1324 | // If any of the input shapes are not known, we should return error. |
1325 | for (int i = 0; i < input.size(); i++) { |
1326 | if (!input[i].Handle()) { |
1327 | return errors::InvalidArgument(absl::StrCat( |
1328 | "Cannot infer output shape #" , i, |
1329 | " for IdentityN node because input shape #" , i, " is unknown." )); |
1330 | } |
1331 | } |
1332 | return OkStatus(); |
1333 | }); |
1334 | |
1335 | // -------------------------------------------------------------------------- |
1336 | REGISTER_OP("RefIdentity" ) |
1337 | .Input("input: Ref(T)" ) |
1338 | .Output("output: Ref(T)" ) |
1339 | .Attr("T: type" ) |
1340 | .SetShapeFn(shape_inference::UnchangedShape) |
1341 | .SetAllowsUninitializedInput(); |
1342 | |
1343 | // -------------------------------------------------------------------------- |
1344 | REGISTER_OP("DebugGradientIdentity" ) |
1345 | .Input("input: T" ) |
1346 | .Output("output: T" ) |
1347 | .Attr("T: type" ) |
1348 | .SetShapeFn(shape_inference::UnchangedShape) |
1349 | .SetAllowsUninitializedInput(); |
1350 | |
1351 | REGISTER_OP("DebugGradientRefIdentity" ) |
1352 | .Input("input: Ref(T)" ) |
1353 | .Output("output: Ref(T)" ) |
1354 | .Attr("T: type" ) |
1355 | .SetShapeFn(shape_inference::UnchangedShape) |
1356 | .SetAllowsUninitializedInput(); |
1357 | |
1358 | // -------------------------------------------------------------------------- |
1359 | REGISTER_OP("StopGradient" ) |
1360 | .Input("input: T" ) |
1361 | .Output("output: T" ) |
1362 | .Attr("T: type" ) |
1363 | .SetShapeFn(shape_inference::UnchangedShape); |
1364 | |
1365 | REGISTER_OP("PreventGradient" ) |
1366 | .Input("input: T" ) |
1367 | .Output("output: T" ) |
1368 | .Attr("T: type" ) |
1369 | .Attr("message: string = ''" ) |
1370 | .SetShapeFn(shape_inference::UnchangedShape); |
1371 | |
1372 | // -------------------------------------------------------------------------- |
1373 | REGISTER_OP("CheckNumerics" ) |
1374 | .Input("tensor: T" ) |
1375 | .Output("output: T" ) |
1376 | .Attr("T: {bfloat16, half, float, double}" ) |
1377 | .Attr("message: string" ) |
1378 | .SetIsStateful() |
1379 | .SetShapeFn(shape_inference::UnchangedShape); |
1380 | |
1381 | // -------------------------------------------------------------------------- |
1382 | REGISTER_OP("CheckNumericsV2" ) |
1383 | .Input("tensor: T" ) |
1384 | .Output("output: T" ) |
1385 | .Attr("T: {bfloat16, half, float, double}" ) |
1386 | .Attr("message: string" ) |
1387 | .SetIsStateful() |
1388 | .SetShapeFn(shape_inference::UnchangedShape); |
1389 | |
1390 | // -------------------------------------------------------------------------- |
1391 | REGISTER_OP("Reshape" ) |
1392 | .Input("tensor: T" ) |
1393 | .Input("shape: Tshape" ) |
1394 | .Output("output: T" ) |
1395 | .Attr("T: type" ) |
1396 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
1397 | .SetShapeFn([](InferenceContext* c) { |
1398 | return SetOutputShapeForReshape(c); |
1399 | }); |
1400 | |
1401 | #ifdef INTEL_MKL |
1402 | REGISTER_OP("_MklReshape" ) |
1403 | .Input("tensor: T" ) |
1404 | .Input("shape: Tshape" ) |
1405 | .Input("mkl_tensor: uint8" ) |
1406 | .Input("mkl_shape: uint8" ) |
1407 | .Output("output: T" ) |
1408 | .Output("mkl_output: uint8" ) |
1409 | .Attr("T: type" ) |
1410 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
1411 | .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); }) |
1412 | .Doc(R"Doc( MKL implementation of ReshapeOp. |
1413 | )Doc" ); |
1414 | #endif // INTEL_MKL |
1415 | |
1416 | // -------------------------------------------------------------------------- |
1417 | REGISTER_OP("InvertPermutation" ) |
1418 | .Input("x: T" ) |
1419 | .Output("y: T" ) |
1420 | .Attr("T: {int32, int64} = DT_INT32" ) |
1421 | .SetShapeFn([](InferenceContext* c) { |
1422 | ShapeHandle x; |
1423 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x)); |
1424 | c->set_output(0, x); |
1425 | return OkStatus(); |
1426 | }); |
1427 | |
1428 | // -------------------------------------------------------------------------- |
1429 | REGISTER_OP("Transpose" ) |
1430 | .Input("x: T" ) |
1431 | .Input("perm: Tperm" ) |
1432 | .Output("y: T" ) |
1433 | .Attr("T: type" ) |
1434 | .Attr("Tperm: {int32, int64} = DT_INT32" ) |
1435 | .SetShapeFn(TransposeShapeFn); |
1436 | |
1437 | #ifdef INTEL_MKL |
1438 | REGISTER_OP("_MklTranspose" ) |
1439 | .Input("x: T" ) |
1440 | .Input("perm: Tperm" ) |
1441 | .Output("y: T" ) |
1442 | .Attr("T: type" ) |
1443 | .Attr("Tperm: {int32, int64} = DT_INT32" ) |
1444 | .SetShapeFn(TransposeShapeFn); |
1445 | #endif // INTEL_MKL |
1446 | |
1447 | // -------------------------------------------------------------------------- |
1448 | REGISTER_OP("ConjugateTranspose" ) |
1449 | .Input("x: T" ) |
1450 | .Input("perm: Tperm" ) |
1451 | .Output("y: T" ) |
1452 | .Attr("T: type" ) |
1453 | .Attr("Tperm: {int32, int64} = DT_INT32" ) |
1454 | .SetShapeFn(TransposeShapeFn); |
1455 | |
1456 | #ifdef INTEL_MKL |
1457 | REGISTER_OP("_MklConjugateTranspose" ) |
1458 | .Input("x: T" ) |
1459 | .Input("perm: Tperm" ) |
1460 | .Output("y: T" ) |
1461 | .Attr("T: type" ) |
1462 | .Attr("Tperm: {int32, int64} = DT_INT32" ) |
1463 | .SetShapeFn(TransposeShapeFn); |
1464 | #endif // INTEL_MKL |
1465 | |
1466 | // -------------------------------------------------------------------------- |
1467 | namespace { |
1468 | Status UniqueIdxShapeFn(InferenceContext* c) { |
1469 | ShapeHandle input = c->input(0); |
1470 | const Tensor* axis_t = c->input_tensor(1); |
1471 | if (axis_t == nullptr || !c->RankKnown(input)) { |
1472 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
1473 | return OkStatus(); |
1474 | } |
1475 | |
1476 | if (c->Rank(c->input(1)) != 1) { |
1477 | return errors::InvalidArgument("axis expects a 1D vector." ); |
1478 | } |
1479 | |
1480 | int32_t n = axis_t->NumElements(); |
1481 | if (n == 0) { |
1482 | if (c->Rank(input) != 1) { |
1483 | return errors::InvalidArgument("x expects a 1D vector." ); |
1484 | } |
1485 | c->set_output(1, input); |
1486 | return OkStatus(); |
1487 | } else if (n == 1) { |
1488 | int64_t axis; |
1489 | if (axis_t->dtype() == DT_INT32) { |
1490 | axis = static_cast<int64_t>(axis_t->flat<int32>()(0)); |
1491 | } else { |
1492 | axis = axis_t->flat<int64_t>()(0); |
1493 | } |
1494 | |
1495 | int64_t input_rank = c->Rank(input); |
1496 | if (axis < -input_rank || axis >= input_rank) { |
1497 | return errors::InvalidArgument("axis expects to be in the range [" , |
1498 | -input_rank, ", " , input_rank, ")" ); |
1499 | } |
1500 | if (axis < 0) { |
1501 | axis += input_rank; |
1502 | } |
1503 | c->set_output(1, c->Vector(c->Dim(input, axis))); |
1504 | return OkStatus(); |
1505 | } |
1506 | return errors::InvalidArgument( |
1507 | "axis does not support input tensors larger than 1 elements." ); |
1508 | } |
1509 | } // namespace |
1510 | |
1511 | REGISTER_OP("Unique" ) |
1512 | .Input("x: T" ) |
1513 | .Output("y: T" ) |
1514 | .Output("idx: out_idx" ) |
1515 | .Attr("T: type" ) |
1516 | .Attr("out_idx: {int32, int64} = DT_INT32" ) |
1517 | .SetShapeFn([](InferenceContext* c) { |
1518 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
1519 | c->set_output(1, c->input(0)); |
1520 | // Assert that the input rank is 1. |
1521 | ShapeHandle dummy; |
1522 | return c->WithRank(c->input(0), 1, &dummy); |
1523 | }); |
1524 | |
1525 | REGISTER_OP("UniqueV2" ) |
1526 | .Input("x: T" ) |
1527 | .Input("axis: Taxis" ) |
1528 | .Output("y: T" ) |
1529 | .Output("idx: out_idx" ) |
1530 | .Attr("T: type" ) |
1531 | .Attr("Taxis: {int32,int64} = DT_INT64" ) |
1532 | .Attr("out_idx: {int32, int64} = DT_INT32" ) |
1533 | .SetShapeFn([](InferenceContext* c) { |
1534 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); |
1535 | TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c)); |
1536 | return OkStatus(); |
1537 | }); |
1538 | |
1539 | // -------------------------------------------------------------------------- |
1540 | REGISTER_OP("UniqueWithCounts" ) |
1541 | .Input("x: T" ) |
1542 | .Output("y: T" ) |
1543 | .Output("idx: out_idx" ) |
1544 | .Output("count: out_idx" ) |
1545 | .Attr("T: type" ) |
1546 | .Attr("out_idx: {int32, int64} = DT_INT32" ) |
1547 | .SetShapeFn([](InferenceContext* c) { |
1548 | auto uniq = c->Vector(InferenceContext::kUnknownDim); |
1549 | c->set_output(0, uniq); |
1550 | c->set_output(1, c->input(0)); |
1551 | c->set_output(2, uniq); |
1552 | return OkStatus(); |
1553 | }); |
1554 | |
1555 | REGISTER_OP("UniqueWithCountsV2" ) |
1556 | .Input("x: T" ) |
1557 | .Input("axis: Taxis" ) |
1558 | .Output("y: T" ) |
1559 | .Output("idx: out_idx" ) |
1560 | .Output("count: out_idx" ) |
1561 | .Attr("T: type" ) |
1562 | .Attr("Taxis: {int32,int64} = DT_INT64" ) |
1563 | .Attr("out_idx: {int32, int64} = DT_INT32" ) |
1564 | .SetShapeFn([](InferenceContext* c) { |
1565 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); |
1566 | TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c)); |
1567 | c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); |
1568 | return OkStatus(); |
1569 | }); |
1570 | |
1571 | namespace { |
1572 | |
1573 | Status ShapeShapeFn(InferenceContext* c) { |
1574 | for (int i = 0; i < c->num_inputs(); ++i) { |
1575 | DimensionHandle dim; |
1576 | if (c->RankKnown(c->input(i))) { |
1577 | dim = c->MakeDim(c->Rank(c->input(i))); |
1578 | } else { |
1579 | dim = c->UnknownDim(); |
1580 | } |
1581 | c->set_output(i, c->Vector(dim)); |
1582 | } |
1583 | return OkStatus(); |
1584 | } |
1585 | |
1586 | } // namespace |
1587 | |
1588 | // -------------------------------------------------------------------------- |
1589 | REGISTER_OP("Shape" ) |
1590 | .Input("input: T" ) |
1591 | .Output("output: out_type" ) |
1592 | .Attr("T: type" ) |
1593 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
1594 | .SetShapeFn(ShapeShapeFn); |
1595 | |
1596 | REGISTER_OP("ShapeN" ) |
1597 | .Input("input: N * T" ) |
1598 | .Output("output: N * out_type" ) |
1599 | .Attr("N: int" ) |
1600 | .Attr("T: type" ) |
1601 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
1602 | .SetShapeFn(ShapeShapeFn); |
1603 | |
1604 | REGISTER_OP("EnsureShape" ) |
1605 | .Input("input: T" ) |
1606 | .Output("output: T" ) |
1607 | .Attr("shape: shape" ) |
1608 | .Attr("T: type" ) |
1609 | .SetShapeFn([](InferenceContext* c) { |
1610 | // Merges desired shape and statically known shape of input |
1611 | PartialTensorShape desired_shape; |
1612 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &desired_shape)); |
1613 | |
1614 | int rank = desired_shape.dims(); |
1615 | ShapeHandle input_shape_handle; |
1616 | ShapeHandle desired_shape_handle; |
1617 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle)); |
1618 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
1619 | desired_shape, &desired_shape_handle)); |
1620 | |
1621 | ShapeHandle merged_shape; |
1622 | TF_RETURN_IF_ERROR( |
1623 | c->Merge(desired_shape_handle, input_shape_handle, &merged_shape)); |
1624 | c->set_output(0, merged_shape); |
1625 | return OkStatus(); |
1626 | }); |
1627 | |
1628 | // -------------------------------------------------------------------------- |
1629 | REGISTER_OP("ReverseSequence" ) |
1630 | .Input("input: T" ) |
1631 | .Input("seq_lengths: Tlen" ) |
1632 | .Output("output: T" ) |
1633 | .Attr("seq_dim: int" ) |
1634 | .Attr("batch_dim: int = 0" ) |
1635 | .Attr("T: type" ) |
1636 | .Attr("Tlen: {int32, int64} = DT_INT64" ) |
1637 | .SetShapeFn([](InferenceContext* c) { |
1638 | ShapeHandle input = c->input(0); |
1639 | ShapeHandle seq_lens_shape; |
1640 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape)); |
1641 | |
1642 | int64_t seq_dim; |
1643 | TF_RETURN_IF_ERROR(c->GetAttr("seq_dim" , &seq_dim)); |
1644 | int64_t batch_dim; |
1645 | TF_RETURN_IF_ERROR(c->GetAttr("batch_dim" , &batch_dim)); |
1646 | |
1647 | if (!c->RankKnown(input)) { |
1648 | return shape_inference::UnknownShape(c); |
1649 | } |
1650 | |
1651 | // Validate batch_dim and seq_dim against input. |
1652 | const int32_t input_rank = c->Rank(input); |
1653 | if (batch_dim >= input_rank) { |
1654 | return errors::InvalidArgument( |
1655 | "batch_dim must be < input rank: " , batch_dim, " vs. " , input_rank); |
1656 | } |
1657 | |
1658 | if (seq_dim >= input_rank) { |
1659 | return errors::InvalidArgument( |
1660 | "seq_dim must be < input rank: " , seq_dim, " vs. " , input_rank); |
1661 | } |
1662 | |
1663 | // To prevent out of bound access when calling c->Dim(input, batch_dim), |
1664 | // batch_dim range [-1 * input rank, input rank) is allowed. However, |
1665 | // the op implementation has a stricter bound for batch_dim requiring >= 0 |
1666 | // value. Thus, perform strict check here. |
1667 | if (batch_dim < 0) { |
1668 | return errors::InvalidArgument("batch_dim must be >=0, got " , |
1669 | batch_dim); |
1670 | } |
1671 | |
1672 | DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); |
1673 | TF_RETURN_IF_ERROR( |
1674 | c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim)); |
1675 | |
1676 | // Replace batch_dim of input with batch_size |
1677 | ShapeHandle output_shape; |
1678 | TF_RETURN_IF_ERROR( |
1679 | c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape)); |
1680 | c->set_output(0, output_shape); |
1681 | return OkStatus(); |
1682 | }); |
1683 | |
1684 | // -------------------------------------------------------------------------- |
1685 | REGISTER_OP("Rank" ) |
1686 | .Input("input: T" ) |
1687 | .Output("output: int32" ) |
1688 | .Attr("T: type" ) |
1689 | .SetShapeFn(shape_inference::ScalarShape); |
1690 | |
1691 | // -------------------------------------------------------------------------- |
1692 | REGISTER_OP("Size" ) |
1693 | .Input("input: T" ) |
1694 | .Output("output: out_type" ) |
1695 | .Attr("T: type" ) |
1696 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
1697 | .SetShapeFn(shape_inference::ScalarShape); |
1698 | |
1699 | // -------------------------------------------------------------------------- |
1700 | REGISTER_OP("Slice" ) |
1701 | .Input("input: T" ) |
1702 | .Input("begin: Index" ) |
1703 | .Input("size: Index" ) |
1704 | .Output("output: T" ) |
1705 | .Attr("T: type" ) |
1706 | .Attr("Index: {int32,int64}" ) |
1707 | .SetShapeFn(shape_inference::SliceShape); |
1708 | |
1709 | #ifdef INTEL_MKL |
1710 | REGISTER_OP("_MklSlice" ) |
1711 | .Input("input: T" ) |
1712 | .Input("begin: Index" ) |
1713 | .Input("size: Index" ) |
1714 | .Input("mkl_input: uint8" ) |
1715 | .Input("mkl_begin: uint8" ) |
1716 | .Input("mkl_size: uint8" ) |
1717 | .Output("output: T" ) |
1718 | .Output("mkl_output: uint8" ) |
1719 | .Attr("T: type" ) |
1720 | .Attr("Index: {int32,int64}" ) |
1721 | .SetShapeFn(shape_inference::SliceShape); |
1722 | #endif |
1723 | |
1724 | REGISTER_OP("StridedSlice" ) |
1725 | .Input("input: T" ) |
1726 | .Input("begin: Index" ) |
1727 | .Input("end: Index" ) |
1728 | .Input("strides: Index" ) |
1729 | .Output("output: T" ) |
1730 | .Attr("T: type" ) |
1731 | .Attr("Index: {int16, int32, int64}" ) |
1732 | .Attr("begin_mask: int = 0" ) |
1733 | .Attr("end_mask: int = 0" ) |
1734 | .Attr("ellipsis_mask: int = 0" ) |
1735 | .Attr("new_axis_mask: int = 0" ) |
1736 | .Attr("shrink_axis_mask: int = 0" ) |
1737 | .SetShapeFn([](InferenceContext* c) { |
1738 | ShapeHandle input = c->input(0); |
1739 | ShapeHandle begin_shape, end_shape, strides_shape; |
1740 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); |
1741 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape)); |
1742 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape)); |
1743 | TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape)); |
1744 | TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape)); |
1745 | DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0); |
1746 | |
1747 | const Tensor* strides_value = c->input_tensor(3); |
1748 | // TODO(aselle,allenl): If we had a stride_mask it would be possible to do |
1749 | // more shape inference here (e.g. for x[3, ::T]). |
1750 | if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) || |
1751 | strides_value == nullptr) { |
1752 | c->set_output(0, c->UnknownShape()); |
1753 | return OkStatus(); |
1754 | } |
1755 | |
1756 | PartialTensorShape input_shape({}); |
1757 | for (int i = 0; i < c->Rank(input); ++i) { |
1758 | auto dim = c->Dim(input, i); |
1759 | input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1); |
1760 | } |
1761 | |
1762 | int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, |
1763 | shrink_axis_mask; |
1764 | TF_RETURN_IF_ERROR(c->GetAttr("begin_mask" , &begin_mask)); |
1765 | TF_RETURN_IF_ERROR(c->GetAttr("end_mask" , &end_mask)); |
1766 | TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask" , &ellipsis_mask)); |
1767 | TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask" , &new_axis_mask)); |
1768 | TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask" , &shrink_axis_mask)); |
1769 | |
1770 | const Tensor* begin_value = c->input_tensor(1); |
1771 | const Tensor* end_value = c->input_tensor(2); |
1772 | |
1773 | PartialTensorShape processing_shape, final_shape; |
1774 | bool is_identity, is_simple_slice, slice_dim0; |
1775 | gtl::InlinedVector<int64, 4> begin, end, strides; |
1776 | TF_RETURN_IF_ERROR(ValidateStridedSliceOp( |
1777 | begin_value, end_value, *strides_value, input_shape, begin_mask, |
1778 | end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, |
1779 | &processing_shape, &final_shape, &is_identity, &is_simple_slice, |
1780 | &slice_dim0, &begin, &end, &strides)); |
1781 | |
1782 | ShapeHandle out; |
1783 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out)); |
1784 | c->set_output(0, out); |
1785 | |
1786 | auto* shape_and_type = c->input_handle_shapes_and_types(0); |
1787 | if (shape_and_type) { |
1788 | c->set_output_handle_shapes_and_types(0, *shape_and_type); |
1789 | } |
1790 | |
1791 | return OkStatus(); |
1792 | }); |
1793 | |
1794 | REGISTER_OP("StridedSliceGrad" ) |
1795 | .Input("shape: Index" ) |
1796 | .Input("begin: Index" ) |
1797 | .Input("end: Index" ) |
1798 | .Input("strides: Index" ) |
1799 | .Input("dy: T" ) |
1800 | .Output("output: T" ) |
1801 | .Attr("T: type" ) |
1802 | .Attr("Index: {int32, int64}" ) |
1803 | .Attr("begin_mask: int = 0" ) |
1804 | .Attr("end_mask: int = 0" ) |
1805 | .Attr("ellipsis_mask: int = 0" ) |
1806 | .Attr("new_axis_mask: int = 0" ) |
1807 | .Attr("shrink_axis_mask: int = 0" ) |
1808 | .SetShapeFn([](InferenceContext* c) { |
1809 | ShapeHandle out; |
1810 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
1811 | c->set_output(0, out); |
1812 | return OkStatus(); |
1813 | }); |
1814 | |
1815 | REGISTER_OP("StridedSliceAssign" ) |
1816 | .Input("ref: Ref(T)" ) |
1817 | .Input("begin: Index" ) |
1818 | .Input("end: Index" ) |
1819 | .Input("strides: Index" ) |
1820 | .Input("value: T" ) |
1821 | .Output("output_ref: Ref(T)" ) |
1822 | .Attr("T: type" ) |
1823 | .Attr("Index: {int32, int64}" ) |
1824 | .Attr("begin_mask: int = 0" ) |
1825 | .Attr("end_mask: int = 0" ) |
1826 | .Attr("ellipsis_mask: int = 0" ) |
1827 | .Attr("new_axis_mask: int = 0" ) |
1828 | .Attr("shrink_axis_mask: int = 0" ) |
1829 | .SetShapeFn(shape_inference::UnchangedShape); |
1830 | // TODO(aselle): Fix this documentation once StridedSliceAssign Supports |
1831 | // broadcasting. |
1832 | // -------------------------------------------------------------------------- |
1833 | |
1834 | REGISTER_OP("ResourceStridedSliceAssign" ) |
1835 | .Input("ref: resource" ) |
1836 | .Input("begin: Index" ) |
1837 | .Input("end: Index" ) |
1838 | .Input("strides: Index" ) |
1839 | .Input("value: T" ) |
1840 | .Attr("T: type" ) |
1841 | .Attr("Index: {int32, int64}" ) |
1842 | .Attr("begin_mask: int = 0" ) |
1843 | .Attr("end_mask: int = 0" ) |
1844 | .Attr("ellipsis_mask: int = 0" ) |
1845 | .Attr("new_axis_mask: int = 0" ) |
1846 | .Attr("shrink_axis_mask: int = 0" ) |
1847 | .SetShapeFn(shape_inference::NoOutputs); |
1848 | |
1849 | REGISTER_OP("TensorStridedSliceUpdate" ) |
1850 | .Input("input: T" ) |
1851 | .Input("begin: Index" ) |
1852 | .Input("end: Index" ) |
1853 | .Input("strides: Index" ) |
1854 | .Input("value: T" ) |
1855 | .Output("output: T" ) |
1856 | .Attr("T: type" ) |
1857 | .Attr("Index: {int32, int64}" ) |
1858 | .Attr("begin_mask: int = 0" ) |
1859 | .Attr("end_mask: int = 0" ) |
1860 | .Attr("ellipsis_mask: int = 0" ) |
1861 | .Attr("new_axis_mask: int = 0" ) |
1862 | .Attr("shrink_axis_mask: int = 0" ) |
1863 | .SetShapeFn(shape_inference::UnchangedShape); |
1864 | |
1865 | REGISTER_OP("Tile" ) |
1866 | .Input("input: T" ) |
1867 | .Input("multiples: Tmultiples" ) |
1868 | .Output("output: T" ) |
1869 | .Attr("T: type" ) |
1870 | .Attr("Tmultiples: {int32, int64} = DT_INT32" ) |
1871 | .SetShapeFn([](InferenceContext* c) { |
1872 | ShapeHandle input = c->input(0); |
1873 | // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i) |
1874 | // it is a vector of non-negative integers, and (ii) doing so allows |
1875 | // us to handle partially-known multiples. |
1876 | ShapeHandle multiples; |
1877 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples)); |
1878 | if (c->RankKnown(input)) { |
1879 | TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples)); |
1880 | ShapeHandle dummy; |
1881 | TF_RETURN_IF_ERROR( |
1882 | c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy)); |
1883 | } |
1884 | |
1885 | if (!c->RankKnown(multiples)) { |
1886 | return shape_inference::UnknownShape(c); |
1887 | } |
1888 | |
1889 | int32_t rank = c->Rank(multiples); |
1890 | TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); |
1891 | std::vector<DimensionHandle> dims(rank); |
1892 | for (int i = 0; i < rank; ++i) { |
1893 | TF_RETURN_IF_ERROR( |
1894 | c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i])); |
1895 | } |
1896 | c->set_output(0, c->MakeShape(dims)); |
1897 | return OkStatus(); |
1898 | }); |
1899 | |
1900 | // -------------------------------------------------------------------------- |
1901 | REGISTER_OP("TileGrad" ) |
1902 | .Input("input: T" ) |
1903 | .Input("multiples: int32" ) |
1904 | .Output("output: T" ) |
1905 | .Attr("T: type" ) |
1906 | .Deprecated(3, "TileGrad has been replaced with reduce_sum" ) |
1907 | .SetShapeFn(tensorflow::shape_inference::UnknownShape); |
1908 | |
1909 | // -------------------------------------------------------------------------- |
1910 | REGISTER_OP("Where" ) |
1911 | .Input("input: T" ) |
1912 | .Attr("T: {numbertype, bool} = DT_BOOL" ) |
1913 | .Output("index: int64" ) |
1914 | .SetShapeFn([](InferenceContext* c) { |
1915 | c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); |
1916 | return OkStatus(); |
1917 | }); |
1918 | |
1919 | // -------------------------------------------------------------------------- |
1920 | REGISTER_OP("BroadcastArgs" ) |
1921 | .Input("s0: T" ) |
1922 | .Input("s1: T" ) |
1923 | .Output("r0: T" ) |
1924 | .Attr("T: {int32, int64} = DT_INT32" ) |
1925 | .SetShapeFn([](InferenceContext* c) { |
1926 | ShapeHandle unused; |
1927 | ShapeHandle shape_x = c->input(0); |
1928 | ShapeHandle shape_y = c->input(1); |
1929 | TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused)); |
1930 | TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused)); |
1931 | |
1932 | if (!c->ValueKnown(c->Dim(shape_x, 0)) || |
1933 | !c->ValueKnown(c->Dim(shape_y, 0))) { |
1934 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
1935 | return OkStatus(); |
1936 | } |
1937 | |
1938 | int64_t x_dim = c->Value(c->Dim(shape_x, 0)); |
1939 | int64_t y_dim = c->Value(c->Dim(shape_y, 0)); |
1940 | |
1941 | // Broadcasted shape is going to be as large as the largest dimension. |
1942 | c->set_output(0, c->Vector(std::max(x_dim, y_dim))); |
1943 | return OkStatus(); |
1944 | }); |
1945 | |
1946 | // -------------------------------------------------------------------------- |
1947 | REGISTER_OP("BroadcastGradientArgs" ) |
1948 | .Input("s0: T" ) |
1949 | .Input("s1: T" ) |
1950 | .Output("r0: T" ) |
1951 | .Output("r1: T" ) |
1952 | .Attr("T: {int32, int64} = DT_INT32" ) |
1953 | .SetShapeFn([](InferenceContext* c) { |
1954 | // TODO(mrry): Implement constant_value for BroadcastGradientArgs? |
1955 | ShapeHandle unused; |
1956 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
1957 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
1958 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
1959 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
1960 | return OkStatus(); |
1961 | }); |
1962 | |
1963 | // -------------------------------------------------------------------------- |
1964 | REGISTER_OP("Pad" ) |
1965 | .Input("input: T" ) |
1966 | .Input("paddings: Tpaddings" ) |
1967 | .Output("output: T" ) |
1968 | .Attr("T: type" ) |
1969 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
1970 | .SetShapeFn(PadShapeFn); |
1971 | |
1972 | // -------------------------------------------------------------------------- |
1973 | REGISTER_OP("PadV2" ) |
1974 | .Input("input: T" ) |
1975 | .Input("paddings: Tpaddings" ) |
1976 | .Input("constant_values: T" ) |
1977 | .Output("output: T" ) |
1978 | .Attr("T: type" ) |
1979 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
1980 | .SetShapeFn(PadShapeFn); |
1981 | |
1982 | // -------------------------------------------------------------------------- |
1983 | REGISTER_OP("MirrorPad" ) |
1984 | .Input("input: T" ) |
1985 | .Input("paddings: Tpaddings" ) |
1986 | .Output("output: T" ) |
1987 | .Attr("T: type" ) |
1988 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
1989 | .Attr(GetMirrorPadModeAttrString()) |
1990 | .SetShapeFn(PadShapeFn); |
1991 | |
1992 | // -------------------------------------------------------------------------- |
1993 | namespace { |
1994 | template <typename T> |
1995 | Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, |
1996 | const Tensor* paddings_t, int64_t input_rank) { |
1997 | auto paddings_data = paddings_t->matrix<T>(); |
1998 | std::vector<DimensionHandle> dims(input_rank); |
1999 | for (int64_t i = 0; i < input_rank; ++i) { |
2000 | const int64_t pad0 = static_cast<int64_t>(paddings_data(i, 0)); |
2001 | const int64_t pad1 = static_cast<int64_t>(paddings_data(i, 1)); |
2002 | if (pad0 < 0 || pad1 < 0) { |
2003 | return errors::InvalidArgument("Paddings must be non-negative" ); |
2004 | } |
2005 | |
2006 | TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i])); |
2007 | } |
2008 | c->set_output(0, c->MakeShape(dims)); |
2009 | return OkStatus(); |
2010 | } |
2011 | |
2012 | } // namespace |
2013 | |
2014 | REGISTER_OP("MirrorPadGrad" ) |
2015 | .Input("input: T" ) |
2016 | .Input("paddings: Tpaddings" ) |
2017 | .Output("output: T" ) |
2018 | .Attr("T: type" ) |
2019 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
2020 | .Attr(GetMirrorPadModeAttrString()) |
2021 | .SetShapeFn([](InferenceContext* c) { |
2022 | ShapeHandle paddings; |
2023 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); |
2024 | DimensionHandle pad_0 = c->Dim(paddings, 0); |
2025 | if (!c->ValueKnown(pad_0)) { |
2026 | // We don't know the rank of the output since the first |
2027 | // padding dimension is unknown. |
2028 | c->set_output(0, c->UnknownShape()); |
2029 | return OkStatus(); |
2030 | } |
2031 | |
2032 | int64_t input_rank = c->Value(pad_0); |
2033 | ShapeHandle input; |
2034 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input)); |
2035 | TF_RETURN_IF_ERROR( |
2036 | c->Merge(paddings, c->Matrix(input_rank, 2), &paddings)); |
2037 | |
2038 | const Tensor* paddings_t = c->input_tensor(1); |
2039 | if (paddings_t == nullptr) { |
2040 | // Values of 'paddings' is not available, but we know the |
2041 | // input rank, so return the rank of the output with unknown |
2042 | // dimensions. |
2043 | c->set_output(0, c->UnknownShapeOfRank(input_rank)); |
2044 | return OkStatus(); |
2045 | } |
2046 | |
2047 | if (paddings_t->dtype() == DT_INT32) { |
2048 | return MirrorPadKnown<int32>(c, input, paddings_t, input_rank); |
2049 | } else { |
2050 | return MirrorPadKnown<int64_t>(c, input, paddings_t, input_rank); |
2051 | } |
2052 | }); |
2053 | |
2054 | // -------------------------------------------------------------------------- |
2055 | REGISTER_OP("Placeholder" ) |
2056 | .Output("output: dtype" ) |
2057 | .Attr("dtype: type" ) |
2058 | .Attr("shape: shape = { unknown_rank: true }" ) |
2059 | .SetShapeFn([](InferenceContext* c) { |
2060 | PartialTensorShape shape; |
2061 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
2062 | |
2063 | // Placeholder has legacy behavior where we cannot tell the difference |
2064 | // between a scalar shape attribute and 'unknown shape'. So if the shape |
2065 | // is a scalar, we return an unknown shape. |
2066 | if (c->graph_def_version() <= 21 && shape.dims() <= 0) { |
2067 | return shape_inference::UnknownShape(c); |
2068 | } |
2069 | |
2070 | ShapeHandle out; |
2071 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
2072 | c->set_output(0, out); |
2073 | return OkStatus(); |
2074 | }); |
2075 | |
2076 | // Placeholder was modified in a backwards compatible way to do what |
2077 | // PlaceholderV2 did, so we have deprecated V2 (no one was really |
2078 | // using it). |
2079 | REGISTER_OP("PlaceholderV2" ) |
2080 | .Output("output: dtype" ) |
2081 | .Attr("dtype: type" ) |
2082 | .Attr("shape: shape" ) |
2083 | .SetShapeFn(shape_inference::ExplicitShape) |
2084 | .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2." ); |
2085 | |
2086 | // -------------------------------------------------------------------------- |
2087 | REGISTER_OP("PlaceholderWithDefault" ) |
2088 | .Input("input: dtype" ) |
2089 | .Output("output: dtype" ) |
2090 | .Attr("dtype: type" ) |
2091 | .Attr("shape: shape" ) |
2092 | .SetShapeFn([](InferenceContext* c) { |
2093 | ShapeHandle input = c->input(0); |
2094 | PartialTensorShape shape; |
2095 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
2096 | ShapeHandle out; |
2097 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
2098 | |
2099 | // We merge for compatibility checking, but return the output, |
2100 | // since output_shape may be less precise than input_shape. |
2101 | ShapeHandle unused; |
2102 | TF_RETURN_IF_ERROR(c->Merge(input, out, &unused)); |
2103 | c->set_output(0, out); |
2104 | return OkStatus(); |
2105 | }); |
2106 | |
2107 | // -------------------------------------------------------------------------- |
2108 | REGISTER_OP("ExpandDims" ) |
2109 | .Input("input: T" ) |
2110 | .Input("dim: Tdim" ) |
2111 | .Output("output: T" ) |
2112 | .Attr("T: type" ) |
2113 | .Attr("Tdim: {int32, int64} = DT_INT32" ) |
2114 | .SetShapeFn([](InferenceContext* c) { |
2115 | ShapeHandle input = c->input(0); |
2116 | |
2117 | const Tensor* dim_t = c->input_tensor(1); |
2118 | if (dim_t != nullptr && dim_t->NumElements() != 1) { |
2119 | return errors::InvalidArgument( |
2120 | "'dim' input must be a tensor with a single value" ); |
2121 | } |
2122 | if (dim_t == nullptr || !c->RankKnown(input)) { |
2123 | c->set_output(0, c->UnknownShape()); |
2124 | return OkStatus(); |
2125 | } |
2126 | |
2127 | int64_t dim; |
2128 | if (dim_t->dtype() == DT_INT32) { |
2129 | dim = static_cast<int64_t>(dim_t->flat<int32>()(0)); |
2130 | } else { |
2131 | dim = dim_t->flat<int64_t>()(0); |
2132 | } |
2133 | |
2134 | const int32_t rank = c->Rank(input); |
2135 | const int32_t min_dim = -1 * rank - 1; |
2136 | if (dim < min_dim || dim > rank) { |
2137 | return errors::InvalidArgument("dim " , dim, " not in the interval [" , |
2138 | min_dim, ", " , rank, "]." ); |
2139 | } |
2140 | |
2141 | if (dim < 0) { |
2142 | dim += rank + 1; |
2143 | } |
2144 | |
2145 | ShapeHandle end; |
2146 | TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end)); |
2147 | |
2148 | // Build output as start + 1 + end. |
2149 | ShapeHandle output; |
2150 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output)); |
2151 | TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output)); |
2152 | TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output)); |
2153 | c->set_output(0, output); |
2154 | return OkStatus(); |
2155 | }); |
2156 | |
2157 | // -------------------------------------------------------------------------- |
2158 | REGISTER_OP("Squeeze" ) |
2159 | .Input("input: T" ) |
2160 | .Output("output: T" ) |
2161 | .Attr("T: type" ) |
2162 | .Attr("squeeze_dims: list(int) >= 0 = []" ) |
2163 | .SetShapeFn([](InferenceContext* c) { |
2164 | ShapeHandle input = c->input(0); |
2165 | if (!c->RankKnown(input)) { |
2166 | // Input shape unknown. |
2167 | return shape_inference::UnknownShape(c); |
2168 | } |
2169 | |
2170 | const int32_t input_rank = c->Rank(input); |
2171 | |
2172 | // Validate and wrap squeeze dimensions. |
2173 | std::vector<int32> squeeze_dims; |
2174 | TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims" , &squeeze_dims)); |
2175 | for (int i = 0; i < squeeze_dims.size(); ++i) { |
2176 | if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) { |
2177 | return errors::InvalidArgument("squeeze_dims[" , i, "] not in [" , |
2178 | -input_rank, "," , input_rank, ")." ); |
2179 | } |
2180 | |
2181 | if (squeeze_dims[i] < 0) { |
2182 | squeeze_dims[i] += input_rank; |
2183 | } |
2184 | } |
2185 | |
2186 | std::vector<DimensionHandle> result_shape; |
2187 | for (int i = 0; i < input_rank; ++i) { |
2188 | // True if squeeze_dims contains an entry to squeeze this |
2189 | // dimension. |
2190 | bool is_explicit_match = |
2191 | std::find(squeeze_dims.begin(), squeeze_dims.end(), i) != |
2192 | squeeze_dims.end(); |
2193 | |
2194 | DimensionHandle dim = c->Dim(input, i); |
2195 | |
2196 | if (!c->ValueKnown(dim)) { |
2197 | // Assume that the squeezed dimension will be 1 at runtime. |
2198 | if (is_explicit_match) continue; |
2199 | |
2200 | // If squeezing all 1 dimensions, and we see an unknown value, |
2201 | // give up and return Unknown Shape. |
2202 | if (squeeze_dims.empty()) { |
2203 | c->set_output(0, c->UnknownShape()); |
2204 | return OkStatus(); |
2205 | } |
2206 | } else if (c->Value(dim) == 1) { |
2207 | if (is_explicit_match || squeeze_dims.empty()) { |
2208 | // If explicitly squeezing, or squeezing all 1s, remove |
2209 | // this dimension. |
2210 | continue; |
2211 | } |
2212 | } else if (is_explicit_match) { |
2213 | return errors::InvalidArgument("Can not squeeze dim[" , i, |
2214 | "], expected a dimension of 1, got " , |
2215 | c->Value(c->Dim(input, i))); |
2216 | } |
2217 | |
2218 | result_shape.emplace_back(dim); |
2219 | } |
2220 | |
2221 | c->set_output(0, c->MakeShape(result_shape)); |
2222 | return OkStatus(); |
2223 | }); |
2224 | |
2225 | // -------------------------------------------------------------------------- |
2226 | REGISTER_OP("ListDiff" ) |
2227 | .Input("x: T" ) |
2228 | .Input("y: T" ) |
2229 | .Output("out: T" ) |
2230 | .Output("idx: out_idx" ) |
2231 | .Attr("T: type" ) |
2232 | .Attr("out_idx: {int32, int64} = DT_INT32" ) |
2233 | .SetShapeFn([](InferenceContext* c) { |
2234 | ShapeHandle unused; |
2235 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
2236 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
2237 | // TODO(mrry): Indicate that the length falls within an interval? |
2238 | ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); |
2239 | c->set_output(0, out); |
2240 | c->set_output(1, out); |
2241 | return OkStatus(); |
2242 | }); |
2243 | |
2244 | namespace { |
2245 | |
2246 | // Converts Tensor to flat std::vector<int64_t>. |
2247 | template <typename InputType> |
2248 | std::vector<int64_t> GetFlatInt64(const Tensor& t) { |
2249 | std::vector<int64_t> output(t.shape().num_elements()); |
2250 | if (t.shape().num_elements() > 0) { |
2251 | auto eigen_vec = t.flat<InputType>(); |
2252 | std::copy_n(&eigen_vec(0), output.size(), output.begin()); |
2253 | } |
2254 | return output; |
2255 | } |
2256 | |
2257 | // Converts int32 or int64 Tensor to flat std::vector<int64_t>. |
2258 | std::vector<int64_t> GetFlatInt64(const Tensor& t) { |
2259 | if (t.dtype() == DT_INT32) { |
2260 | return GetFlatInt64<int32>(t); |
2261 | } else { |
2262 | return GetFlatInt64<int64_t>(t); |
2263 | } |
2264 | } |
2265 | |
2266 | Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape, |
2267 | ShapeHandle block_shape_shape, |
2268 | const Tensor* block_shape_t, |
2269 | ShapeHandle paddings_shape, |
2270 | const Tensor* paddings_t) { |
2271 | if (c->Rank(block_shape_shape) != 1) { |
2272 | return errors::InvalidArgument("block_shape must have rank 1." ); |
2273 | } |
2274 | |
2275 | const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); |
2276 | if (!c->ValueKnown(num_block_dims_handle)) { |
2277 | return errors::InvalidArgument("block_shape must have known size." ); |
2278 | } |
2279 | |
2280 | const int64_t num_block_dims = c->Value(num_block_dims_handle); |
2281 | |
2282 | TF_RETURN_IF_ERROR( |
2283 | c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); |
2284 | |
2285 | TF_RETURN_IF_ERROR( |
2286 | c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape)); |
2287 | |
2288 | DimensionHandle batch_size = c->Dim(input_shape, 0); |
2289 | std::vector<int64_t> block_shape_vec; |
2290 | if (block_shape_t && (block_shape_t->NumElements() > 0)) { |
2291 | block_shape_vec = GetFlatInt64(*block_shape_t); |
2292 | for (int64_t dim = 0; dim < num_block_dims; ++dim) { |
2293 | const int64_t block_shape_value = block_shape_vec[dim]; |
2294 | if (block_shape_value < 1) { |
2295 | return errors::InvalidArgument("block_shape must be positive" ); |
2296 | } |
2297 | if (c->ValueKnown(batch_size)) { |
2298 | TF_RETURN_IF_ERROR( |
2299 | c->Multiply(batch_size, block_shape_value, &batch_size)); |
2300 | } else { |
2301 | batch_size = c->UnknownDim(); |
2302 | } |
2303 | } |
2304 | } else if (num_block_dims > 0) { |
2305 | batch_size = c->UnknownDim(); |
2306 | } |
2307 | |
2308 | std::vector<DimensionHandle> output_dims{batch_size}; |
2309 | output_dims.resize(num_block_dims + 1, c->UnknownDim()); |
2310 | |
2311 | if (paddings_t && (paddings_t->NumElements() > 0)) { |
2312 | const std::vector<int64_t> paddings_vec = GetFlatInt64(*paddings_t); |
2313 | for (int64_t dim = 0; dim < num_block_dims; ++dim) { |
2314 | const int64_t pad_start = paddings_vec[dim * 2], |
2315 | pad_end = paddings_vec[dim * 2 + 1]; |
2316 | if (pad_start < 0 || pad_end < 0) { |
2317 | return errors::InvalidArgument("paddings cannot be negative" ); |
2318 | } |
2319 | if (block_shape_t) { |
2320 | DimensionHandle padded_size; |
2321 | TF_RETURN_IF_ERROR( |
2322 | c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size)); |
2323 | TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size)); |
2324 | TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim], |
2325 | /*evenly_divisible=*/true, |
2326 | &output_dims[dim + 1])); |
2327 | } |
2328 | } |
2329 | } |
2330 | |
2331 | ShapeHandle remaining_input_shape; |
2332 | TF_RETURN_IF_ERROR( |
2333 | c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); |
2334 | |
2335 | ShapeHandle result; |
2336 | TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), |
2337 | remaining_input_shape, &result)); |
2338 | c->set_output(0, result); |
2339 | return OkStatus(); |
2340 | } |
2341 | |
2342 | Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape, |
2343 | ShapeHandle block_shape_shape, |
2344 | const Tensor* block_shape_t, |
2345 | ShapeHandle crops_shape, const Tensor* crops_t) { |
2346 | if (c->Rank(block_shape_shape) != 1) { |
2347 | return errors::InvalidArgument("block_shape must have rank 1." ); |
2348 | } |
2349 | |
2350 | const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); |
2351 | if (!c->ValueKnown(num_block_dims_handle)) { |
2352 | return errors::InvalidArgument("block_shape must have known size." ); |
2353 | } |
2354 | |
2355 | const int64_t num_block_dims = c->Value(num_block_dims_handle); |
2356 | |
2357 | TF_RETURN_IF_ERROR( |
2358 | c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); |
2359 | |
2360 | TF_RETURN_IF_ERROR( |
2361 | c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape)); |
2362 | |
2363 | DimensionHandle batch_size = c->Dim(input_shape, 0); |
2364 | std::vector<int64_t> block_shape_vec; |
2365 | if (block_shape_t) { |
2366 | block_shape_vec = GetFlatInt64(*block_shape_t); |
2367 | for (int64_t dim = 0; dim < num_block_dims; ++dim) { |
2368 | const int64_t block_shape_value = block_shape_vec[dim]; |
2369 | if (block_shape_value < 1) { |
2370 | return errors::InvalidArgument("block_shape must be positive" ); |
2371 | } |
2372 | if (c->ValueKnown(batch_size)) { |
2373 | TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value, |
2374 | /*evenly_divisible=*/true, &batch_size)); |
2375 | } else { |
2376 | batch_size = c->UnknownDim(); |
2377 | } |
2378 | } |
2379 | } else if (num_block_dims > 0) { |
2380 | batch_size = c->UnknownDim(); |
2381 | } |
2382 | |
2383 | std::vector<DimensionHandle> output_dims{batch_size}; |
2384 | output_dims.resize(num_block_dims + 1, c->UnknownDim()); |
2385 | |
2386 | if (crops_t) { |
2387 | const std::vector<int64_t> crops_vec = GetFlatInt64(*crops_t); |
2388 | for (int64_t dim = 0; dim < num_block_dims; ++dim) { |
2389 | const int64_t crop_start = crops_vec[dim * 2], |
2390 | crop_end = crops_vec[dim * 2 + 1]; |
2391 | if (crop_start < 0 || crop_end < 0) { |
2392 | return errors::InvalidArgument("crops cannot be negative" ); |
2393 | } |
2394 | if (block_shape_t) { |
2395 | DimensionHandle cropped_size; |
2396 | TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1), |
2397 | block_shape_vec[dim], &cropped_size)); |
2398 | TF_RETURN_IF_ERROR( |
2399 | c->Subtract(cropped_size, crop_start, &cropped_size)); |
2400 | TF_RETURN_IF_ERROR( |
2401 | c->Subtract(cropped_size, crop_end, &output_dims[dim + 1])); |
2402 | } |
2403 | } |
2404 | } |
2405 | |
2406 | ShapeHandle remaining_input_shape; |
2407 | TF_RETURN_IF_ERROR( |
2408 | c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); |
2409 | |
2410 | ShapeHandle result; |
2411 | TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), |
2412 | remaining_input_shape, &result)); |
2413 | c->set_output(0, result); |
2414 | return OkStatus(); |
2415 | } |
2416 | |
2417 | } // namespace |
2418 | |
2419 | // -------------------------------------------------------------------------- |
2420 | REGISTER_OP("SpaceToBatchND" ) |
2421 | .Input("input: T" ) |
2422 | .Input("block_shape: Tblock_shape" ) |
2423 | .Input("paddings: Tpaddings" ) |
2424 | .Output("output: T" ) |
2425 | .Attr("T: type" ) |
2426 | .Attr("Tblock_shape: {int32, int64} = DT_INT32" ) |
2427 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
2428 | .SetShapeFn([](InferenceContext* c) { |
2429 | return SpaceToBatchShapeHelper(c, c->input(0), c->input(1), |
2430 | c->input_tensor(1), c->input(2), |
2431 | c->input_tensor(2)); |
2432 | }); |
2433 | |
2434 | // -------------------------------------------------------------------------- |
2435 | REGISTER_OP("SpaceToBatch" ) |
2436 | .Input("input: T" ) |
2437 | .Input("paddings: Tpaddings" ) |
2438 | .Output("output: T" ) |
2439 | .Attr("T: type" ) |
2440 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
2441 | .Attr("block_size: int >= 2" ) |
2442 | .SetShapeFn([](InferenceContext* c) { |
2443 | ShapeHandle input_shape; |
2444 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
2445 | |
2446 | int32_t block_size; |
2447 | TF_RETURN_IF_ERROR(c->GetAttr("block_size" , &block_size)); |
2448 | |
2449 | Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); |
2450 | auto block_shape_vec = block_shape.vec<int64_t>(); |
2451 | block_shape_vec(0) = block_size; |
2452 | block_shape_vec(1) = block_size; |
2453 | |
2454 | return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}), |
2455 | &block_shape, c->input(1), |
2456 | c->input_tensor(1)); |
2457 | }); |
2458 | |
2459 | // -------------------------------------------------------------------------- |
2460 | REGISTER_OP("BatchToSpaceND" ) |
2461 | .Input("input: T" ) |
2462 | .Input("block_shape: Tblock_shape" ) |
2463 | .Input("crops: Tcrops" ) |
2464 | .Output("output: T" ) |
2465 | .Attr("T: type" ) |
2466 | .Attr("Tblock_shape: {int32, int64} = DT_INT32" ) |
2467 | .Attr("Tcrops: {int32, int64} = DT_INT32" ) |
2468 | .SetShapeFn([](InferenceContext* c) { |
2469 | return BatchToSpaceShapeHelper(c, c->input(0), c->input(1), |
2470 | c->input_tensor(1), c->input(2), |
2471 | c->input_tensor(2)); |
2472 | }); |
2473 | |
2474 | // -------------------------------------------------------------------------- |
2475 | REGISTER_OP("BatchToSpace" ) |
2476 | .Input("input: T" ) |
2477 | .Input("crops: Tidx" ) |
2478 | .Output("output: T" ) |
2479 | .Attr("T: type" ) |
2480 | .Attr("block_size: int >= 2" ) |
2481 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
2482 | .SetShapeFn([](InferenceContext* c) { |
2483 | ShapeHandle input_shape; |
2484 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
2485 | |
2486 | int32_t block_size; |
2487 | TF_RETURN_IF_ERROR(c->GetAttr("block_size" , &block_size)); |
2488 | |
2489 | Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); |
2490 | auto block_shape_vec = block_shape.vec<int64_t>(); |
2491 | block_shape_vec(0) = block_size; |
2492 | block_shape_vec(1) = block_size; |
2493 | |
2494 | return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}), |
2495 | &block_shape, c->input(1), |
2496 | c->input_tensor(1)); |
2497 | }); |
2498 | |
2499 | // -------------------------------------------------------------------------- |
2500 | REGISTER_OP("SpaceToDepth" ) |
2501 | .Input("input: T" ) |
2502 | .Output("output: T" ) |
2503 | .Attr("T: type" ) |
2504 | .Attr("block_size: int >= 2" ) |
2505 | .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'" ) |
2506 | // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C. |
2507 | .SetShapeFn([](InferenceContext* c) { |
2508 | string data_format_str; |
2509 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
2510 | TensorFormat data_format; |
2511 | FormatFromString(data_format_str, &data_format); |
2512 | |
2513 | constexpr int num_spatial_dims = 2; |
2514 | const int dims = |
2515 | GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); |
2516 | ShapeHandle input; |
2517 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); |
2518 | |
2519 | int32_t block_size; |
2520 | TF_RETURN_IF_ERROR(c->GetAttr("block_size" , &block_size)); |
2521 | |
2522 | DimensionHandle batch_size = |
2523 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
2524 | DimensionHandle input_height = |
2525 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
2526 | DimensionHandle input_width = |
2527 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
2528 | DimensionHandle input_depth = |
2529 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
2530 | |
2531 | DimensionHandle output_height; |
2532 | DimensionHandle output_width; |
2533 | DimensionHandle output_depth; |
2534 | // Will return an error if input height or width are not evenly divisible. |
2535 | TF_RETURN_IF_ERROR(c->Divide(input_height, block_size, |
2536 | true /* evenly_divisible */, |
2537 | &output_height)); |
2538 | TF_RETURN_IF_ERROR(c->Divide(input_width, block_size, |
2539 | true /* evenly_divisible */, &output_width)); |
2540 | |
2541 | TF_RETURN_IF_ERROR( |
2542 | c->Multiply(input_depth, block_size * block_size, &output_depth)); |
2543 | |
2544 | ShapeHandle output_shape; |
2545 | TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, |
2546 | {output_height, output_width}, |
2547 | output_depth, &output_shape, c)); |
2548 | |
2549 | c->set_output(0, output_shape); |
2550 | return OkStatus(); |
2551 | }); |
2552 | |
2553 | // -------------------------------------------------------------------------- |
2554 | REGISTER_OP("DepthToSpace" ) |
2555 | .Input("input: T" ) |
2556 | .Output("output: T" ) |
2557 | .Attr("T: type" ) |
2558 | .Attr("block_size: int >= 2" ) |
2559 | .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'" ) |
2560 | // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C. |
2561 | .SetShapeFn([](InferenceContext* c) { |
2562 | string data_format_str; |
2563 | TF_RETURN_IF_ERROR(c->GetAttr("data_format" , &data_format_str)); |
2564 | TensorFormat data_format; |
2565 | FormatFromString(data_format_str, &data_format); |
2566 | |
2567 | constexpr int num_spatial_dims = 2; |
2568 | const int dims = |
2569 | GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); |
2570 | |
2571 | ShapeHandle input; |
2572 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); |
2573 | |
2574 | int32_t block_size; |
2575 | TF_RETURN_IF_ERROR(c->GetAttr("block_size" , &block_size)); |
2576 | |
2577 | DimensionHandle batch_size = |
2578 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); |
2579 | DimensionHandle input_height = |
2580 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); |
2581 | DimensionHandle input_width = |
2582 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); |
2583 | DimensionHandle input_depth = |
2584 | c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); |
2585 | |
2586 | DimensionHandle output_height; |
2587 | DimensionHandle output_width; |
2588 | DimensionHandle output_depth; |
2589 | TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height)); |
2590 | TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width)); |
2591 | |
2592 | // Will return an error if input_depth is not evenly divisible. |
2593 | TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size, |
2594 | true /* evenly_divisible */, &output_depth)); |
2595 | |
2596 | ShapeHandle output_shape; |
2597 | TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, |
2598 | {output_height, output_width}, |
2599 | output_depth, &output_shape, c)); |
2600 | |
2601 | c->set_output(0, output_shape); |
2602 | return OkStatus(); |
2603 | }); |
2604 | |
2605 | // -------------------------------------------------------------------------- |
2606 | |
2607 | REGISTER_OP("ExtractImagePatches" ) |
2608 | .Input("images: T" ) |
2609 | .Output("patches: T" ) |
2610 | .Attr("ksizes: list(int) >= 4" ) |
2611 | .Attr("strides: list(int) >= 4" ) |
2612 | .Attr("rates: list(int) >= 4" ) |
2613 | .Attr( |
2614 | "T: {bfloat16, half, float, double, int8, int16, int32, int64, " |
2615 | "uint8, uint16, uint32, uint64, complex64, complex128, bool}" ) |
2616 | .Attr(GetPaddingAttrString()) |
2617 | .SetShapeFn([](InferenceContext* c) { |
2618 | ShapeHandle input_shape; |
2619 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
2620 | |
2621 | std::vector<int32> ksizes; |
2622 | TF_RETURN_IF_ERROR(c->GetAttr("ksizes" , &ksizes)); |
2623 | if (ksizes.size() != 4) { |
2624 | return errors::InvalidArgument( |
2625 | "ExtractImagePatches requires the ksizes attribute to contain 4 " |
2626 | "values, but got: " , |
2627 | ksizes.size()); |
2628 | } |
2629 | |
2630 | std::vector<int32> strides; |
2631 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
2632 | if (strides.size() != 4) { |
2633 | return errors::InvalidArgument( |
2634 | "ExtractImagePatches requires the stride attribute to contain 4 " |
2635 | "values, but got: " , |
2636 | strides.size()); |
2637 | } |
2638 | |
2639 | std::vector<int32> rates; |
2640 | TF_RETURN_IF_ERROR(c->GetAttr("rates" , &rates)); |
2641 | if (rates.size() != 4) { |
2642 | return errors::InvalidArgument( |
2643 | "ExtractImagePatches requires the rates attribute to contain 4 " |
2644 | "values, but got: " , |
2645 | rates.size()); |
2646 | } |
2647 | |
2648 | int32_t ksize_rows = ksizes[1]; |
2649 | int32_t ksize_cols = ksizes[2]; |
2650 | |
2651 | int32_t stride_rows = strides[1]; |
2652 | int32_t stride_cols = strides[2]; |
2653 | |
2654 | int32_t rate_rows = rates[1]; |
2655 | int32_t rate_cols = rates[2]; |
2656 | |
2657 | int32_t ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); |
2658 | int32_t ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); |
2659 | |
2660 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
2661 | DimensionHandle in_rows_dim = c->Dim(input_shape, 1); |
2662 | DimensionHandle in_cols_dim = c->Dim(input_shape, 2); |
2663 | DimensionHandle output_depth_dim; |
2664 | TF_RETURN_IF_ERROR(c->Multiply( |
2665 | c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim)); |
2666 | |
2667 | if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) { |
2668 | ShapeHandle output_shape = |
2669 | c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, |
2670 | InferenceContext::kUnknownDim, output_depth_dim}); |
2671 | c->set_output(0, output_shape); |
2672 | return OkStatus(); |
2673 | } |
2674 | auto in_rows = c->Value(in_rows_dim); |
2675 | auto in_cols = c->Value(in_cols_dim); |
2676 | |
2677 | Padding padding; |
2678 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
2679 | |
2680 | int64_t output_rows, output_cols; |
2681 | int64_t padding_before, padding_after; |
2682 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
2683 | in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, |
2684 | &padding_before, &padding_after)); |
2685 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
2686 | in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, |
2687 | &padding_before, &padding_after)); |
2688 | ShapeHandle output_shape = c->MakeShape( |
2689 | {batch_size_dim, output_rows, output_cols, output_depth_dim}); |
2690 | c->set_output(0, output_shape); |
2691 | return OkStatus(); |
2692 | }); |
2693 | |
2694 | // -------------------------------------------------------------------------- |
2695 | |
2696 | // To enable rates, uncomment all lines commented below and use ksize_*_eff |
2697 | // as the second parameter of all GetWindowedOutputSizeVerbose calls instead |
2698 | // of ksize_*. |
2699 | REGISTER_OP("ExtractVolumePatches" ) |
2700 | .Input("input: T" ) |
2701 | .Output("patches: T" ) |
2702 | .Attr("ksizes: list(int) >= 5" ) |
2703 | .Attr("strides: list(int) >= 5" ) |
2704 | /* .Attr("rates: list(int) >= 5") */ |
2705 | .Attr("T: realnumbertype" ) |
2706 | .Attr(GetPaddingAttrString()) |
2707 | .SetShapeFn([](InferenceContext* c) { |
2708 | ShapeHandle input_shape; |
2709 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); |
2710 | |
2711 | std::vector<int32> ksizes; |
2712 | TF_RETURN_IF_ERROR(c->GetAttr("ksizes" , &ksizes)); |
2713 | if (ksizes.size() != 5) { |
2714 | return errors::InvalidArgument( |
2715 | "ExtractVolumePatches requires the ksizes attribute to contain 5 " |
2716 | "values, but got: " , |
2717 | ksizes.size()); |
2718 | } |
2719 | |
2720 | std::vector<int32> strides; |
2721 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
2722 | if (strides.size() != 5) { |
2723 | return errors::InvalidArgument( |
2724 | "ExtractVolumePatches requires the stride attribute to contain 5 " |
2725 | "values, but got: " , |
2726 | strides.size()); |
2727 | } |
2728 | |
2729 | /* |
2730 | // TODO(hsgkim): Enable rates. |
2731 | // See extract_volume_patches_op.cc for why rates are disabled now. |
2732 | |
2733 | std::vector<int32> rates; |
2734 | TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); |
2735 | if (rates.size() != 5) { |
2736 | return errors::InvalidArgument( |
2737 | "ExtractVolumePatches requires the rates attribute to contain 5 " |
2738 | "values, but got: ", |
2739 | rates.size()); |
2740 | } |
2741 | */ |
2742 | |
2743 | int32_t ksize_planes = ksizes[1]; |
2744 | int32_t ksize_rows = ksizes[2]; |
2745 | int32_t ksize_cols = ksizes[3]; |
2746 | |
2747 | int32_t stride_planes = strides[1]; |
2748 | int32_t stride_rows = strides[2]; |
2749 | int32_t stride_cols = strides[3]; |
2750 | |
2751 | /* |
2752 | int32 rate_planes = rates[1]; |
2753 | int32 rate_rows = rates[2]; |
2754 | int32 rate_cols = rates[3]; |
2755 | |
2756 | int32 ksize_planes_eff = ksize_planes + |
2757 | (ksize_planes - 1) * (rate_planes - 1); |
2758 | int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); |
2759 | int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); |
2760 | */ |
2761 | |
2762 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
2763 | DimensionHandle in_planes_dim = c->Dim(input_shape, 1); |
2764 | DimensionHandle in_rows_dim = c->Dim(input_shape, 2); |
2765 | DimensionHandle in_cols_dim = c->Dim(input_shape, 3); |
2766 | DimensionHandle output_depth_dim; |
2767 | TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4), |
2768 | ksize_planes * ksize_rows * ksize_cols, |
2769 | &output_depth_dim)); |
2770 | |
2771 | if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || |
2772 | !c->ValueKnown(in_cols_dim)) { |
2773 | ShapeHandle output_shape = |
2774 | c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, |
2775 | InferenceContext::kUnknownDim, output_depth_dim}); |
2776 | c->set_output(0, output_shape); |
2777 | return OkStatus(); |
2778 | } |
2779 | auto in_planes = c->Value(in_planes_dim); |
2780 | auto in_rows = c->Value(in_rows_dim); |
2781 | auto in_cols = c->Value(in_cols_dim); |
2782 | |
2783 | Padding padding; |
2784 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
2785 | |
2786 | int64_t output_planes, output_rows, output_cols; |
2787 | int64_t padding_before, padding_after; |
2788 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
2789 | in_planes, ksize_planes, stride_planes, padding, &output_planes, |
2790 | &padding_before, &padding_after)); |
2791 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
2792 | in_rows, ksize_rows, stride_rows, padding, &output_rows, |
2793 | &padding_before, &padding_after)); |
2794 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
2795 | in_cols, ksize_cols, stride_cols, padding, &output_cols, |
2796 | &padding_before, &padding_after)); |
2797 | ShapeHandle output_shape = |
2798 | c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, |
2799 | output_depth_dim}); |
2800 | c->set_output(0, output_shape); |
2801 | return OkStatus(); |
2802 | }); |
2803 | |
2804 | // -------------------------------------------------------------------------- |
2805 | |
2806 | REGISTER_OP("OneHot" ) |
2807 | .Input("indices: TI" ) |
2808 | .Input("depth: int32" ) |
2809 | .Input("on_value: T" ) |
2810 | .Input("off_value: T" ) |
2811 | .Attr("axis: int = -1" ) |
2812 | .Output("output: T" ) |
2813 | .Attr("T: type" ) |
2814 | .Attr("TI: {uint8, int8, int32, int64} = DT_INT64" ) |
2815 | .SetShapeFn([](InferenceContext* c) { |
2816 | int32_t axis; |
2817 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , &axis)); |
2818 | if (axis < -1) return errors::InvalidArgument("axis must be >= -1" ); |
2819 | |
2820 | DimensionHandle depth; |
2821 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth)); |
2822 | |
2823 | ShapeHandle indices = c->input(0); |
2824 | if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c); |
2825 | |
2826 | int32_t new_rank = c->Rank(indices) + 1; |
2827 | // We need to add new_rank to axis in the case the axis is -1 because |
2828 | // C++ returns negative values from % if the dividend is negative. |
2829 | int32_t depth_index = (axis + new_rank) % new_rank; |
2830 | // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:]. |
2831 | ShapeHandle front; |
2832 | ShapeHandle back; |
2833 | ShapeHandle out; |
2834 | TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front)); |
2835 | TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back)); |
2836 | TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front)); |
2837 | TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out)); |
2838 | c->set_output(0, out); |
2839 | return OkStatus(); |
2840 | }); |
2841 | |
2842 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. |
2843 | REGISTER_OP("QuantizeAndDequantize" ) |
2844 | .Input("input: T" ) |
2845 | .Attr("signed_input: bool = true" ) |
2846 | .Attr("num_bits: int = 8" ) |
2847 | .Attr("range_given: bool = false" ) |
2848 | .Attr("input_min: float = 0" ) |
2849 | .Attr("input_max: float = 0" ) |
2850 | .Output("output: T" ) |
2851 | .Attr("T: {bfloat16, half, float, double}" ) |
2852 | .SetShapeFn(shape_inference::UnchangedShape) |
2853 | .Deprecated(22, "Replaced by QuantizeAndDequantizeV2" ); |
2854 | |
2855 | // TODO(suharshs): Deprecate QuantizeAndDequantizeV2. |
2856 | REGISTER_OP("QuantizeAndDequantizeV2" ) |
2857 | .Input("input: T" ) |
2858 | .Input("input_min: T" ) |
2859 | .Input("input_max: T" ) |
2860 | .Attr("signed_input: bool = true" ) |
2861 | .Attr("num_bits: int = 8" ) |
2862 | .Attr("range_given: bool = false" ) |
2863 | .Output("output: T" ) |
2864 | .Attr("T: {bfloat16, half, float, double}" ) |
2865 | .Attr( |
2866 | "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = " |
2867 | "'HALF_TO_EVEN'" ) |
2868 | .Attr("narrow_range: bool = false" ) |
2869 | .Attr("axis: int = -1" ) |
2870 | .SetShapeFn([](InferenceContext* c) { |
2871 | int axis; |
2872 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , &axis)); |
2873 | const int minmax_rank = (axis == -1) ? 0 : 1; |
2874 | ShapeHandle minmax; |
2875 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); |
2876 | TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax)); |
2877 | if (axis < -1) { |
2878 | return errors::InvalidArgument("axis should be at least -1, got " , |
2879 | axis); |
2880 | } else if (axis != -1) { |
2881 | ShapeHandle input; |
2882 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
2883 | DimensionHandle depth; |
2884 | TF_RETURN_IF_ERROR( |
2885 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
2886 | } |
2887 | c->set_output(0, c->input(0)); |
2888 | return OkStatus(); |
2889 | }); |
2890 | |
2891 | REGISTER_OP("QuantizeAndDequantizeV4" ) |
2892 | .Input("input: T" ) |
2893 | .Input("input_min: T" ) |
2894 | .Input("input_max: T" ) |
2895 | .Attr("signed_input: bool = true" ) |
2896 | .Attr("num_bits: int = 8" ) |
2897 | .Attr("range_given: bool = false" ) |
2898 | .Output("output: T" ) |
2899 | .Attr("T: {bfloat16, half, float, double}" ) |
2900 | .Attr( |
2901 | "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = " |
2902 | "'HALF_TO_EVEN'" ) |
2903 | .Attr("narrow_range: bool = false" ) |
2904 | .Attr("axis: int = -1" ) |
2905 | .SetShapeFn([](InferenceContext* c) { |
2906 | int axis; |
2907 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , &axis)); |
2908 | const int minmax_rank = (axis == -1) ? 0 : 1; |
2909 | ShapeHandle minmax; |
2910 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); |
2911 | TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax)); |
2912 | if (axis < -1) { |
2913 | return errors::InvalidArgument("axis should be at least -1, got " , |
2914 | axis); |
2915 | } else if (axis != -1) { |
2916 | ShapeHandle input; |
2917 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
2918 | DimensionHandle depth; |
2919 | TF_RETURN_IF_ERROR( |
2920 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
2921 | } |
2922 | c->set_output(0, c->input(0)); |
2923 | return OkStatus(); |
2924 | }); |
2925 | |
2926 | REGISTER_OP("QuantizeAndDequantizeV4Grad" ) |
2927 | .Input("gradients: T" ) |
2928 | .Input("input: T" ) |
2929 | .Input("input_min: T" ) |
2930 | .Input("input_max: T" ) |
2931 | .Output("input_backprop: T" ) |
2932 | .Output("input_min_backprop: T" ) |
2933 | .Output("input_max_backprop: T" ) |
2934 | .Attr("T: {bfloat16, half, float, double}" ) |
2935 | .Attr("axis: int = -1" ) |
2936 | .SetShapeFn([](InferenceContext* c) { |
2937 | int axis; |
2938 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , &axis)); |
2939 | const int minmax_rank = (axis == -1) ? 0 : 1; |
2940 | ShapeHandle minmax; |
2941 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax)); |
2942 | TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax)); |
2943 | if (axis < -1) { |
2944 | return errors::InvalidArgument("axis should be at least -1, got " , |
2945 | axis); |
2946 | } else if (axis != -1) { |
2947 | ShapeHandle input; |
2948 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
2949 | DimensionHandle depth; |
2950 | TF_RETURN_IF_ERROR( |
2951 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
2952 | } |
2953 | ShapeHandle inputs; |
2954 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs)); |
2955 | c->set_output(0, inputs); |
2956 | c->set_output(1, minmax); |
2957 | c->set_output(2, minmax); |
2958 | return OkStatus(); |
2959 | }); |
2960 | |
2961 | REGISTER_OP("QuantizeAndDequantizeV3" ) |
2962 | .Input("input: T" ) |
2963 | .Input("input_min: T" ) |
2964 | .Input("input_max: T" ) |
2965 | .Input("num_bits: int32" ) |
2966 | .Attr("signed_input: bool = true" ) |
2967 | .Attr("range_given: bool = true" ) |
2968 | .Output("output: T" ) |
2969 | .Attr("T: {bfloat16, half, float, double}" ) |
2970 | .Attr("narrow_range: bool = false" ) |
2971 | .Attr("axis: int = -1" ) |
2972 | .SetShapeFn([](InferenceContext* c) { |
2973 | int axis; |
2974 | TF_RETURN_IF_ERROR(c->GetAttr("axis" , &axis)); |
2975 | const int minmax_rank = (axis == -1) ? 0 : 1; |
2976 | ShapeHandle minmax; |
2977 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); |
2978 | TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax)); |
2979 | if (axis < -1) { |
2980 | return errors::InvalidArgument("axis should be at least -1, got " , |
2981 | axis); |
2982 | } else if (axis != -1) { |
2983 | ShapeHandle input; |
2984 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
2985 | DimensionHandle depth; |
2986 | TF_RETURN_IF_ERROR( |
2987 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
2988 | } |
2989 | ShapeHandle unused; |
2990 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2991 | c->set_output(0, c->input(0)); |
2992 | return OkStatus(); |
2993 | }); |
2994 | |
2995 | REGISTER_OP("QuantizeV2" ) |
2996 | .Input("input: float" ) |
2997 | .Input("min_range: float" ) |
2998 | .Input("max_range: float" ) |
2999 | .Output("output: T" ) |
3000 | .Output("output_min: float" ) |
3001 | .Output("output_max: float" ) |
3002 | .Attr("T: quantizedtype" ) |
3003 | .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'" ) |
3004 | .Attr( |
3005 | "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = " |
3006 | "'HALF_AWAY_FROM_ZERO'" ) |
3007 | .Attr("narrow_range: bool = false" ) |
3008 | .Attr("axis: int = -1" ) |
3009 | .Attr("ensure_minimum_range: float = 0.01" ) |
3010 | .SetShapeFn(shape_inference::QuantizeV2Shape); |
3011 | |
3012 | REGISTER_OP("Dequantize" ) |
3013 | .Input("input: T" ) |
3014 | .Input("min_range: float" ) |
3015 | .Input("max_range: float" ) |
3016 | .Output("output: dtype" ) |
3017 | .Attr("T: quantizedtype" ) |
3018 | .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'" ) |
3019 | .Attr("narrow_range: bool = false" ) |
3020 | .Attr("axis: int = -1" ) |
3021 | .Attr("dtype: {bfloat16, float} = DT_FLOAT" ) |
3022 | .SetShapeFn([](InferenceContext* c) { |
3023 | int axis = -1; |
3024 | Status s = c->GetAttr("axis" , &axis); |
3025 | if (!s.ok() && s.code() != error::NOT_FOUND) { |
3026 | return s; |
3027 | } |
3028 | if (axis < -1) { |
3029 | return errors::InvalidArgument("axis should be at least -1, got " , |
3030 | axis); |
3031 | } |
3032 | auto input_dims = c->Rank(c->input(0)); |
3033 | if (axis > input_dims) { |
3034 | return errors::InvalidArgument( |
3035 | "Axis must be less than input dimension(" , input_dims, "), got " , |
3036 | axis); |
3037 | } |
3038 | const int minmax_rank = (axis == -1) ? 0 : 1; |
3039 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
3040 | ShapeHandle minmax; |
3041 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); |
3042 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax)); |
3043 | if (axis != -1) { |
3044 | ShapeHandle input; |
3045 | if (axis >= kint32max) { |
3046 | // Check int32 max bound for a corner case to prevent integer flow |
3047 | // when input actually has kint32max rank and above bound check is not |
3048 | // triggered. |
3049 | return errors::InvalidArgument( |
3050 | "Axis cannot be >= kint32max value, got " , axis); |
3051 | } |
3052 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); |
3053 | DimensionHandle depth; |
3054 | TF_RETURN_IF_ERROR( |
3055 | c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); |
3056 | } |
3057 | return OkStatus(); |
3058 | }); |
3059 | |
3060 | REGISTER_OP("QuantizedConcat" ) |
3061 | .Input("concat_dim: int32" ) |
3062 | .Input("values: N * T" ) |
3063 | .Input("input_mins: N * float32" ) |
3064 | .Input("input_maxes: N * float32" ) |
3065 | .Output("output: T" ) |
3066 | .Output("output_min: float" ) |
3067 | .Output("output_max: float" ) |
3068 | .Attr("N: int >= 2" ) |
3069 | .Attr("T: type" ) |
3070 | .SetShapeFn([](InferenceContext* c) { |
3071 | const int n = (c->num_inputs() - 1) / 3; |
3072 | TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n)); |
3073 | ShapeHandle unused; |
3074 | for (int i = n + 1; i < c->num_inputs(); ++i) { |
3075 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); |
3076 | } |
3077 | c->set_output(1, c->Scalar()); |
3078 | c->set_output(2, c->Scalar()); |
3079 | return OkStatus(); |
3080 | }); |
3081 | |
3082 | REGISTER_OP("QuantizedReshape" ) |
3083 | .Input("tensor: T" ) |
3084 | .Input("shape: Tshape" ) |
3085 | .Input("input_min: float" ) |
3086 | .Input("input_max: float" ) |
3087 | .Output("output: T" ) |
3088 | .Output("output_min: float" ) |
3089 | .Output("output_max: float" ) |
3090 | .Attr("T: type" ) |
3091 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
3092 | .SetShapeFn([](InferenceContext* c) { |
3093 | TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c)); |
3094 | ShapeHandle unused; |
3095 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
3096 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3097 | c->set_output(1, c->Scalar()); |
3098 | c->set_output(2, c->Scalar()); |
3099 | return OkStatus(); |
3100 | }); |
3101 | |
3102 | REGISTER_OP("QuantizedInstanceNorm" ) |
3103 | .Input("x: T" ) |
3104 | .Input("x_min: float" ) |
3105 | .Input("x_max: float" ) |
3106 | .Output("y: T" ) |
3107 | .Output("y_min: float" ) |
3108 | .Output("y_max: float" ) |
3109 | .Attr("T: quantizedtype" ) |
3110 | .Attr("output_range_given: bool = false" ) |
3111 | .Attr("given_y_min: float = 0" ) |
3112 | .Attr("given_y_max: float = 0" ) |
3113 | .Attr("variance_epsilon: float = 1e-5" ) |
3114 | .Attr("min_separation: float = 1e-3" ) |
3115 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
3116 | shape_inference::ShapeHandle unused; |
3117 | // x should be a rank 4 tensor. |
3118 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused)); |
3119 | // Assert x_min and x_max are scalars (rank 0). |
3120 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
3121 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
3122 | // y has the same shape as x. |
3123 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
3124 | // y_min and y_max are scalars. |
3125 | c->set_output(1, c->Scalar()); |
3126 | c->set_output(2, c->Scalar()); |
3127 | return OkStatus(); |
3128 | }); |
3129 | |
3130 | namespace { |
3131 | |
3132 | Status ScatterNdTensorShape(InferenceContext* c) { |
3133 | ShapeHandle output_shape; |
3134 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape)); |
3135 | ShapeHandle indices_shape; |
3136 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); |
3137 | ShapeHandle updates_shape; |
3138 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 0, &updates_shape)); |
3139 | return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, |
3140 | output_shape); |
3141 | } |
3142 | |
3143 | } // namespace |
3144 | |
3145 | REGISTER_OP("UpperBound" ) |
3146 | .Input("sorted_inputs: T" ) |
3147 | .Input("values: T" ) |
3148 | .Output("output: out_type" ) |
3149 | .Attr("T: type" ) |
3150 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
3151 | .SetShapeFn([](InferenceContext* c) { |
3152 | ShapeHandle unused_shape; |
3153 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); |
3154 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); |
3155 | c->set_output(0, c->input(1)); |
3156 | return OkStatus(); |
3157 | }); |
3158 | |
3159 | REGISTER_OP("LowerBound" ) |
3160 | .Input("sorted_inputs: T" ) |
3161 | .Input("values: T" ) |
3162 | .Output("output: out_type" ) |
3163 | .Attr("T: type" ) |
3164 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
3165 | .SetShapeFn([](InferenceContext* c) { |
3166 | ShapeHandle unused_shape; |
3167 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); |
3168 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); |
3169 | c->set_output(0, c->input(1)); |
3170 | return OkStatus(); |
3171 | }); |
3172 | |
3173 | REGISTER_OP("ScatterNd" ) |
3174 | .Input("indices: Tindices" ) |
3175 | .Input("updates: T" ) |
3176 | .Input("shape: Tindices" ) |
3177 | .Output("output: T" ) |
3178 | .Attr("T: type" ) |
3179 | .Attr("Tindices: {int16, int32, int64}" ) |
3180 | .SetShapeFn([](InferenceContext* c) { |
3181 | ShapeHandle indices_shape; |
3182 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); |
3183 | ShapeHandle updates_shape; |
3184 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); |
3185 | ShapeHandle output_shape; |
3186 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); |
3187 | return shape_inference::ScatterNdShapeHelper(c, indices_shape, |
3188 | updates_shape, output_shape); |
3189 | }); |
3190 | |
3191 | REGISTER_OP("TensorScatterUpdate" ) |
3192 | .Input("tensor: T" ) |
3193 | .Input("indices: Tindices" ) |
3194 | .Input("updates: T" ) |
3195 | .Output("output: T" ) |
3196 | .Attr("T: type" ) |
3197 | .Attr("Tindices: {int16, int32, int64, uint16}" ) |
3198 | .SetShapeFn(ScatterNdTensorShape); |
3199 | |
3200 | REGISTER_OP("TensorScatterAdd" ) |
3201 | .Input("tensor: T" ) |
3202 | .Input("indices: Tindices" ) |
3203 | .Input("updates: T" ) |
3204 | .Output("output: T" ) |
3205 | .Attr("T: type" ) |
3206 | .Attr("Tindices: {int32, int64}" ) |
3207 | .SetShapeFn(ScatterNdTensorShape); |
3208 | |
3209 | REGISTER_OP("TensorScatterSub" ) |
3210 | .Input("tensor: T" ) |
3211 | .Input("indices: Tindices" ) |
3212 | .Input("updates: T" ) |
3213 | .Output("output: T" ) |
3214 | .Attr("T: type" ) |
3215 | .Attr("Tindices: {int32, int64}" ) |
3216 | .SetShapeFn(ScatterNdTensorShape); |
3217 | |
3218 | REGISTER_OP("TensorScatterMin" ) |
3219 | .Input("tensor: T" ) |
3220 | .Input("indices: Tindices" ) |
3221 | .Input("updates: T" ) |
3222 | .Output("output: T" ) |
3223 | .Attr("T: type" ) |
3224 | .Attr("Tindices: {int32, int64}" ) |
3225 | .SetShapeFn(ScatterNdTensorShape); |
3226 | |
3227 | REGISTER_OP("TensorScatterMax" ) |
3228 | .Input("tensor: T" ) |
3229 | .Input("indices: Tindices" ) |
3230 | .Input("updates: T" ) |
3231 | .Output("output: T" ) |
3232 | .Attr("T: type" ) |
3233 | .Attr("Tindices: {int32, int64}" ) |
3234 | .SetShapeFn(ScatterNdTensorShape); |
3235 | |
3236 | REGISTER_OP("ScatterNdNonAliasingAdd" ) |
3237 | .Input("input: T" ) |
3238 | .Input("indices: Tindices" ) |
3239 | .Input("updates: T" ) |
3240 | .Output("output: T" ) |
3241 | .Attr("T: {numbertype, bool}" ) |
3242 | .Attr("Tindices: {int32, int64}" ) |
3243 | .SetShapeFn(ScatterNdTensorShape); |
3244 | |
3245 | REGISTER_OP("FakeQuantWithMinMaxArgs" ) |
3246 | .Attr("min: float = -6.0" ) |
3247 | .Attr("max: float = 6.0" ) |
3248 | .Attr("num_bits: int = 8" ) |
3249 | .Attr("narrow_range: bool = false" ) |
3250 | .Input("inputs: float" ) |
3251 | .Output("outputs: float" ) |
3252 | .SetShapeFn(shape_inference::UnchangedShape); |
3253 | |
3254 | REGISTER_OP("FakeQuantWithMinMaxArgsGradient" ) |
3255 | .Attr("min: float = -6.0" ) |
3256 | .Attr("max: float = 6.0" ) |
3257 | .Attr("num_bits: int = 8" ) |
3258 | .Attr("narrow_range: bool = false" ) |
3259 | .Input("gradients: float" ) |
3260 | .Input("inputs: float" ) |
3261 | .Output("backprops: float" ) |
3262 | .SetShapeFn(shape_inference::UnchangedShape); |
3263 | |
3264 | REGISTER_OP("FakeQuantWithMinMaxVars" ) |
3265 | .Attr("num_bits: int = 8" ) |
3266 | .Attr("narrow_range: bool = false" ) |
3267 | .Input("inputs: float" ) |
3268 | .Input("min: float" ) |
3269 | .Input("max: float" ) |
3270 | .Output("outputs: float" ) |
3271 | .SetShapeFn([](InferenceContext* c) { |
3272 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
3273 | ShapeHandle unused; |
3274 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
3275 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
3276 | return OkStatus(); |
3277 | }); |
3278 | |
3279 | REGISTER_OP("FakeQuantWithMinMaxVarsGradient" ) |
3280 | .Attr("num_bits: int = 8" ) |
3281 | .Attr("narrow_range: bool = false" ) |
3282 | .Input("gradients: float" ) |
3283 | .Input("inputs: float" ) |
3284 | .Input("min: float" ) |
3285 | .Input("max: float" ) |
3286 | .Output("backprops_wrt_input: float" ) |
3287 | .Output("backprop_wrt_min: float" ) |
3288 | .Output("backprop_wrt_max: float" ) |
3289 | .SetShapeFn([](InferenceContext* c) { |
3290 | // gradients and inputs are same size. |
3291 | ShapeHandle inputs; |
3292 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs)); |
3293 | |
3294 | // min and max are scalars |
3295 | ShapeHandle min_max; |
3296 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max)); |
3297 | TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max)); |
3298 | |
3299 | c->set_output(0, inputs); |
3300 | c->set_output(1, min_max); |
3301 | c->set_output(2, min_max); |
3302 | return OkStatus(); |
3303 | }); |
3304 | |
3305 | REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel" ) |
3306 | .Attr("num_bits: int = 8" ) |
3307 | .Attr("narrow_range: bool = false" ) |
3308 | .Input("inputs: float" ) |
3309 | .Input("min: float" ) |
3310 | .Input("max: float" ) |
3311 | .Output("outputs: float" ) |
3312 | .SetShapeFn([](InferenceContext* c) { |
3313 | ShapeHandle input, min, max; |
3314 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); |
3315 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min)); |
3316 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max)); |
3317 | |
3318 | DimensionHandle unused; |
3319 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused)); |
3320 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused)); |
3321 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused)); |
3322 | |
3323 | c->set_output(0, input); |
3324 | return OkStatus(); |
3325 | }); |
3326 | |
3327 | REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient" ) |
3328 | .Attr("num_bits: int = 8" ) |
3329 | .Attr("narrow_range: bool = false" ) |
3330 | .Input("gradients: float" ) |
3331 | .Input("inputs: float" ) |
3332 | .Input("min: float" ) |
3333 | .Input("max: float" ) |
3334 | .Output("backprops_wrt_input: float" ) |
3335 | .Output("backprop_wrt_min: float" ) |
3336 | .Output("backprop_wrt_max: float" ) |
3337 | .SetShapeFn([](InferenceContext* c) { |
3338 | ShapeHandle inputs; |
3339 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs)); |
3340 | TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs)); |
3341 | TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs)); |
3342 | |
3343 | ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1)); |
3344 | |
3345 | ShapeHandle min_max; |
3346 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max)); |
3347 | TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max)); |
3348 | TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max)); |
3349 | |
3350 | c->set_output(0, inputs); |
3351 | c->set_output(1, min_max); |
3352 | c->set_output(2, min_max); |
3353 | return OkStatus(); |
3354 | }); |
3355 | |
3356 | REGISTER_OP("Fingerprint" ) |
3357 | .Input("data: T" ) |
3358 | .Input("method: string" ) |
3359 | .Output("fingerprint: uint8" ) |
3360 | .Attr("T: type" ) |
3361 | .SetShapeFn([](InferenceContext* c) { |
3362 | ShapeHandle unused; |
3363 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); |
3364 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
3365 | |
3366 | DimensionHandle fingerprint_size; |
3367 | const Tensor* method = c->input_tensor(1); |
3368 | if (method == nullptr) { |
3369 | fingerprint_size = c->UnknownDim(); |
3370 | } else { |
3371 | if (method->dims() != 0) { |
3372 | return errors::InvalidArgument("`method` must be rank 0: " , |
3373 | method->shape()); |
3374 | } |
3375 | const string& method_string = method->scalar<tstring>()(); |
3376 | if (method_string != "farmhash64" ) { |
3377 | return errors::InvalidArgument("Unsupported method: " , method_string); |
3378 | } |
3379 | fingerprint_size = c->MakeDim(sizeof(uint64)); |
3380 | } |
3381 | |
3382 | DimensionHandle batch = c->Dim(c->input(0), 0); |
3383 | c->set_output(0, c->MakeShape({batch, fingerprint_size})); |
3384 | return OkStatus(); |
3385 | }); |
3386 | |
3387 | #ifdef INTEL_MKL |
3388 | REGISTER_OP("_MklConcat" ) |
3389 | .Input("concat_dim: int32" ) |
3390 | .Input("values: N * T" ) |
3391 | .Input("mkl_concat_dim: uint8" ) |
3392 | .Input("mkl_values: N * uint8" ) |
3393 | .Output("output: T" ) |
3394 | .Output("mkl_output: uint8" ) |
3395 | .Attr("N: int >= 2" ) |
3396 | .Attr("T: type" ) |
3397 | .SetShapeFn([](InferenceContext* c) { |
3398 | return shape_inference::ConcatShape(c, c->num_inputs() - 3); |
3399 | }) |
3400 | .Doc(R"doc( |
3401 | MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation. |
3402 | |
3403 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
3404 | expected to invoke these operators. |
3405 | )doc" ); |
3406 | #endif |
3407 | |
3408 | // Deprecated op registrations: |
3409 | |
3410 | // The following can be deleted after 10mar2017. |
3411 | REGISTER_OP("BatchMatrixDiag" ) |
3412 | .Input("diagonal: T" ) |
3413 | .Output("output: T" ) |
3414 | .Attr("T: type" ) |
3415 | .Deprecated(14, "Use MatrixDiag" ) |
3416 | .SetShapeFn(shape_inference::UnknownShape); |
3417 | REGISTER_OP("BatchMatrixSetDiag" ) |
3418 | .Input("input: T" ) |
3419 | .Input("diagonal: T" ) |
3420 | .Output("output: T" ) |
3421 | .Attr("T: type" ) |
3422 | .Deprecated(14, "Use MatrixSetDiag" ) |
3423 | .SetShapeFn(shape_inference::UnknownShape); |
3424 | REGISTER_OP("BatchMatrixDiagPart" ) |
3425 | .Input("input: T" ) |
3426 | .Output("diagonal: T" ) |
3427 | .Attr("T: type" ) |
3428 | .Deprecated(14, "Use MatrixDiagPart" ) |
3429 | .SetShapeFn(shape_inference::UnknownShape); |
3430 | REGISTER_OP("BatchMatrixBandPart" ) |
3431 | .Input("input: T" ) |
3432 | .Input("num_lower: int64" ) |
3433 | .Input("num_upper: int64" ) |
3434 | .Output("band: T" ) |
3435 | .Attr("T: type" ) |
3436 | .Deprecated(14, "Use MatrixBandPart" ) |
3437 | .SetShapeFn(shape_inference::UnknownShape); |
3438 | |
3439 | } // namespace tensorflow |
3440 | |