1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <ostream>
18
19#include "tensorflow/core/framework/common_shape_fns.h"
20#include "tensorflow/core/framework/kernel_shape_util.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/shape_inference.h"
23#include "tensorflow/core/framework/tensor.pb.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/framework/types.pb.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/platform/types.h"
28#include "tensorflow/core/util/mirror_pad_mode.h"
29#include "tensorflow/core/util/padding.h"
30#include "tensorflow/core/util/strided_slice_op.h"
31#include "tensorflow/core/util/tensor_format.h"
32
33namespace tensorflow {
34
35using shape_inference::DimensionHandle;
36using shape_inference::InferenceContext;
37using shape_inference::ShapeHandle;
38using shape_inference::UnchangedShape;
39
40namespace {
41
42Status GetAxisForPackAndUnpack(InferenceContext* c, int32_t rank_after_pack,
43 int32* axis) {
44 TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
45 if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
46 return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
47 -1 * rank_after_pack, ",", rank_after_pack,
48 ")");
49 }
50 if (*axis < 0) *axis = (rank_after_pack + *axis);
51 return OkStatus();
52}
53
54template <typename T>
55std::vector<int64_t> AsInt64(const Tensor* tensor, int64_t num_elements) {
56 std::vector<int64_t> ret(num_elements);
57 auto data = tensor->vec<T>();
58 for (int64_t i = 0; i < num_elements; ++i) {
59 ret[i] = data(i);
60 }
61 return ret;
62}
63
64template <typename T>
65Status PadKnown(InferenceContext* c, ShapeHandle input,
66 const Tensor* paddings_t, int64_t num_dims) {
67 // paddings_t is known.
68 std::vector<DimensionHandle> dims(num_dims);
69 auto paddings_data = paddings_t->matrix<T>();
70 for (int64_t i = 0; i < num_dims; ++i) {
71 const T pad0 = paddings_data(i, 0);
72 const T pad1 = paddings_data(i, 1);
73 if (pad0 < 0 || pad1 < 0) {
74 return errors::InvalidArgument("Paddings must be non-negative");
75 }
76 TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
77 }
78 c->set_output(0, c->MakeShape(dims));
79 return OkStatus();
80}
81
82Status PadShapeFn(InferenceContext* c) {
83 // Paddings is a matrix of [input_rank, 2].
84 ShapeHandle paddings;
85 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
86 DimensionHandle unused;
87 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
88
89 // n_dim and input.rank are equivalent.
90 ShapeHandle input = c->input(0);
91 DimensionHandle n_dim = c->Dim(paddings, 0);
92 if (c->ValueKnown(n_dim)) {
93 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
94 } else if (c->RankKnown(input)) {
95 TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
96 }
97
98 const Tensor* paddings_t = c->input_tensor(1);
99
100 // paddings_t is unknown
101 if (paddings_t == nullptr) {
102 if (c->ValueKnown(n_dim)) {
103 // Make output with n_dim unknown dims.
104 c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
105 } else {
106 c->set_output(0, c->UnknownShape());
107 }
108 return OkStatus();
109 }
110
111 const int64_t num_dims = paddings_t->shape().dim_size(0);
112 TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
113 TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
114
115 if (paddings_t->dtype() == DT_INT32) {
116 return PadKnown<int32>(c, input, paddings_t, num_dims);
117 } else {
118 return PadKnown<int64_t>(c, input, paddings_t, num_dims);
119 }
120}
121
122Status TransposeShapeFn(InferenceContext* c) {
123 ShapeHandle input = c->input(0);
124 ShapeHandle perm_shape = c->input(1);
125 const Tensor* perm = c->input_tensor(1);
126 DimensionHandle perm_elems = c->NumElements(perm_shape);
127 // If we don't have rank information on the input or value information on
128 // perm we can't return any shape information, otherwise we have enough
129 // information to at least find the rank of the output.
130 if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
131 c->set_output(0, c->UnknownShape());
132 return OkStatus();
133 }
134
135 // Find our value of the rank.
136 int64_t rank;
137 if (c->RankKnown(input)) {
138 rank = c->Rank(input);
139 } else if (c->ValueKnown(perm_elems)) {
140 rank = c->Value(perm_elems);
141 } else {
142 rank = perm->NumElements();
143 }
144 if (!c->RankKnown(input) && rank < 2) {
145 // A permutation array containing a single element is ambiguous. It could
146 // indicate either a scalar or a 1-dimensional array, both of which the
147 // transpose op returns unchanged.
148 c->set_output(0, input);
149 return OkStatus();
150 }
151
152 std::vector<DimensionHandle> dims;
153 dims.resize(rank);
154 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
155 // Ensure that perm is a vector and has rank elements.
156 TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
157 TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
158
159 // If we know the rank of the input and the value of perm, we can return
160 // all shape information, otherwise we can only return rank information,
161 // but no information for the dimensions.
162 if (perm != nullptr) {
163 std::vector<int64_t> data;
164 if (perm->dtype() == DT_INT32) {
165 data = AsInt64<int32>(perm, rank);
166 } else {
167 data = AsInt64<int64_t>(perm, rank);
168 }
169
170 for (int32_t i = 0; i < rank; ++i) {
171 int64_t in_idx = data[i];
172 if (in_idx >= rank || in_idx <= -rank) {
173 return errors::InvalidArgument("perm dim ", in_idx,
174 " is out of range of input rank ", rank);
175 }
176 dims[i] = c->Dim(input, in_idx);
177 }
178 } else {
179 for (int i = 0; i < rank; ++i) {
180 dims[i] = c->UnknownDim();
181 }
182 }
183
184 c->set_output(0, c->MakeShape(dims));
185 return OkStatus();
186}
187
188Status SetOutputShapeForReshape(InferenceContext* c) {
189 ShapeHandle in = c->input(0);
190 ShapeHandle out;
191 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
192
193 if (!c->RankKnown(out)) {
194 // We have no information about the shape of the output.
195 c->set_output(0, out);
196 return OkStatus();
197 }
198 if (c->RankKnown(in)) {
199 // We don't know the number of output elements, but we can try to infer
200 // the missing dimension.
201 bool too_many_unknown = false;
202 int32_t out_unknown_idx = -1;
203
204 DimensionHandle known_out_elems = c->NumElements(out);
205 if (!c->ValueKnown(known_out_elems)) {
206 known_out_elems = c->MakeDim(1);
207 for (int32_t i = 0; i < c->Rank(out); ++i) {
208 DimensionHandle dim = c->Dim(out, i);
209 if (!c->ValueKnown(dim)) {
210 if (out_unknown_idx >= 0) {
211 too_many_unknown = true;
212 break;
213 }
214 out_unknown_idx = i;
215 } else {
216 TF_RETURN_IF_ERROR(
217 c->Multiply(known_out_elems, dim, &known_out_elems));
218 }
219 }
220 }
221 int32_t in_unknown_idx = -1;
222 DimensionHandle known_in_elems = c->NumElements(in);
223 if (!c->ValueKnown(known_in_elems)) {
224 known_in_elems = c->MakeDim(1);
225 for (int32_t i = 0; i < c->Rank(in); ++i) {
226 DimensionHandle dim = c->Dim(in, i);
227 if (!c->ValueKnown(dim)) {
228 if (in_unknown_idx >= 0) {
229 too_many_unknown = true;
230 break;
231 }
232 in_unknown_idx = i;
233 } else {
234 TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
235 }
236 }
237 }
238
239 if (!too_many_unknown) {
240 if (in_unknown_idx < 0 && out_unknown_idx < 0) {
241 // Just check that the dimensions match.
242 if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
243 return errors::InvalidArgument(
244 "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
245 " elements to shape ", c->DebugString(out), " (",
246 c->DebugString(known_out_elems), " elements)");
247 }
248 } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
249 c->Value(known_out_elems) > 0) {
250 // Input fully known, infer the one missing output dim
251 DimensionHandle inferred_dim;
252 TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
253 true /* evenly_divisible */,
254 &inferred_dim));
255 TF_RETURN_IF_ERROR(
256 c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
257
258 } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
259 c->Value(known_in_elems) != 0) {
260 // Output fully known, infer the one missing input dim
261 DimensionHandle inferred_dim;
262 TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
263 true /* evenly_divisible */,
264 &inferred_dim));
265 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
266 TF_RETURN_IF_ERROR(
267 c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
268 } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
269 // Exactly one unknown dimension in both input and output. These 2 are
270 // equal iff the known elements are equal.
271 if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
272 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
273 TF_RETURN_IF_ERROR(
274 c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
275 }
276 }
277 }
278 }
279 c->set_output(0, out);
280 return OkStatus();
281}
282
283} // namespace
284
285REGISTER_OP("ParallelConcat")
286 .Input("values: N * T")
287 .Output("output: T")
288 .Attr("N: int >= 1")
289 .Attr("T: type")
290 .Attr("shape: shape")
291 .SetShapeFn([](InferenceContext* c) {
292 // Validate that the shape attr is correct.
293 PartialTensorShape shape;
294 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
295 ShapeHandle passed_shape;
296 TF_RETURN_IF_ERROR(
297 c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
298 if (!c->FullyDefined(passed_shape)) {
299 return errors::InvalidArgument("shape attr must be fully defined.");
300 }
301 ShapeHandle cur;
302 TF_RETURN_IF_ERROR(c->ReplaceDim(
303 passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
304 &cur));
305 for (int i = 0; i < c->num_inputs(); ++i) {
306 if (!c->FullyDefined(c->input(i))) {
307 return errors::InvalidArgument(
308 "All input shapes must be fully defined.");
309 }
310 DimensionHandle unused;
311 if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
312 return errors::InvalidArgument("Size of first dimension must be 1.");
313 }
314 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
315 "From merging shape ", i,
316 " with other shapes.");
317 }
318
319 c->set_output(0, passed_shape);
320
321 return OkStatus();
322 });
323
324REGISTER_OP("Pack")
325 .Input("values: N * T")
326 .Output("output: T")
327 .Attr("N: int >= 1")
328 .Attr("T: type")
329 .Attr("axis: int = 0")
330 .SetShapeFn([](InferenceContext* c) {
331 // Validate shapes of all inputs are compatible
332 ShapeHandle cur = c->input(c->num_inputs() - 1);
333 for (int i = c->num_inputs() - 2; i >= 0; --i) {
334 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
335 "From merging shape ", i,
336 " with other shapes.");
337 }
338 if (!c->RankKnown(cur)) {
339 c->set_output(0, c->UnknownShape());
340 return OkStatus();
341 }
342 // Determine the axis that will be added, converting from negative
343 // axes to a positive point per negative indexing rules.
344 int32_t rank = c->Rank(cur);
345 int32_t axis;
346 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
347
348 // Copy all dimensions over, inserting a dimension of value #inputs
349 // at <axis>.
350 std::vector<DimensionHandle> dims;
351 int index = 0;
352 while (index < axis) dims.push_back(c->Dim(cur, index++));
353 dims.push_back(c->MakeDim(c->num_inputs()));
354 while (index < rank) dims.push_back(c->Dim(cur, index++));
355
356 c->set_output(0, c->MakeShape(dims));
357 for (int i = 0; i < c->num_inputs(); ++i) {
358 auto* shape_and_type = c->input_handle_shapes_and_types(i);
359 if (shape_and_type) {
360 if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
361 c->set_output_handle_shapes_and_types(
362 0, std::vector<shape_inference::ShapeAndType>({}));
363 break;
364 }
365 }
366 }
367 return OkStatus();
368 });
369
370REGISTER_OP("DeepCopy")
371 .Input("x: T")
372 .Output("y: T")
373 .Attr("T: type")
374 .SetIsStateful()
375 .SetShapeFn(UnchangedShape);
376
377REGISTER_OP("InplaceUpdate")
378 .Input("x: T")
379 .Input("i: int32")
380 .Input("v: T")
381 .Output("y: T")
382 .Attr("T: type")
383 .SetShapeFn(UnchangedShape);
384
385REGISTER_OP("InplaceAdd")
386 .Input("x: T")
387 .Input("i: int32")
388 .Input("v: T")
389 .Output("y: T")
390 .Attr("T: type")
391 .SetShapeFn(UnchangedShape);
392
393REGISTER_OP("InplaceSub")
394 .Input("x: T")
395 .Input("i: int32")
396 .Input("v: T")
397 .Output("y: T")
398 .Attr("T: type")
399 .SetShapeFn(UnchangedShape);
400
401REGISTER_OP("Empty")
402 .Input("shape: int32")
403 .Output("output: dtype")
404 .Attr("dtype: type")
405 .Attr("init: bool = false")
406 .SetDoNotOptimize()
407 .SetShapeFn([](InferenceContext* c) {
408 ShapeHandle out;
409 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
410 c->set_output(0, out);
411 return OkStatus();
412 });
413
414// --------------------------------------------------------------------------
415REGISTER_OP("Unpack")
416 .Input("value: T")
417 .Output("output: num * T")
418 .Attr("num: int >= 0")
419 .Attr("T: type")
420 .Attr("axis: int = 0")
421 .SetShapeFn([](InferenceContext* c) {
422 ShapeHandle s = c->input(0);
423 ShapeHandle out;
424 if (c->RankKnown(s)) {
425 // Determine the axis that will be removed, converting from negative
426 // axes to a positive point per negative indexing rules.
427 int32_t rank = c->Rank(s);
428 int32_t axis;
429 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
430
431 // The axis dim matches the number of outputs.
432 DimensionHandle unused;
433 TF_RETURN_IF_ERROR(
434 c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
435
436 // Copy all dimensions, removing the <axis> dimension.
437 std::vector<DimensionHandle> dims;
438 for (int i = 0; i < rank; ++i) {
439 if (i != axis) dims.push_back(c->Dim(s, i));
440 }
441 out = c->MakeShape(dims);
442 } else {
443 // All outputs are the same shape, but it's not known.
444 out = c->UnknownShape();
445 }
446 for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
447 return OkStatus();
448 });
449
450REGISTER_OP("UnravelIndex")
451 .Input("indices: Tidx")
452 .Input("dims: Tidx")
453 .Output("output: Tidx")
454 .Attr("Tidx: {int32, int64} = DT_INT32")
455 .SetShapeFn([](InferenceContext* c) {
456 ShapeHandle indices = c->input(0);
457 ShapeHandle dims;
458 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
459 if (c->RankKnown(indices) && c->Rank(indices) == 0) {
460 c->set_output(0, c->Vector(c->Dim(dims, 0)));
461 } else if (c->RankKnown(indices)) {
462 c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
463 } else {
464 c->set_output(0, c->UnknownShape());
465 }
466 return OkStatus();
467 });
468
469REGISTER_OP("BroadcastTo")
470 .Input("input: T")
471 .Input("shape: Tidx")
472 .Output("output: T")
473 .Attr("T: type")
474 .Attr("Tidx: {int32, int64} = DT_INT32")
475 .SetShapeFn([](InferenceContext* c) {
476 ShapeHandle shape_in = c->input(1);
477 TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
478 ShapeHandle out;
479 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
480 if (!c->RankKnown(out)) {
481 // We have no information about the shape of the output.
482 c->set_output(0, out);
483 return OkStatus();
484 }
485
486 ShapeHandle in = c->input(0);
487 if (!c->RankKnown(in)) {
488 // We have no information about the shape of the input,
489 // nothing to do here.
490 c->set_output(0, out);
491 return OkStatus();
492 }
493 int out_rank = c->Rank(out);
494 TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
495 int in_rank = c->Rank(in);
496 for (int i = 0; i < in_rank; ++i) {
497 auto in_dim = c->Dim(in, in_rank - i - 1);
498 if (c->Value(in_dim) > 1) {
499 // If the input dimension is greater than 1 then the output dimension
500 // must be equal to it, since we only broadcast "from left to right".
501 auto out_dim = c->Dim(out, out_rank - i - 1);
502 TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
503 TF_RETURN_IF_ERROR(
504 c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
505 }
506 }
507 c->set_output(0, out);
508 return OkStatus();
509 });
510
511// --------------------------------------------------------------------------
512// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
513// in the N == 1 case to remove the node.
514REGISTER_OP("Concat")
515 .Input("concat_dim: int32")
516 .Input("values: N * T")
517 .Output("output: T")
518 .Attr("N: int >= 2")
519 .Attr("T: type")
520 .SetShapeFn([](InferenceContext* c) {
521 return shape_inference::ConcatShape(c, c->num_inputs() - 1);
522 });
523
524REGISTER_OP("ConcatV2")
525 .Input("values: N * T")
526 .Input("axis: Tidx")
527 .Output("output: T")
528 .Attr("N: int >= 2")
529 .Attr("T: type")
530 .Attr("Tidx: {int32, int64} = DT_INT32")
531 .SetShapeFn(shape_inference::ConcatV2Shape);
532
533// TODO([email protected]): Prefix the op names with underscore if the ops
534// are not to be made user-accessible.
535#ifdef INTEL_MKL
536REGISTER_OP("_MklConcatV2")
537 .Input("values: N * T")
538 .Input("axis: Tidx")
539 .Input("mkl_values: N * uint8")
540 .Input("mkl_axis: uint8")
541 .Output("output: T")
542 .Output("mkl_output: uint8")
543 .Attr("N: int >= 2")
544 .Attr("T: type")
545 .Attr("Tidx: {int32, int64} = DT_INT32")
546 .SetShapeFn(shape_inference::ConcatV2Shape)
547 .Doc(R"doc(
548MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
549
550NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
551expected to invoke these operators.
552)doc");
553#endif
554
555REGISTER_OP("ConcatOffset")
556 .Input("concat_dim: int32")
557 .Input("shape: N * int32")
558 .Output("offset: N * int32")
559 .Attr("N: int >= 2")
560 .SetShapeFn([](InferenceContext* c) {
561 for (int i = 1; i < c->num_inputs(); ++i) {
562 c->set_output(i - 1, c->input(i));
563 }
564 return OkStatus();
565 });
566
567// --------------------------------------------------------------------------
568REGISTER_OP("Split")
569 .Input("split_dim: int32")
570 .Input("value: T")
571 .Output("output: num_split * T")
572 .Attr("num_split: int >= 1")
573 .Attr("T: type")
574 .SetShapeFn([](InferenceContext* c) {
575 DimensionHandle split_dimension;
576 ShapeHandle input = c->input(1);
577 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
578 0, c->Rank(input), &split_dimension));
579 int num_split = c->num_outputs();
580 ShapeHandle out;
581 if (!c->ValueKnown(split_dimension)) {
582 if (c->RankKnown(input)) {
583 out = c->UnknownShapeOfRank(c->Rank(input));
584 } else {
585 out = c->UnknownShape();
586 }
587 } else {
588 int64_t split_dim = c->Value(split_dimension);
589 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
590 DimensionHandle split_dim_size;
591 TF_RETURN_WITH_CONTEXT_IF_ERROR(
592 c->Divide(c->Dim(input, split_dim), num_split,
593 true /* evenly_divisible */, &split_dim_size),
594 "Number of ways to split should evenly divide the split dimension");
595 TF_RETURN_IF_ERROR(
596 c->ReplaceDim(input, split_dim, split_dim_size, &out));
597 }
598 for (int i = 0; i < num_split; ++i) c->set_output(i, out);
599 return OkStatus();
600 });
601
602REGISTER_OP("SplitV")
603 .Input("value: T")
604 .Input("size_splits: Tlen")
605 .Input("split_dim: int32")
606 .Output("output: num_split * T")
607 .Attr("num_split: int >= 1")
608 .Attr("T: type")
609 .Attr("Tlen: {int8, int32, int64} = DT_INT64")
610 .SetShapeFn([](InferenceContext* c) {
611 DimensionHandle split_dimension;
612 ShapeHandle input = c->input(0);
613 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
614 2, c->Rank(input), &split_dimension));
615 int32_t num_outputs = c->num_outputs();
616 int32_t rank = c->Rank(input);
617 ShapeHandle output_shape;
618 const Tensor* size_splits = c->input_tensor(1);
619 if (rank == InferenceContext::kUnknownRank) {
620 // If the rank of input tensor is unknown, then return unknown shapes.
621 // Note that the shape of each output can be different.
622 for (int i = 0; i < num_outputs; ++i) {
623 c->set_output(i, c->UnknownShape());
624 }
625 } else if (rank == 0) {
626 // Throw error if input is a scalar.
627 return errors::InvalidArgument("Can't split scalars");
628 } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
629 // If split dimension is known, but the sizes are unknown, then
630 // only the split dimension is unknown
631 output_shape = input;
632 for (int i = 0; i < num_outputs; ++i) {
633 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
634 c->Value(split_dimension),
635 c->UnknownDim(), &output_shape));
636 c->set_output(i, output_shape);
637 }
638 } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
639 // If split dimension or tensor containing the split sizes is unknown,
640 // then return unknown shapes of same rank as input. Note that each
641 // output shape can be different since splitv doesn't always split
642 // tensors evenly.
643 for (int i = 0; i < num_outputs; ++i) {
644 c->set_output(i, c->UnknownShapeOfRank(rank));
645 }
646 } else {
647 // Determine the output shape if split dimension and split sizes are
648 // known.
649 int64_t split_dim = c->Value(split_dimension);
650 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
651 std::vector<int64_t> data;
652 if (size_splits->dtype() == DT_INT32) {
653 data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
654 } else {
655 data =
656 AsInt64<int64_t>(size_splits, size_splits->shape().dim_size(0));
657 }
658 if (num_outputs != data.size()) {
659 return errors::InvalidArgument(
660 "Length of size_splits should be equal to num_outputs");
661 }
662 int64_t total_size = 0;
663 bool has_neg_one = false;
664 for (const auto size : data) {
665 if (size == -1) {
666 if (has_neg_one) {
667 return errors::InvalidArgument(
668 "size_splits can only have one -1");
669 }
670 has_neg_one = true;
671 } else {
672 total_size += size;
673 }
674 }
675 auto split_dim_size = c->Value(c->Dim(input, split_dim));
676 // If the sizes of the splits are known, then
677 // make sure that the sizes add up to the expected
678 // dimension size, with the possibility of a -1.
679 // Specify the full output shapes.
680 for (int i = 0; i < num_outputs; ++i) {
681 auto size = data[i];
682 if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
683 size = split_dim_size - total_size;
684 }
685 // If we have a negative known size (either explicit, or computed
686 // via -1), then the split sizes are invalid.
687 if (size < -1 || (size == -1 && c->ValueKnown(split_dim_size))) {
688 return errors::InvalidArgument("Split size at index ", i,
689 " must be >= 0. Got: ", size);
690 }
691 TF_RETURN_IF_ERROR(
692 c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
693 c->set_output(i, output_shape);
694 }
695 if (c->ValueKnown(split_dim_size)) {
696 if (has_neg_one ? total_size > split_dim_size
697 : total_size != split_dim_size) {
698 return errors::InvalidArgument(
699 "can't split axis of size ", split_dim_size,
700 " into pieces of size [", absl::StrJoin(data, ","), "]");
701 }
702 }
703 }
704
705 return OkStatus();
706 });
707
708// --------------------------------------------------------------------------
709REGISTER_OP("Const")
710 .Output("output: dtype")
711 .Attr("value: tensor")
712 .Attr("dtype: type")
713 .SetShapeFn([](InferenceContext* c) {
714 const TensorProto* proto = nullptr;
715 TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
716 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
717 TensorShape shape(proto->tensor_shape());
718 std::vector<DimensionHandle> dims;
719 dims.reserve(shape.dims());
720 for (int i = 0; i < shape.dims(); ++i) {
721 dims.push_back(c->MakeDim(shape.dim_size(i)));
722 }
723 c->set_output(0, c->MakeShape(dims));
724 return OkStatus();
725 });
726
727// Returns a constant tensor on the host. Useful for writing C++ tests
728// and benchmarks which run on GPU but require arguments pinned to the host.
729// Used by test::graph::HostConstant.
730// value: Attr `value` is the tensor to return.
731REGISTER_OP("HostConst")
732 .Output("output: dtype")
733 .Attr("value: tensor")
734 .Attr("dtype: type")
735 .SetShapeFn(shape_inference::UnknownShape);
736
737// Used executing op-by-op to copy constants to the current device without
738// serializing tensors as TensorProtos, after a host tensor has been
739// created. Same behavior as Identity, but no gradient and potentially relaxed
740// copy semantics.
741REGISTER_OP("_EagerConst")
742 .Input("input: T")
743 .Output("output: T")
744 .Attr("T: type")
745 .SetShapeFn(shape_inference::UnchangedShape);
746
747// --------------------------------------------------------------------------
748// TODO(mgubin): Update the doc when the freeze_graph script supports converting
749// into memmapped format.
750REGISTER_OP("ImmutableConst")
751 .Attr("dtype: type")
752 .Attr("shape: shape")
753 .Attr("memory_region_name: string")
754 .Output("tensor: dtype")
755 .SetShapeFn(shape_inference::ExplicitShape);
756
757REGISTER_OP("GuaranteeConst")
758 .Input("input: T")
759 .Output("output: T")
760 .Attr("T: type")
761 .SetShapeFn([](shape_inference::InferenceContext* c) {
762 return UnchangedShape(c);
763 })
764 // We don't want this to be optimized away.
765 .SetDoNotOptimize();
766
767// --------------------------------------------------------------------------
768REGISTER_OP("ZerosLike")
769 .Input("x: T")
770 .Output("y: T")
771 .Attr("T: type")
772 .SetShapeFn(shape_inference::UnchangedShape);
773
774// --------------------------------------------------------------------------
775REGISTER_OP("OnesLike")
776 .Input("x: T")
777 .Output("y: T")
778 .Attr(
779 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
780 "uint32, int64, uint64, complex64, complex128, bool}")
781 .SetShapeFn(shape_inference::UnchangedShape);
782
783// --------------------------------------------------------------------------
784REGISTER_OP("Diag")
785 .Input("diagonal: T")
786 .Output("output: T")
787 .Attr(
788 "T: {bfloat16, half, float, double, int32, int64, complex64, "
789 "complex128}")
790 .SetShapeFn([](InferenceContext* c) {
791 ShapeHandle in = c->input(0);
792 TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
793 // Output shape is original concatenated with itself.
794 ShapeHandle out;
795 TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
796 c->set_output(0, out);
797 return OkStatus();
798 });
799
800// --------------------------------------------------------------------------
801REGISTER_OP("DiagPart")
802 .Input("input: T")
803 .Output("diagonal: T")
804 .Attr(
805 "T: {bfloat16, half, float, double, int32, int64, complex64, "
806 "complex128}")
807 .SetShapeFn([](InferenceContext* c) {
808 ShapeHandle in = c->input(0);
809 if (!c->RankKnown(in)) {
810 c->set_output(0, c->UnknownShape());
811 return OkStatus();
812 }
813 // Rank must be even, and result will have rank <rank/2>.
814 const int32_t rank = c->Rank(in);
815 if ((rank % 2) != 0 || rank <= 0) {
816 return errors::InvalidArgument(
817 "Input must have even and non-zero rank, input rank is ", rank);
818 }
819 const int32_t mid = rank / 2;
820
821 // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
822 std::vector<DimensionHandle> dims(mid);
823 for (int i = 0; i < mid; ++i) {
824 TF_RETURN_IF_ERROR(
825 c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
826 }
827 c->set_output(0, c->MakeShape(dims));
828 return OkStatus();
829 });
830
831// --------------------------------------------------------------------------
832REGISTER_OP("MatrixDiag")
833 .Input("diagonal: T")
834 .Output("output: T")
835 .Attr("T: type")
836 .SetShapeFn([](InferenceContext* c) {
837 ShapeHandle in;
838 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
839 if (!c->RankKnown(in)) {
840 c->set_output(0, c->UnknownShape());
841 return OkStatus();
842 }
843 const int32_t rank = c->Rank(in);
844 ShapeHandle out;
845 TF_RETURN_IF_ERROR(
846 c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
847 c->set_output(0, out);
848 return OkStatus();
849 });
850
851REGISTER_OP("MatrixDiagV2")
852 .Input("diagonal: T")
853 .Input("k: int32")
854 .Input("num_rows: int32")
855 .Input("num_cols: int32")
856 .Input("padding_value: T")
857 .Output("output: T")
858 .Attr("T: type")
859 .SetShapeFn(shape_inference::MatrixDiagV2Shape);
860
861REGISTER_OP("MatrixDiagV3")
862 .Input("diagonal: T")
863 .Input("k: int32")
864 .Input("num_rows: int32")
865 .Input("num_cols: int32")
866 .Input("padding_value: T")
867 .Output("output: T")
868 .Attr("T: type")
869 .Attr(
870 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
871 "'RIGHT_LEFT'")
872 .SetShapeFn(shape_inference::MatrixDiagV2Shape);
873
874// --------------------------------------------------------------------------
875REGISTER_OP("MatrixSetDiag")
876 .Input("input: T")
877 .Input("diagonal: T")
878 .Output("output: T")
879 .Attr("T: type")
880 .SetShapeFn([](InferenceContext* c) {
881 ShapeHandle input;
882 ShapeHandle diag;
883 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
884 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
885 if (c->RankKnown(input)) {
886 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
887 }
888 DimensionHandle smallest_dim;
889 TF_RETURN_IF_ERROR(
890 c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
891 TF_RETURN_IF_ERROR(
892 c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
893
894 ShapeHandle output = input;
895 if (c->RankKnown(diag) && !c->FullyDefined(input)) {
896 // Try to infer parts of shape from diag.
897 ShapeHandle diag_batch_shape;
898 TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_batch_shape));
899 TF_RETURN_IF_ERROR(
900 c->Concatenate(diag_batch_shape, c->UnknownShapeOfRank(2), &diag));
901 TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
902 }
903 c->set_output(0, output);
904 return OkStatus();
905 });
906
907REGISTER_OP("MatrixSetDiagV2")
908 .Input("input: T")
909 .Input("diagonal: T")
910 .Input("k: int32")
911 .Output("output: T")
912 .Attr("T: type")
913 .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
914
915REGISTER_OP("MatrixSetDiagV3")
916 .Input("input: T")
917 .Input("diagonal: T")
918 .Input("k: int32")
919 .Output("output: T")
920 .Attr("T: type")
921 .Attr(
922 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
923 "'RIGHT_LEFT'")
924 .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
925
926// --------------------------------------------------------------------------
927REGISTER_OP("MatrixDiagPart")
928 .Input("input: T")
929 .Output("diagonal: T")
930 .Attr("T: type")
931 .SetShapeFn([](InferenceContext* c) {
932 ShapeHandle in;
933 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
934 if (!c->RankKnown(in)) {
935 c->set_output(0, c->UnknownShape());
936 return OkStatus();
937 }
938 const int32_t rank = c->Rank(in);
939 std::vector<DimensionHandle> dims;
940 dims.reserve(rank - 2);
941 for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
942
943 DimensionHandle min_dim;
944 TF_RETURN_IF_ERROR(
945 c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
946 dims.push_back(min_dim);
947 c->set_output(0, c->MakeShape(dims));
948 return OkStatus();
949 });
950
951REGISTER_OP("MatrixDiagPartV2")
952 .Input("input: T")
953 .Input("k: int32")
954 .Input("padding_value: T")
955 .Output("diagonal: T")
956 .Attr("T: type")
957 .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
958
959REGISTER_OP("MatrixDiagPartV3")
960 .Input("input: T")
961 .Input("k: int32")
962 .Input("padding_value: T")
963 .Output("diagonal: T")
964 .Attr("T: type")
965 .Attr(
966 "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
967 "'RIGHT_LEFT'")
968 .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
969
970// --------------------------------------------------------------------------
971REGISTER_OP("MatrixBandPart")
972 .Input("input: T")
973 .Input("num_lower: Tindex")
974 .Input("num_upper: Tindex")
975 .Output("band: T")
976 .Attr("T: type")
977 .Attr("Tindex: {int32, int64} = DT_INT64")
978 .SetShapeFn(shape_inference::UnchangedShape);
979
980// --------------------------------------------------------------------------
981REGISTER_OP("Reverse")
982 .Input("tensor: T")
983 .Input("dims: bool")
984 .Output("output: T")
985 .Attr(
986 "T: {uint8, int8, uint16, int16, uint32, int32, uint64, int64, bool, "
987 "bfloat16, half, float, double, complex64, complex128, string}")
988 .SetShapeFn([](InferenceContext* c) {
989 ShapeHandle input = c->input(0);
990 ShapeHandle dims;
991 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
992 DimensionHandle dims_dim = c->Dim(dims, 0);
993 if (c->ValueKnown(dims_dim)) {
994 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
995 }
996 if (c->Rank(input) > 8) {
997 return errors::InvalidArgument(
998 "reverse does not work on tensors with more than 8 dimensions");
999 }
1000 c->set_output(0, input);
1001 return OkStatus();
1002 });
1003
1004// --------------------------------------------------------------------------
1005REGISTER_OP("ReverseV2")
1006 .Input("tensor: T")
1007 .Input("axis: Tidx")
1008 .Output("output: T")
1009 .Attr("Tidx: {int32, int64} = DT_INT32")
1010 .Attr(
1011 "T: {uint8, int8, uint16, int16, int32, uint32, int64, uint64, bool, "
1012 "bfloat16, half, float, double, complex64, complex128, string}")
1013 .SetShapeFn([](InferenceContext* c) {
1014 ShapeHandle input = c->input(0);
1015 ShapeHandle axis;
1016 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
1017 if (c->Rank(input) > 8) {
1018 return errors::InvalidArgument(
1019 "reverse does not work on tensors with more than 8 dimensions");
1020 }
1021 const Tensor* axis_tensor = c->input_tensor(1);
1022 if (axis_tensor != nullptr && c->RankKnown(input)) {
1023 int32_t rank = c->Rank(input);
1024 std::vector<int64_t> axis_value;
1025 if (axis_tensor->dtype() == DT_INT32) {
1026 axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
1027 } else {
1028 axis_value =
1029 AsInt64<int64_t>(axis_tensor, axis_tensor->NumElements());
1030 }
1031 std::vector<bool> axes_dense(c->Rank(input), false);
1032 for (int i = 0; i < axis_value.size(); i++) {
1033 int64_t canonical_axis =
1034 axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
1035 if (canonical_axis < 0 || canonical_axis >= rank) {
1036 return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
1037 " is out of valid range [", 0, ", ",
1038 rank - 1);
1039 }
1040 if (axes_dense[canonical_axis]) {
1041 return errors::InvalidArgument("axis ", canonical_axis,
1042 " specified more than once.");
1043 }
1044 axes_dense[canonical_axis] = true;
1045 }
1046 }
1047 c->set_output(0, input);
1048 return OkStatus();
1049 });
1050
1051// --------------------------------------------------------------------------
1052REGISTER_OP("EditDistance")
1053 .Input("hypothesis_indices: int64")
1054 .Input("hypothesis_values: T")
1055 .Input("hypothesis_shape: int64")
1056 .Input("truth_indices: int64")
1057 .Input("truth_values: T")
1058 .Input("truth_shape: int64")
1059 .Attr("normalize: bool = true")
1060 .Attr("T: type")
1061 .Output("output: float")
1062 .SetShapeFn([](InferenceContext* c) {
1063 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1064 c, c->input(0), c->input(1), c->input(2)));
1065 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1066 c, c->input(3), c->input(4), c->input(5)));
1067 const Tensor* hypothesis_shape_t = c->input_tensor(2);
1068 const Tensor* truth_shape_t = c->input_tensor(5);
1069 if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
1070 // We need to know the runtime shape of the two tensors,
1071 // or else the output shape is unknown.
1072 return shape_inference::UnknownShape(c);
1073 }
1074
1075 if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
1076 return errors::InvalidArgument(
1077 "Num elements of hypothesis_shape does not match truth_shape: ",
1078 hypothesis_shape_t->NumElements(), " vs. ",
1079 truth_shape_t->NumElements());
1080 }
1081
1082 auto h_values = hypothesis_shape_t->flat<int64_t>();
1083 auto t_values = truth_shape_t->flat<int64_t>();
1084 std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
1085 for (int i = 0; i < dims.size(); ++i) {
1086 dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1087 }
1088
1089 c->set_output(0, c->MakeShape(dims));
1090 return OkStatus();
1091 });
1092
1093// --------------------------------------------------------------------------
1094REGISTER_OP("Fill")
1095 .Input("dims: index_type")
1096 .Input("value: T")
1097 .Output("output: T")
1098 .Attr("T: type")
1099 .Attr("index_type: {int32, int64} = DT_INT32")
1100 .SetShapeFn([](InferenceContext* c) {
1101 DataType index_type = DT_INT32;
1102 Status s = c->GetAttr("index_type", &index_type);
1103 if (!s.ok() && s.code() != error::NOT_FOUND) {
1104 return s;
1105 }
1106 ShapeHandle unused;
1107 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1108 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1109
1110 const Tensor* t = c->input_tensor(0);
1111 if (t != nullptr) {
1112 for (int i = 0; i < t->NumElements(); ++i) {
1113 if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1114 (index_type == DT_INT64 && t->vec<int64_t>()(i) < 0)) {
1115 return errors::InvalidArgument("Fill dimensions must be >= 0");
1116 }
1117 }
1118 }
1119
1120 ShapeHandle out;
1121 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1122 c->set_output(0, out);
1123
1124 auto* shape_and_type = c->input_handle_shapes_and_types(1);
1125 if (shape_and_type) {
1126 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1127 }
1128
1129 return OkStatus();
1130 });
1131
1132// --------------------------------------------------------------------------
1133REGISTER_OP("_ParallelConcatStart")
1134 .Output("output: dtype")
1135 .Attr("shape: shape")
1136 .Attr("dtype: type")
1137 .SetIsStateful()
1138 .SetShapeFn(shape_inference::ExplicitShape)
1139 .Doc(R"doc(
1140Creates an empty Tensor with shape `shape` and type `dtype`.
1141
1142The memory can optionally be initialized. This is usually useful in
1143conjunction with inplace operations.
1144
1145shape: 1-D `Tensor` indicating the shape of the output.
1146dtype: The element type of the returned tensor.
1147output: An empty Tensor of the specified type.
1148)doc");
1149
1150// --------------------------------------------------------------------------
1151REGISTER_OP("_ParallelConcatUpdate")
1152 .Input("value: T")
1153 .Input("update: T")
1154 .Output("output: T")
1155 .Attr("T: type")
1156 .Attr("loc: int")
1157 .SetShapeFn(shape_inference::UnchangedShape)
1158 .Doc(R"doc(
1159Updates input `value` at `loc` with `update`.
1160
1161If you use this function you will almost certainly want to add
1162a control dependency as done in the implementation of parallel_stack to
1163avoid race conditions.
1164
1165value: A `Tensor` object that will be updated in-place.
1166loc: A scalar indicating the index of the first dimension such that
1167 value[loc, :] is updated.
1168update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1169 otherwise of rank equal to `value` that contains the new values
1170 for `value`.
1171output: `value` that has been updated accordingly.
1172)doc");
1173
1174// --------------------------------------------------------------------------
1175REGISTER_OP("Gather")
1176 .Input("params: Tparams")
1177 .Input("indices: Tindices")
1178 .Attr("validate_indices: bool = true")
1179 .Output("output: Tparams")
1180 .Attr("Tparams: type")
1181 .Attr("Tindices: {int32,int64}")
1182 .SetShapeFn([](InferenceContext* c) {
1183 ShapeHandle unused;
1184 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1185 ShapeHandle params_subshape;
1186 TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
1187 ShapeHandle indices_shape = c->input(1);
1188 ShapeHandle out;
1189 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1190 c->set_output(0, out);
1191 return OkStatus();
1192 });
1193
1194// --------------------------------------------------------------------------
1195REGISTER_OP("GatherV2")
1196 .Input("params: Tparams")
1197 .Input("indices: Tindices")
1198 .Input("axis: Taxis")
1199 .Attr("batch_dims: int = 0")
1200 .Output("output: Tparams")
1201 .Attr("Tparams: type")
1202 .Attr("Tindices: {int16, int32,int64}")
1203 .Attr("Taxis: {int32,int64}")
1204 .SetShapeFn([](InferenceContext* c) {
1205 ShapeHandle params_shape;
1206 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &params_shape));
1207
1208 ShapeHandle indices_shape = c->input(1);
1209 ShapeHandle unused_axis_shape;
1210 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1211 const Tensor* axis_t = c->input_tensor(2);
1212
1213 // If axis is unknown, we can only infer that the result is params_rank +
1214 // indices_rank - 1.
1215 if (axis_t == nullptr) {
1216 if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1217 int32_t batch_dims;
1218 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1219 c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1220 c->Rank(indices_shape) - 1 -
1221 batch_dims));
1222 } else {
1223 c->set_output(0, c->UnknownShape());
1224 }
1225 return OkStatus();
1226 }
1227
1228 // Note, axis can be negative.
1229 int64_t axis = 0;
1230 if (axis_t->dtype() == DT_INT32) {
1231 axis = axis_t->scalar<int32>()();
1232 } else {
1233 axis = axis_t->scalar<int64_t>()();
1234 }
1235
1236 // Check that params has rank of at least axis + 1.
1237 ShapeHandle unused;
1238 TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1239 params_shape, axis < 0 ? -axis : axis + 1, &unused));
1240
1241 // Note, batch_dims can be negative.
1242 int32_t batch_dims;
1243 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1244 // -rank(indices) <= batch_dims <= rank(indices)
1245 TF_RETURN_IF_ERROR(
1246 c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
1247 if (batch_dims < 0) {
1248 batch_dims += c->Rank(indices_shape);
1249 }
1250 // rank(params) > batch_dims
1251 TF_RETURN_IF_ERROR(
1252 c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
1253
1254 ShapeHandle params_outer_subshape;
1255 TF_RETURN_IF_ERROR(
1256 c->Subshape(params_shape, 0, axis, &params_outer_subshape));
1257
1258 ShapeHandle indices_inner_subshape;
1259 TF_RETURN_IF_ERROR(
1260 c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
1261
1262 ShapeHandle out;
1263 TF_RETURN_IF_ERROR(
1264 c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
1265
1266 // Slice from axis + 1 to the end of params_shape to collect the inner
1267 // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1268 // we slice from 0 to the end of shape. Subshape() handles all other
1269 // out-of-bounds checking.
1270 if (axis != -1) {
1271 ShapeHandle params_inner_subshape;
1272 TF_RETURN_IF_ERROR(
1273 c->Subshape(params_shape, axis + 1, &params_inner_subshape));
1274 TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1275 }
1276
1277 c->set_output(0, out);
1278 return OkStatus();
1279 });
1280
1281// --------------------------------------------------------------------------
1282REGISTER_OP("GatherNd")
1283 .Input("params: Tparams")
1284 .Input("indices: Tindices")
1285 .Output("output: Tparams")
1286 .Attr("Tparams: type")
1287 .Attr("Tindices: {int16, int32,int64}")
1288 .SetShapeFn(shape_inference::GatherNdShape);
1289
1290// --------------------------------------------------------------------------
1291REGISTER_OP("Identity")
1292 .Input("input: T")
1293 .Output("output: T")
1294 .Attr("T: type")
1295 .SetForwardTypeFn(full_type::ReplicateInput())
1296 .SetShapeFn(shape_inference::UnchangedShape);
1297
1298REGISTER_OP("Snapshot")
1299 .Input("input: T")
1300 .Output("output: T")
1301 .Attr("T: type")
1302 .SetShapeFn(shape_inference::UnchangedShape);
1303
1304#ifdef INTEL_MKL
1305REGISTER_OP("_MklIdentity")
1306 .Input("input: T")
1307 .Input("mkl_input: uint8")
1308 .Output("output: T")
1309 .Output("mkl_output: uint8")
1310 .Attr("T: type")
1311 .SetShapeFn(shape_inference::UnchangedShape)
1312 .Doc(R"Doc( Mkl implementation of IdentityOp
1313)Doc");
1314#endif
1315
1316REGISTER_OP("IdentityN")
1317 .Input("input: T")
1318 .Output("output: T")
1319 .Attr("T: list(type)")
1320 .SetShapeFn([](shape_inference::InferenceContext* c) {
1321 std::vector<ShapeHandle> input;
1322 TF_RETURN_IF_ERROR(c->input("input", &input));
1323 TF_RETURN_IF_ERROR(c->set_output("output", input));
1324 // If any of the input shapes are not known, we should return error.
1325 for (int i = 0; i < input.size(); i++) {
1326 if (!input[i].Handle()) {
1327 return errors::InvalidArgument(absl::StrCat(
1328 "Cannot infer output shape #", i,
1329 " for IdentityN node because input shape #", i, " is unknown."));
1330 }
1331 }
1332 return OkStatus();
1333 });
1334
1335// --------------------------------------------------------------------------
1336REGISTER_OP("RefIdentity")
1337 .Input("input: Ref(T)")
1338 .Output("output: Ref(T)")
1339 .Attr("T: type")
1340 .SetShapeFn(shape_inference::UnchangedShape)
1341 .SetAllowsUninitializedInput();
1342
1343// --------------------------------------------------------------------------
1344REGISTER_OP("DebugGradientIdentity")
1345 .Input("input: T")
1346 .Output("output: T")
1347 .Attr("T: type")
1348 .SetShapeFn(shape_inference::UnchangedShape)
1349 .SetAllowsUninitializedInput();
1350
1351REGISTER_OP("DebugGradientRefIdentity")
1352 .Input("input: Ref(T)")
1353 .Output("output: Ref(T)")
1354 .Attr("T: type")
1355 .SetShapeFn(shape_inference::UnchangedShape)
1356 .SetAllowsUninitializedInput();
1357
1358// --------------------------------------------------------------------------
1359REGISTER_OP("StopGradient")
1360 .Input("input: T")
1361 .Output("output: T")
1362 .Attr("T: type")
1363 .SetShapeFn(shape_inference::UnchangedShape);
1364
1365REGISTER_OP("PreventGradient")
1366 .Input("input: T")
1367 .Output("output: T")
1368 .Attr("T: type")
1369 .Attr("message: string = ''")
1370 .SetShapeFn(shape_inference::UnchangedShape);
1371
1372// --------------------------------------------------------------------------
1373REGISTER_OP("CheckNumerics")
1374 .Input("tensor: T")
1375 .Output("output: T")
1376 .Attr("T: {bfloat16, half, float, double}")
1377 .Attr("message: string")
1378 .SetIsStateful()
1379 .SetShapeFn(shape_inference::UnchangedShape);
1380
1381// --------------------------------------------------------------------------
1382REGISTER_OP("CheckNumericsV2")
1383 .Input("tensor: T")
1384 .Output("output: T")
1385 .Attr("T: {bfloat16, half, float, double}")
1386 .Attr("message: string")
1387 .SetIsStateful()
1388 .SetShapeFn(shape_inference::UnchangedShape);
1389
1390// --------------------------------------------------------------------------
1391REGISTER_OP("Reshape")
1392 .Input("tensor: T")
1393 .Input("shape: Tshape")
1394 .Output("output: T")
1395 .Attr("T: type")
1396 .Attr("Tshape: {int32, int64} = DT_INT32")
1397 .SetShapeFn([](InferenceContext* c) {
1398 return SetOutputShapeForReshape(c);
1399 });
1400
1401#ifdef INTEL_MKL
1402REGISTER_OP("_MklReshape")
1403 .Input("tensor: T")
1404 .Input("shape: Tshape")
1405 .Input("mkl_tensor: uint8")
1406 .Input("mkl_shape: uint8")
1407 .Output("output: T")
1408 .Output("mkl_output: uint8")
1409 .Attr("T: type")
1410 .Attr("Tshape: {int32, int64} = DT_INT32")
1411 .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1412 .Doc(R"Doc( MKL implementation of ReshapeOp.
1413)Doc");
1414#endif // INTEL_MKL
1415
1416// --------------------------------------------------------------------------
1417REGISTER_OP("InvertPermutation")
1418 .Input("x: T")
1419 .Output("y: T")
1420 .Attr("T: {int32, int64} = DT_INT32")
1421 .SetShapeFn([](InferenceContext* c) {
1422 ShapeHandle x;
1423 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1424 c->set_output(0, x);
1425 return OkStatus();
1426 });
1427
1428// --------------------------------------------------------------------------
1429REGISTER_OP("Transpose")
1430 .Input("x: T")
1431 .Input("perm: Tperm")
1432 .Output("y: T")
1433 .Attr("T: type")
1434 .Attr("Tperm: {int32, int64} = DT_INT32")
1435 .SetShapeFn(TransposeShapeFn);
1436
1437#ifdef INTEL_MKL
1438REGISTER_OP("_MklTranspose")
1439 .Input("x: T")
1440 .Input("perm: Tperm")
1441 .Output("y: T")
1442 .Attr("T: type")
1443 .Attr("Tperm: {int32, int64} = DT_INT32")
1444 .SetShapeFn(TransposeShapeFn);
1445#endif // INTEL_MKL
1446
1447// --------------------------------------------------------------------------
1448REGISTER_OP("ConjugateTranspose")
1449 .Input("x: T")
1450 .Input("perm: Tperm")
1451 .Output("y: T")
1452 .Attr("T: type")
1453 .Attr("Tperm: {int32, int64} = DT_INT32")
1454 .SetShapeFn(TransposeShapeFn);
1455
1456#ifdef INTEL_MKL
1457REGISTER_OP("_MklConjugateTranspose")
1458 .Input("x: T")
1459 .Input("perm: Tperm")
1460 .Output("y: T")
1461 .Attr("T: type")
1462 .Attr("Tperm: {int32, int64} = DT_INT32")
1463 .SetShapeFn(TransposeShapeFn);
1464#endif // INTEL_MKL
1465
1466// --------------------------------------------------------------------------
1467namespace {
1468Status UniqueIdxShapeFn(InferenceContext* c) {
1469 ShapeHandle input = c->input(0);
1470 const Tensor* axis_t = c->input_tensor(1);
1471 if (axis_t == nullptr || !c->RankKnown(input)) {
1472 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1473 return OkStatus();
1474 }
1475
1476 if (c->Rank(c->input(1)) != 1) {
1477 return errors::InvalidArgument("axis expects a 1D vector.");
1478 }
1479
1480 int32_t n = axis_t->NumElements();
1481 if (n == 0) {
1482 if (c->Rank(input) != 1) {
1483 return errors::InvalidArgument("x expects a 1D vector.");
1484 }
1485 c->set_output(1, input);
1486 return OkStatus();
1487 } else if (n == 1) {
1488 int64_t axis;
1489 if (axis_t->dtype() == DT_INT32) {
1490 axis = static_cast<int64_t>(axis_t->flat<int32>()(0));
1491 } else {
1492 axis = axis_t->flat<int64_t>()(0);
1493 }
1494
1495 int64_t input_rank = c->Rank(input);
1496 if (axis < -input_rank || axis >= input_rank) {
1497 return errors::InvalidArgument("axis expects to be in the range [",
1498 -input_rank, ", ", input_rank, ")");
1499 }
1500 if (axis < 0) {
1501 axis += input_rank;
1502 }
1503 c->set_output(1, c->Vector(c->Dim(input, axis)));
1504 return OkStatus();
1505 }
1506 return errors::InvalidArgument(
1507 "axis does not support input tensors larger than 1 elements.");
1508}
1509} // namespace
1510
1511REGISTER_OP("Unique")
1512 .Input("x: T")
1513 .Output("y: T")
1514 .Output("idx: out_idx")
1515 .Attr("T: type")
1516 .Attr("out_idx: {int32, int64} = DT_INT32")
1517 .SetShapeFn([](InferenceContext* c) {
1518 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1519 c->set_output(1, c->input(0));
1520 // Assert that the input rank is 1.
1521 ShapeHandle dummy;
1522 return c->WithRank(c->input(0), 1, &dummy);
1523 });
1524
1525REGISTER_OP("UniqueV2")
1526 .Input("x: T")
1527 .Input("axis: Taxis")
1528 .Output("y: T")
1529 .Output("idx: out_idx")
1530 .Attr("T: type")
1531 .Attr("Taxis: {int32,int64} = DT_INT64")
1532 .Attr("out_idx: {int32, int64} = DT_INT32")
1533 .SetShapeFn([](InferenceContext* c) {
1534 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1535 TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1536 return OkStatus();
1537 });
1538
1539// --------------------------------------------------------------------------
1540REGISTER_OP("UniqueWithCounts")
1541 .Input("x: T")
1542 .Output("y: T")
1543 .Output("idx: out_idx")
1544 .Output("count: out_idx")
1545 .Attr("T: type")
1546 .Attr("out_idx: {int32, int64} = DT_INT32")
1547 .SetShapeFn([](InferenceContext* c) {
1548 auto uniq = c->Vector(InferenceContext::kUnknownDim);
1549 c->set_output(0, uniq);
1550 c->set_output(1, c->input(0));
1551 c->set_output(2, uniq);
1552 return OkStatus();
1553 });
1554
1555REGISTER_OP("UniqueWithCountsV2")
1556 .Input("x: T")
1557 .Input("axis: Taxis")
1558 .Output("y: T")
1559 .Output("idx: out_idx")
1560 .Output("count: out_idx")
1561 .Attr("T: type")
1562 .Attr("Taxis: {int32,int64} = DT_INT64")
1563 .Attr("out_idx: {int32, int64} = DT_INT32")
1564 .SetShapeFn([](InferenceContext* c) {
1565 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1566 TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1567 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
1568 return OkStatus();
1569 });
1570
1571namespace {
1572
1573Status ShapeShapeFn(InferenceContext* c) {
1574 for (int i = 0; i < c->num_inputs(); ++i) {
1575 DimensionHandle dim;
1576 if (c->RankKnown(c->input(i))) {
1577 dim = c->MakeDim(c->Rank(c->input(i)));
1578 } else {
1579 dim = c->UnknownDim();
1580 }
1581 c->set_output(i, c->Vector(dim));
1582 }
1583 return OkStatus();
1584}
1585
1586} // namespace
1587
1588// --------------------------------------------------------------------------
1589REGISTER_OP("Shape")
1590 .Input("input: T")
1591 .Output("output: out_type")
1592 .Attr("T: type")
1593 .Attr("out_type: {int32, int64} = DT_INT32")
1594 .SetShapeFn(ShapeShapeFn);
1595
1596REGISTER_OP("ShapeN")
1597 .Input("input: N * T")
1598 .Output("output: N * out_type")
1599 .Attr("N: int")
1600 .Attr("T: type")
1601 .Attr("out_type: {int32, int64} = DT_INT32")
1602 .SetShapeFn(ShapeShapeFn);
1603
1604REGISTER_OP("EnsureShape")
1605 .Input("input: T")
1606 .Output("output: T")
1607 .Attr("shape: shape")
1608 .Attr("T: type")
1609 .SetShapeFn([](InferenceContext* c) {
1610 // Merges desired shape and statically known shape of input
1611 PartialTensorShape desired_shape;
1612 TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1613
1614 int rank = desired_shape.dims();
1615 ShapeHandle input_shape_handle;
1616 ShapeHandle desired_shape_handle;
1617 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1618 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1619 desired_shape, &desired_shape_handle));
1620
1621 ShapeHandle merged_shape;
1622 TF_RETURN_IF_ERROR(
1623 c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1624 c->set_output(0, merged_shape);
1625 return OkStatus();
1626 });
1627
1628// --------------------------------------------------------------------------
1629REGISTER_OP("ReverseSequence")
1630 .Input("input: T")
1631 .Input("seq_lengths: Tlen")
1632 .Output("output: T")
1633 .Attr("seq_dim: int")
1634 .Attr("batch_dim: int = 0")
1635 .Attr("T: type")
1636 .Attr("Tlen: {int32, int64} = DT_INT64")
1637 .SetShapeFn([](InferenceContext* c) {
1638 ShapeHandle input = c->input(0);
1639 ShapeHandle seq_lens_shape;
1640 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1641
1642 int64_t seq_dim;
1643 TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1644 int64_t batch_dim;
1645 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1646
1647 if (!c->RankKnown(input)) {
1648 return shape_inference::UnknownShape(c);
1649 }
1650
1651 // Validate batch_dim and seq_dim against input.
1652 const int32_t input_rank = c->Rank(input);
1653 if (batch_dim >= input_rank) {
1654 return errors::InvalidArgument(
1655 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1656 }
1657
1658 if (seq_dim >= input_rank) {
1659 return errors::InvalidArgument(
1660 "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1661 }
1662
1663 // To prevent out of bound access when calling c->Dim(input, batch_dim),
1664 // batch_dim range [-1 * input rank, input rank) is allowed. However,
1665 // the op implementation has a stricter bound for batch_dim requiring >= 0
1666 // value. Thus, perform strict check here.
1667 if (batch_dim < 0) {
1668 return errors::InvalidArgument("batch_dim must be >=0, got ",
1669 batch_dim);
1670 }
1671
1672 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1673 TF_RETURN_IF_ERROR(
1674 c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1675
1676 // Replace batch_dim of input with batch_size
1677 ShapeHandle output_shape;
1678 TF_RETURN_IF_ERROR(
1679 c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1680 c->set_output(0, output_shape);
1681 return OkStatus();
1682 });
1683
1684// --------------------------------------------------------------------------
1685REGISTER_OP("Rank")
1686 .Input("input: T")
1687 .Output("output: int32")
1688 .Attr("T: type")
1689 .SetShapeFn(shape_inference::ScalarShape);
1690
1691// --------------------------------------------------------------------------
1692REGISTER_OP("Size")
1693 .Input("input: T")
1694 .Output("output: out_type")
1695 .Attr("T: type")
1696 .Attr("out_type: {int32, int64} = DT_INT32")
1697 .SetShapeFn(shape_inference::ScalarShape);
1698
1699// --------------------------------------------------------------------------
1700REGISTER_OP("Slice")
1701 .Input("input: T")
1702 .Input("begin: Index")
1703 .Input("size: Index")
1704 .Output("output: T")
1705 .Attr("T: type")
1706 .Attr("Index: {int32,int64}")
1707 .SetShapeFn(shape_inference::SliceShape);
1708
1709#ifdef INTEL_MKL
1710REGISTER_OP("_MklSlice")
1711 .Input("input: T")
1712 .Input("begin: Index")
1713 .Input("size: Index")
1714 .Input("mkl_input: uint8")
1715 .Input("mkl_begin: uint8")
1716 .Input("mkl_size: uint8")
1717 .Output("output: T")
1718 .Output("mkl_output: uint8")
1719 .Attr("T: type")
1720 .Attr("Index: {int32,int64}")
1721 .SetShapeFn(shape_inference::SliceShape);
1722#endif
1723
1724REGISTER_OP("StridedSlice")
1725 .Input("input: T")
1726 .Input("begin: Index")
1727 .Input("end: Index")
1728 .Input("strides: Index")
1729 .Output("output: T")
1730 .Attr("T: type")
1731 .Attr("Index: {int16, int32, int64}")
1732 .Attr("begin_mask: int = 0")
1733 .Attr("end_mask: int = 0")
1734 .Attr("ellipsis_mask: int = 0")
1735 .Attr("new_axis_mask: int = 0")
1736 .Attr("shrink_axis_mask: int = 0")
1737 .SetShapeFn([](InferenceContext* c) {
1738 ShapeHandle input = c->input(0);
1739 ShapeHandle begin_shape, end_shape, strides_shape;
1740 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1741 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1742 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1743 TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1744 TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1745 DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1746
1747 const Tensor* strides_value = c->input_tensor(3);
1748 // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1749 // more shape inference here (e.g. for x[3, ::T]).
1750 if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1751 strides_value == nullptr) {
1752 c->set_output(0, c->UnknownShape());
1753 return OkStatus();
1754 }
1755
1756 PartialTensorShape input_shape({});
1757 for (int i = 0; i < c->Rank(input); ++i) {
1758 auto dim = c->Dim(input, i);
1759 input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1760 }
1761
1762 int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1763 shrink_axis_mask;
1764 TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1765 TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1766 TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1767 TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1768 TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1769
1770 const Tensor* begin_value = c->input_tensor(1);
1771 const Tensor* end_value = c->input_tensor(2);
1772
1773 PartialTensorShape processing_shape, final_shape;
1774 bool is_identity, is_simple_slice, slice_dim0;
1775 gtl::InlinedVector<int64, 4> begin, end, strides;
1776 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1777 begin_value, end_value, *strides_value, input_shape, begin_mask,
1778 end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1779 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1780 &slice_dim0, &begin, &end, &strides));
1781
1782 ShapeHandle out;
1783 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1784 c->set_output(0, out);
1785
1786 auto* shape_and_type = c->input_handle_shapes_and_types(0);
1787 if (shape_and_type) {
1788 c->set_output_handle_shapes_and_types(0, *shape_and_type);
1789 }
1790
1791 return OkStatus();
1792 });
1793
1794REGISTER_OP("StridedSliceGrad")
1795 .Input("shape: Index")
1796 .Input("begin: Index")
1797 .Input("end: Index")
1798 .Input("strides: Index")
1799 .Input("dy: T")
1800 .Output("output: T")
1801 .Attr("T: type")
1802 .Attr("Index: {int32, int64}")
1803 .Attr("begin_mask: int = 0")
1804 .Attr("end_mask: int = 0")
1805 .Attr("ellipsis_mask: int = 0")
1806 .Attr("new_axis_mask: int = 0")
1807 .Attr("shrink_axis_mask: int = 0")
1808 .SetShapeFn([](InferenceContext* c) {
1809 ShapeHandle out;
1810 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1811 c->set_output(0, out);
1812 return OkStatus();
1813 });
1814
1815REGISTER_OP("StridedSliceAssign")
1816 .Input("ref: Ref(T)")
1817 .Input("begin: Index")
1818 .Input("end: Index")
1819 .Input("strides: Index")
1820 .Input("value: T")
1821 .Output("output_ref: Ref(T)")
1822 .Attr("T: type")
1823 .Attr("Index: {int32, int64}")
1824 .Attr("begin_mask: int = 0")
1825 .Attr("end_mask: int = 0")
1826 .Attr("ellipsis_mask: int = 0")
1827 .Attr("new_axis_mask: int = 0")
1828 .Attr("shrink_axis_mask: int = 0")
1829 .SetShapeFn(shape_inference::UnchangedShape);
1830// TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1831// broadcasting.
1832// --------------------------------------------------------------------------
1833
1834REGISTER_OP("ResourceStridedSliceAssign")
1835 .Input("ref: resource")
1836 .Input("begin: Index")
1837 .Input("end: Index")
1838 .Input("strides: Index")
1839 .Input("value: T")
1840 .Attr("T: type")
1841 .Attr("Index: {int32, int64}")
1842 .Attr("begin_mask: int = 0")
1843 .Attr("end_mask: int = 0")
1844 .Attr("ellipsis_mask: int = 0")
1845 .Attr("new_axis_mask: int = 0")
1846 .Attr("shrink_axis_mask: int = 0")
1847 .SetShapeFn(shape_inference::NoOutputs);
1848
1849REGISTER_OP("TensorStridedSliceUpdate")
1850 .Input("input: T")
1851 .Input("begin: Index")
1852 .Input("end: Index")
1853 .Input("strides: Index")
1854 .Input("value: T")
1855 .Output("output: T")
1856 .Attr("T: type")
1857 .Attr("Index: {int32, int64}")
1858 .Attr("begin_mask: int = 0")
1859 .Attr("end_mask: int = 0")
1860 .Attr("ellipsis_mask: int = 0")
1861 .Attr("new_axis_mask: int = 0")
1862 .Attr("shrink_axis_mask: int = 0")
1863 .SetShapeFn(shape_inference::UnchangedShape);
1864
1865REGISTER_OP("Tile")
1866 .Input("input: T")
1867 .Input("multiples: Tmultiples")
1868 .Output("output: T")
1869 .Attr("T: type")
1870 .Attr("Tmultiples: {int32, int64} = DT_INT32")
1871 .SetShapeFn([](InferenceContext* c) {
1872 ShapeHandle input = c->input(0);
1873 // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1874 // it is a vector of non-negative integers, and (ii) doing so allows
1875 // us to handle partially-known multiples.
1876 ShapeHandle multiples;
1877 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1878 if (c->RankKnown(input)) {
1879 TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1880 ShapeHandle dummy;
1881 TF_RETURN_IF_ERROR(
1882 c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1883 }
1884
1885 if (!c->RankKnown(multiples)) {
1886 return shape_inference::UnknownShape(c);
1887 }
1888
1889 int32_t rank = c->Rank(multiples);
1890 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1891 std::vector<DimensionHandle> dims(rank);
1892 for (int i = 0; i < rank; ++i) {
1893 TF_RETURN_IF_ERROR(
1894 c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1895 }
1896 c->set_output(0, c->MakeShape(dims));
1897 return OkStatus();
1898 });
1899
1900// --------------------------------------------------------------------------
1901REGISTER_OP("TileGrad")
1902 .Input("input: T")
1903 .Input("multiples: int32")
1904 .Output("output: T")
1905 .Attr("T: type")
1906 .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1907 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1908
1909// --------------------------------------------------------------------------
1910REGISTER_OP("Where")
1911 .Input("input: T")
1912 .Attr("T: {numbertype, bool} = DT_BOOL")
1913 .Output("index: int64")
1914 .SetShapeFn([](InferenceContext* c) {
1915 c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1916 return OkStatus();
1917 });
1918
1919// --------------------------------------------------------------------------
1920REGISTER_OP("BroadcastArgs")
1921 .Input("s0: T")
1922 .Input("s1: T")
1923 .Output("r0: T")
1924 .Attr("T: {int32, int64} = DT_INT32")
1925 .SetShapeFn([](InferenceContext* c) {
1926 ShapeHandle unused;
1927 ShapeHandle shape_x = c->input(0);
1928 ShapeHandle shape_y = c->input(1);
1929 TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1930 TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1931
1932 if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1933 !c->ValueKnown(c->Dim(shape_y, 0))) {
1934 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1935 return OkStatus();
1936 }
1937
1938 int64_t x_dim = c->Value(c->Dim(shape_x, 0));
1939 int64_t y_dim = c->Value(c->Dim(shape_y, 0));
1940
1941 // Broadcasted shape is going to be as large as the largest dimension.
1942 c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1943 return OkStatus();
1944 });
1945
1946// --------------------------------------------------------------------------
1947REGISTER_OP("BroadcastGradientArgs")
1948 .Input("s0: T")
1949 .Input("s1: T")
1950 .Output("r0: T")
1951 .Output("r1: T")
1952 .Attr("T: {int32, int64} = DT_INT32")
1953 .SetShapeFn([](InferenceContext* c) {
1954 // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1955 ShapeHandle unused;
1956 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1957 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1958 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1959 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1960 return OkStatus();
1961 });
1962
1963// --------------------------------------------------------------------------
1964REGISTER_OP("Pad")
1965 .Input("input: T")
1966 .Input("paddings: Tpaddings")
1967 .Output("output: T")
1968 .Attr("T: type")
1969 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1970 .SetShapeFn(PadShapeFn);
1971
1972// --------------------------------------------------------------------------
1973REGISTER_OP("PadV2")
1974 .Input("input: T")
1975 .Input("paddings: Tpaddings")
1976 .Input("constant_values: T")
1977 .Output("output: T")
1978 .Attr("T: type")
1979 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1980 .SetShapeFn(PadShapeFn);
1981
1982// --------------------------------------------------------------------------
1983REGISTER_OP("MirrorPad")
1984 .Input("input: T")
1985 .Input("paddings: Tpaddings")
1986 .Output("output: T")
1987 .Attr("T: type")
1988 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1989 .Attr(GetMirrorPadModeAttrString())
1990 .SetShapeFn(PadShapeFn);
1991
1992// --------------------------------------------------------------------------
1993namespace {
1994template <typename T>
1995Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1996 const Tensor* paddings_t, int64_t input_rank) {
1997 auto paddings_data = paddings_t->matrix<T>();
1998 std::vector<DimensionHandle> dims(input_rank);
1999 for (int64_t i = 0; i < input_rank; ++i) {
2000 const int64_t pad0 = static_cast<int64_t>(paddings_data(i, 0));
2001 const int64_t pad1 = static_cast<int64_t>(paddings_data(i, 1));
2002 if (pad0 < 0 || pad1 < 0) {
2003 return errors::InvalidArgument("Paddings must be non-negative");
2004 }
2005
2006 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
2007 }
2008 c->set_output(0, c->MakeShape(dims));
2009 return OkStatus();
2010}
2011
2012} // namespace
2013
2014REGISTER_OP("MirrorPadGrad")
2015 .Input("input: T")
2016 .Input("paddings: Tpaddings")
2017 .Output("output: T")
2018 .Attr("T: type")
2019 .Attr("Tpaddings: {int32, int64} = DT_INT32")
2020 .Attr(GetMirrorPadModeAttrString())
2021 .SetShapeFn([](InferenceContext* c) {
2022 ShapeHandle paddings;
2023 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
2024 DimensionHandle pad_0 = c->Dim(paddings, 0);
2025 if (!c->ValueKnown(pad_0)) {
2026 // We don't know the rank of the output since the first
2027 // padding dimension is unknown.
2028 c->set_output(0, c->UnknownShape());
2029 return OkStatus();
2030 }
2031
2032 int64_t input_rank = c->Value(pad_0);
2033 ShapeHandle input;
2034 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
2035 TF_RETURN_IF_ERROR(
2036 c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
2037
2038 const Tensor* paddings_t = c->input_tensor(1);
2039 if (paddings_t == nullptr) {
2040 // Values of 'paddings' is not available, but we know the
2041 // input rank, so return the rank of the output with unknown
2042 // dimensions.
2043 c->set_output(0, c->UnknownShapeOfRank(input_rank));
2044 return OkStatus();
2045 }
2046
2047 if (paddings_t->dtype() == DT_INT32) {
2048 return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
2049 } else {
2050 return MirrorPadKnown<int64_t>(c, input, paddings_t, input_rank);
2051 }
2052 });
2053
2054// --------------------------------------------------------------------------
2055REGISTER_OP("Placeholder")
2056 .Output("output: dtype")
2057 .Attr("dtype: type")
2058 .Attr("shape: shape = { unknown_rank: true }")
2059 .SetShapeFn([](InferenceContext* c) {
2060 PartialTensorShape shape;
2061 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2062
2063 // Placeholder has legacy behavior where we cannot tell the difference
2064 // between a scalar shape attribute and 'unknown shape'. So if the shape
2065 // is a scalar, we return an unknown shape.
2066 if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
2067 return shape_inference::UnknownShape(c);
2068 }
2069
2070 ShapeHandle out;
2071 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2072 c->set_output(0, out);
2073 return OkStatus();
2074 });
2075
2076// Placeholder was modified in a backwards compatible way to do what
2077// PlaceholderV2 did, so we have deprecated V2 (no one was really
2078// using it).
2079REGISTER_OP("PlaceholderV2")
2080 .Output("output: dtype")
2081 .Attr("dtype: type")
2082 .Attr("shape: shape")
2083 .SetShapeFn(shape_inference::ExplicitShape)
2084 .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
2085
2086// --------------------------------------------------------------------------
2087REGISTER_OP("PlaceholderWithDefault")
2088 .Input("input: dtype")
2089 .Output("output: dtype")
2090 .Attr("dtype: type")
2091 .Attr("shape: shape")
2092 .SetShapeFn([](InferenceContext* c) {
2093 ShapeHandle input = c->input(0);
2094 PartialTensorShape shape;
2095 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2096 ShapeHandle out;
2097 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2098
2099 // We merge for compatibility checking, but return the output,
2100 // since output_shape may be less precise than input_shape.
2101 ShapeHandle unused;
2102 TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
2103 c->set_output(0, out);
2104 return OkStatus();
2105 });
2106
2107// --------------------------------------------------------------------------
2108REGISTER_OP("ExpandDims")
2109 .Input("input: T")
2110 .Input("dim: Tdim")
2111 .Output("output: T")
2112 .Attr("T: type")
2113 .Attr("Tdim: {int32, int64} = DT_INT32")
2114 .SetShapeFn([](InferenceContext* c) {
2115 ShapeHandle input = c->input(0);
2116
2117 const Tensor* dim_t = c->input_tensor(1);
2118 if (dim_t != nullptr && dim_t->NumElements() != 1) {
2119 return errors::InvalidArgument(
2120 "'dim' input must be a tensor with a single value");
2121 }
2122 if (dim_t == nullptr || !c->RankKnown(input)) {
2123 c->set_output(0, c->UnknownShape());
2124 return OkStatus();
2125 }
2126
2127 int64_t dim;
2128 if (dim_t->dtype() == DT_INT32) {
2129 dim = static_cast<int64_t>(dim_t->flat<int32>()(0));
2130 } else {
2131 dim = dim_t->flat<int64_t>()(0);
2132 }
2133
2134 const int32_t rank = c->Rank(input);
2135 const int32_t min_dim = -1 * rank - 1;
2136 if (dim < min_dim || dim > rank) {
2137 return errors::InvalidArgument("dim ", dim, " not in the interval [",
2138 min_dim, ", ", rank, "].");
2139 }
2140
2141 if (dim < 0) {
2142 dim += rank + 1;
2143 }
2144
2145 ShapeHandle end;
2146 TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
2147
2148 // Build output as start + 1 + end.
2149 ShapeHandle output;
2150 TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
2151 TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
2152 TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
2153 c->set_output(0, output);
2154 return OkStatus();
2155 });
2156
2157// --------------------------------------------------------------------------
2158REGISTER_OP("Squeeze")
2159 .Input("input: T")
2160 .Output("output: T")
2161 .Attr("T: type")
2162 .Attr("squeeze_dims: list(int) >= 0 = []")
2163 .SetShapeFn([](InferenceContext* c) {
2164 ShapeHandle input = c->input(0);
2165 if (!c->RankKnown(input)) {
2166 // Input shape unknown.
2167 return shape_inference::UnknownShape(c);
2168 }
2169
2170 const int32_t input_rank = c->Rank(input);
2171
2172 // Validate and wrap squeeze dimensions.
2173 std::vector<int32> squeeze_dims;
2174 TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
2175 for (int i = 0; i < squeeze_dims.size(); ++i) {
2176 if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
2177 return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
2178 -input_rank, ",", input_rank, ").");
2179 }
2180
2181 if (squeeze_dims[i] < 0) {
2182 squeeze_dims[i] += input_rank;
2183 }
2184 }
2185
2186 std::vector<DimensionHandle> result_shape;
2187 for (int i = 0; i < input_rank; ++i) {
2188 // True if squeeze_dims contains an entry to squeeze this
2189 // dimension.
2190 bool is_explicit_match =
2191 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2192 squeeze_dims.end();
2193
2194 DimensionHandle dim = c->Dim(input, i);
2195
2196 if (!c->ValueKnown(dim)) {
2197 // Assume that the squeezed dimension will be 1 at runtime.
2198 if (is_explicit_match) continue;
2199
2200 // If squeezing all 1 dimensions, and we see an unknown value,
2201 // give up and return Unknown Shape.
2202 if (squeeze_dims.empty()) {
2203 c->set_output(0, c->UnknownShape());
2204 return OkStatus();
2205 }
2206 } else if (c->Value(dim) == 1) {
2207 if (is_explicit_match || squeeze_dims.empty()) {
2208 // If explicitly squeezing, or squeezing all 1s, remove
2209 // this dimension.
2210 continue;
2211 }
2212 } else if (is_explicit_match) {
2213 return errors::InvalidArgument("Can not squeeze dim[", i,
2214 "], expected a dimension of 1, got ",
2215 c->Value(c->Dim(input, i)));
2216 }
2217
2218 result_shape.emplace_back(dim);
2219 }
2220
2221 c->set_output(0, c->MakeShape(result_shape));
2222 return OkStatus();
2223 });
2224
2225// --------------------------------------------------------------------------
2226REGISTER_OP("ListDiff")
2227 .Input("x: T")
2228 .Input("y: T")
2229 .Output("out: T")
2230 .Output("idx: out_idx")
2231 .Attr("T: type")
2232 .Attr("out_idx: {int32, int64} = DT_INT32")
2233 .SetShapeFn([](InferenceContext* c) {
2234 ShapeHandle unused;
2235 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2236 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2237 // TODO(mrry): Indicate that the length falls within an interval?
2238 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2239 c->set_output(0, out);
2240 c->set_output(1, out);
2241 return OkStatus();
2242 });
2243
2244namespace {
2245
2246// Converts Tensor to flat std::vector<int64_t>.
2247template <typename InputType>
2248std::vector<int64_t> GetFlatInt64(const Tensor& t) {
2249 std::vector<int64_t> output(t.shape().num_elements());
2250 if (t.shape().num_elements() > 0) {
2251 auto eigen_vec = t.flat<InputType>();
2252 std::copy_n(&eigen_vec(0), output.size(), output.begin());
2253 }
2254 return output;
2255}
2256
2257// Converts int32 or int64 Tensor to flat std::vector<int64_t>.
2258std::vector<int64_t> GetFlatInt64(const Tensor& t) {
2259 if (t.dtype() == DT_INT32) {
2260 return GetFlatInt64<int32>(t);
2261 } else {
2262 return GetFlatInt64<int64_t>(t);
2263 }
2264}
2265
2266Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2267 ShapeHandle block_shape_shape,
2268 const Tensor* block_shape_t,
2269 ShapeHandle paddings_shape,
2270 const Tensor* paddings_t) {
2271 if (c->Rank(block_shape_shape) != 1) {
2272 return errors::InvalidArgument("block_shape must have rank 1.");
2273 }
2274
2275 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2276 if (!c->ValueKnown(num_block_dims_handle)) {
2277 return errors::InvalidArgument("block_shape must have known size.");
2278 }
2279
2280 const int64_t num_block_dims = c->Value(num_block_dims_handle);
2281
2282 TF_RETURN_IF_ERROR(
2283 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2284
2285 TF_RETURN_IF_ERROR(
2286 c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2287
2288 DimensionHandle batch_size = c->Dim(input_shape, 0);
2289 std::vector<int64_t> block_shape_vec;
2290 if (block_shape_t && (block_shape_t->NumElements() > 0)) {
2291 block_shape_vec = GetFlatInt64(*block_shape_t);
2292 for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2293 const int64_t block_shape_value = block_shape_vec[dim];
2294 if (block_shape_value < 1) {
2295 return errors::InvalidArgument("block_shape must be positive");
2296 }
2297 if (c->ValueKnown(batch_size)) {
2298 TF_RETURN_IF_ERROR(
2299 c->Multiply(batch_size, block_shape_value, &batch_size));
2300 } else {
2301 batch_size = c->UnknownDim();
2302 }
2303 }
2304 } else if (num_block_dims > 0) {
2305 batch_size = c->UnknownDim();
2306 }
2307
2308 std::vector<DimensionHandle> output_dims{batch_size};
2309 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2310
2311 if (paddings_t && (paddings_t->NumElements() > 0)) {
2312 const std::vector<int64_t> paddings_vec = GetFlatInt64(*paddings_t);
2313 for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2314 const int64_t pad_start = paddings_vec[dim * 2],
2315 pad_end = paddings_vec[dim * 2 + 1];
2316 if (pad_start < 0 || pad_end < 0) {
2317 return errors::InvalidArgument("paddings cannot be negative");
2318 }
2319 if (block_shape_t) {
2320 DimensionHandle padded_size;
2321 TF_RETURN_IF_ERROR(
2322 c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2323 TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2324 TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2325 /*evenly_divisible=*/true,
2326 &output_dims[dim + 1]));
2327 }
2328 }
2329 }
2330
2331 ShapeHandle remaining_input_shape;
2332 TF_RETURN_IF_ERROR(
2333 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2334
2335 ShapeHandle result;
2336 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2337 remaining_input_shape, &result));
2338 c->set_output(0, result);
2339 return OkStatus();
2340}
2341
2342Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2343 ShapeHandle block_shape_shape,
2344 const Tensor* block_shape_t,
2345 ShapeHandle crops_shape, const Tensor* crops_t) {
2346 if (c->Rank(block_shape_shape) != 1) {
2347 return errors::InvalidArgument("block_shape must have rank 1.");
2348 }
2349
2350 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2351 if (!c->ValueKnown(num_block_dims_handle)) {
2352 return errors::InvalidArgument("block_shape must have known size.");
2353 }
2354
2355 const int64_t num_block_dims = c->Value(num_block_dims_handle);
2356
2357 TF_RETURN_IF_ERROR(
2358 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2359
2360 TF_RETURN_IF_ERROR(
2361 c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2362
2363 DimensionHandle batch_size = c->Dim(input_shape, 0);
2364 std::vector<int64_t> block_shape_vec;
2365 if (block_shape_t) {
2366 block_shape_vec = GetFlatInt64(*block_shape_t);
2367 for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2368 const int64_t block_shape_value = block_shape_vec[dim];
2369 if (block_shape_value < 1) {
2370 return errors::InvalidArgument("block_shape must be positive");
2371 }
2372 if (c->ValueKnown(batch_size)) {
2373 TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2374 /*evenly_divisible=*/true, &batch_size));
2375 } else {
2376 batch_size = c->UnknownDim();
2377 }
2378 }
2379 } else if (num_block_dims > 0) {
2380 batch_size = c->UnknownDim();
2381 }
2382
2383 std::vector<DimensionHandle> output_dims{batch_size};
2384 output_dims.resize(num_block_dims + 1, c->UnknownDim());
2385
2386 if (crops_t) {
2387 const std::vector<int64_t> crops_vec = GetFlatInt64(*crops_t);
2388 for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2389 const int64_t crop_start = crops_vec[dim * 2],
2390 crop_end = crops_vec[dim * 2 + 1];
2391 if (crop_start < 0 || crop_end < 0) {
2392 return errors::InvalidArgument("crops cannot be negative");
2393 }
2394 if (block_shape_t) {
2395 DimensionHandle cropped_size;
2396 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2397 block_shape_vec[dim], &cropped_size));
2398 TF_RETURN_IF_ERROR(
2399 c->Subtract(cropped_size, crop_start, &cropped_size));
2400 TF_RETURN_IF_ERROR(
2401 c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2402 }
2403 }
2404 }
2405
2406 ShapeHandle remaining_input_shape;
2407 TF_RETURN_IF_ERROR(
2408 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2409
2410 ShapeHandle result;
2411 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2412 remaining_input_shape, &result));
2413 c->set_output(0, result);
2414 return OkStatus();
2415}
2416
2417} // namespace
2418
2419// --------------------------------------------------------------------------
2420REGISTER_OP("SpaceToBatchND")
2421 .Input("input: T")
2422 .Input("block_shape: Tblock_shape")
2423 .Input("paddings: Tpaddings")
2424 .Output("output: T")
2425 .Attr("T: type")
2426 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2427 .Attr("Tpaddings: {int32, int64} = DT_INT32")
2428 .SetShapeFn([](InferenceContext* c) {
2429 return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2430 c->input_tensor(1), c->input(2),
2431 c->input_tensor(2));
2432 });
2433
2434// --------------------------------------------------------------------------
2435REGISTER_OP("SpaceToBatch")
2436 .Input("input: T")
2437 .Input("paddings: Tpaddings")
2438 .Output("output: T")
2439 .Attr("T: type")
2440 .Attr("Tpaddings: {int32, int64} = DT_INT32")
2441 .Attr("block_size: int >= 2")
2442 .SetShapeFn([](InferenceContext* c) {
2443 ShapeHandle input_shape;
2444 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2445
2446 int32_t block_size;
2447 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2448
2449 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2450 auto block_shape_vec = block_shape.vec<int64_t>();
2451 block_shape_vec(0) = block_size;
2452 block_shape_vec(1) = block_size;
2453
2454 return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2455 &block_shape, c->input(1),
2456 c->input_tensor(1));
2457 });
2458
2459// --------------------------------------------------------------------------
2460REGISTER_OP("BatchToSpaceND")
2461 .Input("input: T")
2462 .Input("block_shape: Tblock_shape")
2463 .Input("crops: Tcrops")
2464 .Output("output: T")
2465 .Attr("T: type")
2466 .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2467 .Attr("Tcrops: {int32, int64} = DT_INT32")
2468 .SetShapeFn([](InferenceContext* c) {
2469 return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2470 c->input_tensor(1), c->input(2),
2471 c->input_tensor(2));
2472 });
2473
2474// --------------------------------------------------------------------------
2475REGISTER_OP("BatchToSpace")
2476 .Input("input: T")
2477 .Input("crops: Tidx")
2478 .Output("output: T")
2479 .Attr("T: type")
2480 .Attr("block_size: int >= 2")
2481 .Attr("Tidx: {int32, int64} = DT_INT32")
2482 .SetShapeFn([](InferenceContext* c) {
2483 ShapeHandle input_shape;
2484 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2485
2486 int32_t block_size;
2487 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2488
2489 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2490 auto block_shape_vec = block_shape.vec<int64_t>();
2491 block_shape_vec(0) = block_size;
2492 block_shape_vec(1) = block_size;
2493
2494 return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2495 &block_shape, c->input(1),
2496 c->input_tensor(1));
2497 });
2498
2499// --------------------------------------------------------------------------
2500REGISTER_OP("SpaceToDepth")
2501 .Input("input: T")
2502 .Output("output: T")
2503 .Attr("T: type")
2504 .Attr("block_size: int >= 2")
2505 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2506 // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
2507 .SetShapeFn([](InferenceContext* c) {
2508 string data_format_str;
2509 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2510 TensorFormat data_format;
2511 FormatFromString(data_format_str, &data_format);
2512
2513 constexpr int num_spatial_dims = 2;
2514 const int dims =
2515 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2516 ShapeHandle input;
2517 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2518
2519 int32_t block_size;
2520 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2521
2522 DimensionHandle batch_size =
2523 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2524 DimensionHandle input_height =
2525 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2526 DimensionHandle input_width =
2527 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2528 DimensionHandle input_depth =
2529 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2530
2531 DimensionHandle output_height;
2532 DimensionHandle output_width;
2533 DimensionHandle output_depth;
2534 // Will return an error if input height or width are not evenly divisible.
2535 TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2536 true /* evenly_divisible */,
2537 &output_height));
2538 TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2539 true /* evenly_divisible */, &output_width));
2540
2541 TF_RETURN_IF_ERROR(
2542 c->Multiply(input_depth, block_size * block_size, &output_depth));
2543
2544 ShapeHandle output_shape;
2545 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2546 {output_height, output_width},
2547 output_depth, &output_shape, c));
2548
2549 c->set_output(0, output_shape);
2550 return OkStatus();
2551 });
2552
2553// --------------------------------------------------------------------------
2554REGISTER_OP("DepthToSpace")
2555 .Input("input: T")
2556 .Output("output: T")
2557 .Attr("T: type")
2558 .Attr("block_size: int >= 2")
2559 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2560 // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
2561 .SetShapeFn([](InferenceContext* c) {
2562 string data_format_str;
2563 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2564 TensorFormat data_format;
2565 FormatFromString(data_format_str, &data_format);
2566
2567 constexpr int num_spatial_dims = 2;
2568 const int dims =
2569 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2570
2571 ShapeHandle input;
2572 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2573
2574 int32_t block_size;
2575 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2576
2577 DimensionHandle batch_size =
2578 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2579 DimensionHandle input_height =
2580 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2581 DimensionHandle input_width =
2582 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2583 DimensionHandle input_depth =
2584 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2585
2586 DimensionHandle output_height;
2587 DimensionHandle output_width;
2588 DimensionHandle output_depth;
2589 TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2590 TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2591
2592 // Will return an error if input_depth is not evenly divisible.
2593 TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2594 true /* evenly_divisible */, &output_depth));
2595
2596 ShapeHandle output_shape;
2597 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2598 {output_height, output_width},
2599 output_depth, &output_shape, c));
2600
2601 c->set_output(0, output_shape);
2602 return OkStatus();
2603 });
2604
2605// --------------------------------------------------------------------------
2606
2607REGISTER_OP("ExtractImagePatches")
2608 .Input("images: T")
2609 .Output("patches: T")
2610 .Attr("ksizes: list(int) >= 4")
2611 .Attr("strides: list(int) >= 4")
2612 .Attr("rates: list(int) >= 4")
2613 .Attr(
2614 "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
2615 "uint8, uint16, uint32, uint64, complex64, complex128, bool}")
2616 .Attr(GetPaddingAttrString())
2617 .SetShapeFn([](InferenceContext* c) {
2618 ShapeHandle input_shape;
2619 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2620
2621 std::vector<int32> ksizes;
2622 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2623 if (ksizes.size() != 4) {
2624 return errors::InvalidArgument(
2625 "ExtractImagePatches requires the ksizes attribute to contain 4 "
2626 "values, but got: ",
2627 ksizes.size());
2628 }
2629
2630 std::vector<int32> strides;
2631 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2632 if (strides.size() != 4) {
2633 return errors::InvalidArgument(
2634 "ExtractImagePatches requires the stride attribute to contain 4 "
2635 "values, but got: ",
2636 strides.size());
2637 }
2638
2639 std::vector<int32> rates;
2640 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2641 if (rates.size() != 4) {
2642 return errors::InvalidArgument(
2643 "ExtractImagePatches requires the rates attribute to contain 4 "
2644 "values, but got: ",
2645 rates.size());
2646 }
2647
2648 int32_t ksize_rows = ksizes[1];
2649 int32_t ksize_cols = ksizes[2];
2650
2651 int32_t stride_rows = strides[1];
2652 int32_t stride_cols = strides[2];
2653
2654 int32_t rate_rows = rates[1];
2655 int32_t rate_cols = rates[2];
2656
2657 int32_t ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2658 int32_t ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2659
2660 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2661 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2662 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2663 DimensionHandle output_depth_dim;
2664 TF_RETURN_IF_ERROR(c->Multiply(
2665 c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2666
2667 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2668 ShapeHandle output_shape =
2669 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2670 InferenceContext::kUnknownDim, output_depth_dim});
2671 c->set_output(0, output_shape);
2672 return OkStatus();
2673 }
2674 auto in_rows = c->Value(in_rows_dim);
2675 auto in_cols = c->Value(in_cols_dim);
2676
2677 Padding padding;
2678 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2679
2680 int64_t output_rows, output_cols;
2681 int64_t padding_before, padding_after;
2682 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2683 in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2684 &padding_before, &padding_after));
2685 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2686 in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2687 &padding_before, &padding_after));
2688 ShapeHandle output_shape = c->MakeShape(
2689 {batch_size_dim, output_rows, output_cols, output_depth_dim});
2690 c->set_output(0, output_shape);
2691 return OkStatus();
2692 });
2693
2694// --------------------------------------------------------------------------
2695
2696// To enable rates, uncomment all lines commented below and use ksize_*_eff
2697// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2698// of ksize_*.
2699REGISTER_OP("ExtractVolumePatches")
2700 .Input("input: T")
2701 .Output("patches: T")
2702 .Attr("ksizes: list(int) >= 5")
2703 .Attr("strides: list(int) >= 5")
2704 /* .Attr("rates: list(int) >= 5") */
2705 .Attr("T: realnumbertype")
2706 .Attr(GetPaddingAttrString())
2707 .SetShapeFn([](InferenceContext* c) {
2708 ShapeHandle input_shape;
2709 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2710
2711 std::vector<int32> ksizes;
2712 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2713 if (ksizes.size() != 5) {
2714 return errors::InvalidArgument(
2715 "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2716 "values, but got: ",
2717 ksizes.size());
2718 }
2719
2720 std::vector<int32> strides;
2721 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2722 if (strides.size() != 5) {
2723 return errors::InvalidArgument(
2724 "ExtractVolumePatches requires the stride attribute to contain 5 "
2725 "values, but got: ",
2726 strides.size());
2727 }
2728
2729 /*
2730 // TODO(hsgkim): Enable rates.
2731 // See extract_volume_patches_op.cc for why rates are disabled now.
2732
2733 std::vector<int32> rates;
2734 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2735 if (rates.size() != 5) {
2736 return errors::InvalidArgument(
2737 "ExtractVolumePatches requires the rates attribute to contain 5 "
2738 "values, but got: ",
2739 rates.size());
2740 }
2741 */
2742
2743 int32_t ksize_planes = ksizes[1];
2744 int32_t ksize_rows = ksizes[2];
2745 int32_t ksize_cols = ksizes[3];
2746
2747 int32_t stride_planes = strides[1];
2748 int32_t stride_rows = strides[2];
2749 int32_t stride_cols = strides[3];
2750
2751 /*
2752 int32 rate_planes = rates[1];
2753 int32 rate_rows = rates[2];
2754 int32 rate_cols = rates[3];
2755
2756 int32 ksize_planes_eff = ksize_planes +
2757 (ksize_planes - 1) * (rate_planes - 1);
2758 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2759 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2760 */
2761
2762 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2763 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2764 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2765 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2766 DimensionHandle output_depth_dim;
2767 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2768 ksize_planes * ksize_rows * ksize_cols,
2769 &output_depth_dim));
2770
2771 if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2772 !c->ValueKnown(in_cols_dim)) {
2773 ShapeHandle output_shape =
2774 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2775 InferenceContext::kUnknownDim, output_depth_dim});
2776 c->set_output(0, output_shape);
2777 return OkStatus();
2778 }
2779 auto in_planes = c->Value(in_planes_dim);
2780 auto in_rows = c->Value(in_rows_dim);
2781 auto in_cols = c->Value(in_cols_dim);
2782
2783 Padding padding;
2784 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2785
2786 int64_t output_planes, output_rows, output_cols;
2787 int64_t padding_before, padding_after;
2788 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2789 in_planes, ksize_planes, stride_planes, padding, &output_planes,
2790 &padding_before, &padding_after));
2791 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2792 in_rows, ksize_rows, stride_rows, padding, &output_rows,
2793 &padding_before, &padding_after));
2794 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2795 in_cols, ksize_cols, stride_cols, padding, &output_cols,
2796 &padding_before, &padding_after));
2797 ShapeHandle output_shape =
2798 c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2799 output_depth_dim});
2800 c->set_output(0, output_shape);
2801 return OkStatus();
2802 });
2803
2804// --------------------------------------------------------------------------
2805
2806REGISTER_OP("OneHot")
2807 .Input("indices: TI")
2808 .Input("depth: int32")
2809 .Input("on_value: T")
2810 .Input("off_value: T")
2811 .Attr("axis: int = -1")
2812 .Output("output: T")
2813 .Attr("T: type")
2814 .Attr("TI: {uint8, int8, int32, int64} = DT_INT64")
2815 .SetShapeFn([](InferenceContext* c) {
2816 int32_t axis;
2817 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2818 if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2819
2820 DimensionHandle depth;
2821 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2822
2823 ShapeHandle indices = c->input(0);
2824 if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2825
2826 int32_t new_rank = c->Rank(indices) + 1;
2827 // We need to add new_rank to axis in the case the axis is -1 because
2828 // C++ returns negative values from % if the dividend is negative.
2829 int32_t depth_index = (axis + new_rank) % new_rank;
2830 // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2831 ShapeHandle front;
2832 ShapeHandle back;
2833 ShapeHandle out;
2834 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2835 TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2836 TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2837 TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2838 c->set_output(0, out);
2839 return OkStatus();
2840 });
2841
2842// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2843REGISTER_OP("QuantizeAndDequantize")
2844 .Input("input: T")
2845 .Attr("signed_input: bool = true")
2846 .Attr("num_bits: int = 8")
2847 .Attr("range_given: bool = false")
2848 .Attr("input_min: float = 0")
2849 .Attr("input_max: float = 0")
2850 .Output("output: T")
2851 .Attr("T: {bfloat16, half, float, double}")
2852 .SetShapeFn(shape_inference::UnchangedShape)
2853 .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2854
2855// TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2856REGISTER_OP("QuantizeAndDequantizeV2")
2857 .Input("input: T")
2858 .Input("input_min: T")
2859 .Input("input_max: T")
2860 .Attr("signed_input: bool = true")
2861 .Attr("num_bits: int = 8")
2862 .Attr("range_given: bool = false")
2863 .Output("output: T")
2864 .Attr("T: {bfloat16, half, float, double}")
2865 .Attr(
2866 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2867 "'HALF_TO_EVEN'")
2868 .Attr("narrow_range: bool = false")
2869 .Attr("axis: int = -1")
2870 .SetShapeFn([](InferenceContext* c) {
2871 int axis;
2872 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2873 const int minmax_rank = (axis == -1) ? 0 : 1;
2874 ShapeHandle minmax;
2875 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2876 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2877 if (axis < -1) {
2878 return errors::InvalidArgument("axis should be at least -1, got ",
2879 axis);
2880 } else if (axis != -1) {
2881 ShapeHandle input;
2882 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2883 DimensionHandle depth;
2884 TF_RETURN_IF_ERROR(
2885 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2886 }
2887 c->set_output(0, c->input(0));
2888 return OkStatus();
2889 });
2890
2891REGISTER_OP("QuantizeAndDequantizeV4")
2892 .Input("input: T")
2893 .Input("input_min: T")
2894 .Input("input_max: T")
2895 .Attr("signed_input: bool = true")
2896 .Attr("num_bits: int = 8")
2897 .Attr("range_given: bool = false")
2898 .Output("output: T")
2899 .Attr("T: {bfloat16, half, float, double}")
2900 .Attr(
2901 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2902 "'HALF_TO_EVEN'")
2903 .Attr("narrow_range: bool = false")
2904 .Attr("axis: int = -1")
2905 .SetShapeFn([](InferenceContext* c) {
2906 int axis;
2907 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2908 const int minmax_rank = (axis == -1) ? 0 : 1;
2909 ShapeHandle minmax;
2910 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2911 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2912 if (axis < -1) {
2913 return errors::InvalidArgument("axis should be at least -1, got ",
2914 axis);
2915 } else if (axis != -1) {
2916 ShapeHandle input;
2917 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2918 DimensionHandle depth;
2919 TF_RETURN_IF_ERROR(
2920 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2921 }
2922 c->set_output(0, c->input(0));
2923 return OkStatus();
2924 });
2925
2926REGISTER_OP("QuantizeAndDequantizeV4Grad")
2927 .Input("gradients: T")
2928 .Input("input: T")
2929 .Input("input_min: T")
2930 .Input("input_max: T")
2931 .Output("input_backprop: T")
2932 .Output("input_min_backprop: T")
2933 .Output("input_max_backprop: T")
2934 .Attr("T: {bfloat16, half, float, double}")
2935 .Attr("axis: int = -1")
2936 .SetShapeFn([](InferenceContext* c) {
2937 int axis;
2938 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2939 const int minmax_rank = (axis == -1) ? 0 : 1;
2940 ShapeHandle minmax;
2941 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2942 TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
2943 if (axis < -1) {
2944 return errors::InvalidArgument("axis should be at least -1, got ",
2945 axis);
2946 } else if (axis != -1) {
2947 ShapeHandle input;
2948 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2949 DimensionHandle depth;
2950 TF_RETURN_IF_ERROR(
2951 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2952 }
2953 ShapeHandle inputs;
2954 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
2955 c->set_output(0, inputs);
2956 c->set_output(1, minmax);
2957 c->set_output(2, minmax);
2958 return OkStatus();
2959 });
2960
2961REGISTER_OP("QuantizeAndDequantizeV3")
2962 .Input("input: T")
2963 .Input("input_min: T")
2964 .Input("input_max: T")
2965 .Input("num_bits: int32")
2966 .Attr("signed_input: bool = true")
2967 .Attr("range_given: bool = true")
2968 .Output("output: T")
2969 .Attr("T: {bfloat16, half, float, double}")
2970 .Attr("narrow_range: bool = false")
2971 .Attr("axis: int = -1")
2972 .SetShapeFn([](InferenceContext* c) {
2973 int axis;
2974 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2975 const int minmax_rank = (axis == -1) ? 0 : 1;
2976 ShapeHandle minmax;
2977 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2978 TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2979 if (axis < -1) {
2980 return errors::InvalidArgument("axis should be at least -1, got ",
2981 axis);
2982 } else if (axis != -1) {
2983 ShapeHandle input;
2984 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2985 DimensionHandle depth;
2986 TF_RETURN_IF_ERROR(
2987 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2988 }
2989 ShapeHandle unused;
2990 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2991 c->set_output(0, c->input(0));
2992 return OkStatus();
2993 });
2994
2995REGISTER_OP("QuantizeV2")
2996 .Input("input: float")
2997 .Input("min_range: float")
2998 .Input("max_range: float")
2999 .Output("output: T")
3000 .Output("output_min: float")
3001 .Output("output_max: float")
3002 .Attr("T: quantizedtype")
3003 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
3004 .Attr(
3005 "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
3006 "'HALF_AWAY_FROM_ZERO'")
3007 .Attr("narrow_range: bool = false")
3008 .Attr("axis: int = -1")
3009 .Attr("ensure_minimum_range: float = 0.01")
3010 .SetShapeFn(shape_inference::QuantizeV2Shape);
3011
3012REGISTER_OP("Dequantize")
3013 .Input("input: T")
3014 .Input("min_range: float")
3015 .Input("max_range: float")
3016 .Output("output: dtype")
3017 .Attr("T: quantizedtype")
3018 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
3019 .Attr("narrow_range: bool = false")
3020 .Attr("axis: int = -1")
3021 .Attr("dtype: {bfloat16, float} = DT_FLOAT")
3022 .SetShapeFn([](InferenceContext* c) {
3023 int axis = -1;
3024 Status s = c->GetAttr("axis", &axis);
3025 if (!s.ok() && s.code() != error::NOT_FOUND) {
3026 return s;
3027 }
3028 if (axis < -1) {
3029 return errors::InvalidArgument("axis should be at least -1, got ",
3030 axis);
3031 }
3032 auto input_dims = c->Rank(c->input(0));
3033 if (axis > input_dims) {
3034 return errors::InvalidArgument(
3035 "Axis must be less than input dimension(", input_dims, "), got ",
3036 axis);
3037 }
3038 const int minmax_rank = (axis == -1) ? 0 : 1;
3039 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3040 ShapeHandle minmax;
3041 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
3042 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
3043 if (axis != -1) {
3044 ShapeHandle input;
3045 if (axis >= kint32max) {
3046 // Check int32 max bound for a corner case to prevent integer flow
3047 // when input actually has kint32max rank and above bound check is not
3048 // triggered.
3049 return errors::InvalidArgument(
3050 "Axis cannot be >= kint32max value, got ", axis);
3051 }
3052 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
3053 DimensionHandle depth;
3054 TF_RETURN_IF_ERROR(
3055 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
3056 }
3057 return OkStatus();
3058 });
3059
3060REGISTER_OP("QuantizedConcat")
3061 .Input("concat_dim: int32")
3062 .Input("values: N * T")
3063 .Input("input_mins: N * float32")
3064 .Input("input_maxes: N * float32")
3065 .Output("output: T")
3066 .Output("output_min: float")
3067 .Output("output_max: float")
3068 .Attr("N: int >= 2")
3069 .Attr("T: type")
3070 .SetShapeFn([](InferenceContext* c) {
3071 const int n = (c->num_inputs() - 1) / 3;
3072 TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
3073 ShapeHandle unused;
3074 for (int i = n + 1; i < c->num_inputs(); ++i) {
3075 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
3076 }
3077 c->set_output(1, c->Scalar());
3078 c->set_output(2, c->Scalar());
3079 return OkStatus();
3080 });
3081
3082REGISTER_OP("QuantizedReshape")
3083 .Input("tensor: T")
3084 .Input("shape: Tshape")
3085 .Input("input_min: float")
3086 .Input("input_max: float")
3087 .Output("output: T")
3088 .Output("output_min: float")
3089 .Output("output_max: float")
3090 .Attr("T: type")
3091 .Attr("Tshape: {int32, int64} = DT_INT32")
3092 .SetShapeFn([](InferenceContext* c) {
3093 TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
3094 ShapeHandle unused;
3095 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3096 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3097 c->set_output(1, c->Scalar());
3098 c->set_output(2, c->Scalar());
3099 return OkStatus();
3100 });
3101
3102REGISTER_OP("QuantizedInstanceNorm")
3103 .Input("x: T")
3104 .Input("x_min: float")
3105 .Input("x_max: float")
3106 .Output("y: T")
3107 .Output("y_min: float")
3108 .Output("y_max: float")
3109 .Attr("T: quantizedtype")
3110 .Attr("output_range_given: bool = false")
3111 .Attr("given_y_min: float = 0")
3112 .Attr("given_y_max: float = 0")
3113 .Attr("variance_epsilon: float = 1e-5")
3114 .Attr("min_separation: float = 1e-3")
3115 .SetShapeFn([](shape_inference::InferenceContext* c) {
3116 shape_inference::ShapeHandle unused;
3117 // x should be a rank 4 tensor.
3118 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
3119 // Assert x_min and x_max are scalars (rank 0).
3120 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3121 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3122 // y has the same shape as x.
3123 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3124 // y_min and y_max are scalars.
3125 c->set_output(1, c->Scalar());
3126 c->set_output(2, c->Scalar());
3127 return OkStatus();
3128 });
3129
3130namespace {
3131
3132Status ScatterNdTensorShape(InferenceContext* c) {
3133 ShapeHandle output_shape;
3134 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
3135 ShapeHandle indices_shape;
3136 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
3137 ShapeHandle updates_shape;
3138 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 0, &updates_shape));
3139 return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
3140 output_shape);
3141}
3142
3143} // namespace
3144
3145REGISTER_OP("UpperBound")
3146 .Input("sorted_inputs: T")
3147 .Input("values: T")
3148 .Output("output: out_type")
3149 .Attr("T: type")
3150 .Attr("out_type: {int32, int64} = DT_INT32")
3151 .SetShapeFn([](InferenceContext* c) {
3152 ShapeHandle unused_shape;
3153 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3154 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3155 c->set_output(0, c->input(1));
3156 return OkStatus();
3157 });
3158
3159REGISTER_OP("LowerBound")
3160 .Input("sorted_inputs: T")
3161 .Input("values: T")
3162 .Output("output: out_type")
3163 .Attr("T: type")
3164 .Attr("out_type: {int32, int64} = DT_INT32")
3165 .SetShapeFn([](InferenceContext* c) {
3166 ShapeHandle unused_shape;
3167 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3168 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3169 c->set_output(0, c->input(1));
3170 return OkStatus();
3171 });
3172
3173REGISTER_OP("ScatterNd")
3174 .Input("indices: Tindices")
3175 .Input("updates: T")
3176 .Input("shape: Tindices")
3177 .Output("output: T")
3178 .Attr("T: type")
3179 .Attr("Tindices: {int16, int32, int64}")
3180 .SetShapeFn([](InferenceContext* c) {
3181 ShapeHandle indices_shape;
3182 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
3183 ShapeHandle updates_shape;
3184 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
3185 ShapeHandle output_shape;
3186 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
3187 return shape_inference::ScatterNdShapeHelper(c, indices_shape,
3188 updates_shape, output_shape);
3189 });
3190
3191REGISTER_OP("TensorScatterUpdate")
3192 .Input("tensor: T")
3193 .Input("indices: Tindices")
3194 .Input("updates: T")
3195 .Output("output: T")
3196 .Attr("T: type")
3197 .Attr("Tindices: {int16, int32, int64, uint16}")
3198 .SetShapeFn(ScatterNdTensorShape);
3199
3200REGISTER_OP("TensorScatterAdd")
3201 .Input("tensor: T")
3202 .Input("indices: Tindices")
3203 .Input("updates: T")
3204 .Output("output: T")
3205 .Attr("T: type")
3206 .Attr("Tindices: {int32, int64}")
3207 .SetShapeFn(ScatterNdTensorShape);
3208
3209REGISTER_OP("TensorScatterSub")
3210 .Input("tensor: T")
3211 .Input("indices: Tindices")
3212 .Input("updates: T")
3213 .Output("output: T")
3214 .Attr("T: type")
3215 .Attr("Tindices: {int32, int64}")
3216 .SetShapeFn(ScatterNdTensorShape);
3217
3218REGISTER_OP("TensorScatterMin")
3219 .Input("tensor: T")
3220 .Input("indices: Tindices")
3221 .Input("updates: T")
3222 .Output("output: T")
3223 .Attr("T: type")
3224 .Attr("Tindices: {int32, int64}")
3225 .SetShapeFn(ScatterNdTensorShape);
3226
3227REGISTER_OP("TensorScatterMax")
3228 .Input("tensor: T")
3229 .Input("indices: Tindices")
3230 .Input("updates: T")
3231 .Output("output: T")
3232 .Attr("T: type")
3233 .Attr("Tindices: {int32, int64}")
3234 .SetShapeFn(ScatterNdTensorShape);
3235
3236REGISTER_OP("ScatterNdNonAliasingAdd")
3237 .Input("input: T")
3238 .Input("indices: Tindices")
3239 .Input("updates: T")
3240 .Output("output: T")
3241 .Attr("T: {numbertype, bool}")
3242 .Attr("Tindices: {int32, int64}")
3243 .SetShapeFn(ScatterNdTensorShape);
3244
3245REGISTER_OP("FakeQuantWithMinMaxArgs")
3246 .Attr("min: float = -6.0")
3247 .Attr("max: float = 6.0")
3248 .Attr("num_bits: int = 8")
3249 .Attr("narrow_range: bool = false")
3250 .Input("inputs: float")
3251 .Output("outputs: float")
3252 .SetShapeFn(shape_inference::UnchangedShape);
3253
3254REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3255 .Attr("min: float = -6.0")
3256 .Attr("max: float = 6.0")
3257 .Attr("num_bits: int = 8")
3258 .Attr("narrow_range: bool = false")
3259 .Input("gradients: float")
3260 .Input("inputs: float")
3261 .Output("backprops: float")
3262 .SetShapeFn(shape_inference::UnchangedShape);
3263
3264REGISTER_OP("FakeQuantWithMinMaxVars")
3265 .Attr("num_bits: int = 8")
3266 .Attr("narrow_range: bool = false")
3267 .Input("inputs: float")
3268 .Input("min: float")
3269 .Input("max: float")
3270 .Output("outputs: float")
3271 .SetShapeFn([](InferenceContext* c) {
3272 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3273 ShapeHandle unused;
3274 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3275 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3276 return OkStatus();
3277 });
3278
3279REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3280 .Attr("num_bits: int = 8")
3281 .Attr("narrow_range: bool = false")
3282 .Input("gradients: float")
3283 .Input("inputs: float")
3284 .Input("min: float")
3285 .Input("max: float")
3286 .Output("backprops_wrt_input: float")
3287 .Output("backprop_wrt_min: float")
3288 .Output("backprop_wrt_max: float")
3289 .SetShapeFn([](InferenceContext* c) {
3290 // gradients and inputs are same size.
3291 ShapeHandle inputs;
3292 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3293
3294 // min and max are scalars
3295 ShapeHandle min_max;
3296 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3297 TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3298
3299 c->set_output(0, inputs);
3300 c->set_output(1, min_max);
3301 c->set_output(2, min_max);
3302 return OkStatus();
3303 });
3304
3305REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3306 .Attr("num_bits: int = 8")
3307 .Attr("narrow_range: bool = false")
3308 .Input("inputs: float")
3309 .Input("min: float")
3310 .Input("max: float")
3311 .Output("outputs: float")
3312 .SetShapeFn([](InferenceContext* c) {
3313 ShapeHandle input, min, max;
3314 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3315 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3316 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3317
3318 DimensionHandle unused;
3319 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3320 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3321 TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3322
3323 c->set_output(0, input);
3324 return OkStatus();
3325 });
3326
3327REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3328 .Attr("num_bits: int = 8")
3329 .Attr("narrow_range: bool = false")
3330 .Input("gradients: float")
3331 .Input("inputs: float")
3332 .Input("min: float")
3333 .Input("max: float")
3334 .Output("backprops_wrt_input: float")
3335 .Output("backprop_wrt_min: float")
3336 .Output("backprop_wrt_max: float")
3337 .SetShapeFn([](InferenceContext* c) {
3338 ShapeHandle inputs;
3339 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3340 TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3341 TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3342
3343 ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3344
3345 ShapeHandle min_max;
3346 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3347 TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3348 TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3349
3350 c->set_output(0, inputs);
3351 c->set_output(1, min_max);
3352 c->set_output(2, min_max);
3353 return OkStatus();
3354 });
3355
3356REGISTER_OP("Fingerprint")
3357 .Input("data: T")
3358 .Input("method: string")
3359 .Output("fingerprint: uint8")
3360 .Attr("T: type")
3361 .SetShapeFn([](InferenceContext* c) {
3362 ShapeHandle unused;
3363 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
3364 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3365
3366 DimensionHandle fingerprint_size;
3367 const Tensor* method = c->input_tensor(1);
3368 if (method == nullptr) {
3369 fingerprint_size = c->UnknownDim();
3370 } else {
3371 if (method->dims() != 0) {
3372 return errors::InvalidArgument("`method` must be rank 0: ",
3373 method->shape());
3374 }
3375 const string& method_string = method->scalar<tstring>()();
3376 if (method_string != "farmhash64") {
3377 return errors::InvalidArgument("Unsupported method: ", method_string);
3378 }
3379 fingerprint_size = c->MakeDim(sizeof(uint64));
3380 }
3381
3382 DimensionHandle batch = c->Dim(c->input(0), 0);
3383 c->set_output(0, c->MakeShape({batch, fingerprint_size}));
3384 return OkStatus();
3385 });
3386
3387#ifdef INTEL_MKL
3388REGISTER_OP("_MklConcat")
3389 .Input("concat_dim: int32")
3390 .Input("values: N * T")
3391 .Input("mkl_concat_dim: uint8")
3392 .Input("mkl_values: N * uint8")
3393 .Output("output: T")
3394 .Output("mkl_output: uint8")
3395 .Attr("N: int >= 2")
3396 .Attr("T: type")
3397 .SetShapeFn([](InferenceContext* c) {
3398 return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3399 })
3400 .Doc(R"doc(
3401MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3402
3403NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3404expected to invoke these operators.
3405)doc");
3406#endif
3407
3408// Deprecated op registrations:
3409
3410// The following can be deleted after 10mar2017.
3411REGISTER_OP("BatchMatrixDiag")
3412 .Input("diagonal: T")
3413 .Output("output: T")
3414 .Attr("T: type")
3415 .Deprecated(14, "Use MatrixDiag")
3416 .SetShapeFn(shape_inference::UnknownShape);
3417REGISTER_OP("BatchMatrixSetDiag")
3418 .Input("input: T")
3419 .Input("diagonal: T")
3420 .Output("output: T")
3421 .Attr("T: type")
3422 .Deprecated(14, "Use MatrixSetDiag")
3423 .SetShapeFn(shape_inference::UnknownShape);
3424REGISTER_OP("BatchMatrixDiagPart")
3425 .Input("input: T")
3426 .Output("diagonal: T")
3427 .Attr("T: type")
3428 .Deprecated(14, "Use MatrixDiagPart")
3429 .SetShapeFn(shape_inference::UnknownShape);
3430REGISTER_OP("BatchMatrixBandPart")
3431 .Input("input: T")
3432 .Input("num_lower: int64")
3433 .Input("num_upper: int64")
3434 .Output("band: T")
3435 .Attr("T: type")
3436 .Deprecated(14, "Use MatrixBandPart")
3437 .SetShapeFn(shape_inference::UnknownShape);
3438
3439} // namespace tensorflow
3440