1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::DimensionHandle;
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26namespace {
27
28// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29// height and width come from the size_tensor.
30Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31 int size_input_idx, DimensionHandle channel_dim) {
32 // Verify shape of size input.
33 ShapeHandle size;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35 DimensionHandle unused;
36 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37
38 // Get size values from the size tensor.
39 const Tensor* size_tensor = c->input_tensor(size_input_idx);
40 DimensionHandle width;
41 DimensionHandle height;
42 if (size_tensor == nullptr) {
43 width = c->UnknownDim();
44 height = c->UnknownDim();
45 } else {
46 // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47 if (size_tensor->dtype() != DT_INT32) {
48 return errors::InvalidArgument(
49 "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50 "but got ",
51 DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52 " in ", c->DebugString());
53 }
54 auto vec = size_tensor->vec<int32>();
55 height = c->MakeDim(vec(0));
56 width = c->MakeDim(vec(1));
57 }
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59 return OkStatus();
60}
61
62Status ResizeShapeFn(InferenceContext* c) {
63 ShapeHandle input;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65 return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66 c->Dim(input, 3));
67}
68
69Status DecodeImageShapeFn(InferenceContext* c) {
70 ShapeHandle unused;
71 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72 DimensionHandle channels_dim;
73 int32_t channels;
74 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75 if (channels == 0) {
76 channels_dim = c->UnknownDim();
77 } else {
78 if (channels < 0) {
79 return errors::InvalidArgument("channels must be non-negative, got ",
80 channels);
81 }
82 channels_dim = c->MakeDim(channels);
83 }
84
85 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86 InferenceContext::kUnknownDim, channels_dim}));
87 return OkStatus();
88}
89
90Status DecodeImageV2ShapeFn(InferenceContext* c) {
91 ShapeHandle unused;
92 int32_t channels;
93 bool expand_animations;
94 DimensionHandle channels_dim;
95
96 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
97 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
98 TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
99
100 if (channels == 0) {
101 channels_dim = c->UnknownDim();
102 } else {
103 if (channels < 0) {
104 return errors::InvalidArgument("channels must be non-negative, got ",
105 channels);
106 }
107 channels_dim = c->MakeDim(channels);
108 }
109
110 // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
111 // will be returned for jpg, png, and bmp. `expand_animations` set to false
112 // will always return 3-D shapes for all (jpg, png, bmp, gif).
113 if (expand_animations) {
114 c->set_output(0, c->UnknownShape());
115 return OkStatus();
116 } else {
117 c->set_output(0,
118 c->MakeShape({InferenceContext::kUnknownDim,
119 InferenceContext::kUnknownDim, channels_dim}));
120 return OkStatus();
121 }
122}
123
124Status EncodeImageShapeFn(InferenceContext* c) {
125 ShapeHandle unused;
126 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
127 c->set_output(0, c->Scalar());
128 return OkStatus();
129}
130
131Status ColorspaceShapeFn(InferenceContext* c) {
132 ShapeHandle input;
133 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
134
135 // The last dimension value is always 3.
136 DimensionHandle last_dim;
137 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
138 ShapeHandle out;
139 TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
140 c->set_output(0, out);
141
142 return OkStatus();
143}
144
145Status NMSShapeFn(InferenceContext* c) {
146 // Get inputs and validate ranks.
147 ShapeHandle boxes;
148 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
149 ShapeHandle scores;
150 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
151 ShapeHandle max_output_size;
152 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
153 ShapeHandle iou_threshold;
154 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
155 ShapeHandle score_threshold;
156 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
157 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
158 DimensionHandle unused;
159 // The boxes[0] and scores[0] are both num_boxes.
160 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
161 // The boxes[1] is 4.
162 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
163
164 c->set_output(0, c->Vector(c->UnknownDim()));
165 return OkStatus();
166}
167
168Status SoftNMSShapeFn(InferenceContext* c) {
169 // Get inputs and validate ranks.
170 ShapeHandle boxes;
171 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
172 ShapeHandle scores;
173 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
174 ShapeHandle max_output_size;
175 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
176 ShapeHandle iou_threshold;
177 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
178 ShapeHandle score_threshold;
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
180 ShapeHandle soft_nms_sigma;
181 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &soft_nms_sigma));
182 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
183 DimensionHandle unused;
184 // The boxes[0] and scores[0] are both num_boxes.
185 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
186 // The boxes[1] is 4.
187 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
188
189 c->set_output(0, c->Vector(c->UnknownDim()));
190 c->set_output(1, c->Vector(c->UnknownDim()));
191 return OkStatus();
192}
193
194Status CombinedNMSShapeFn(InferenceContext* c) {
195 // Get inputs and validate ranks
196 ShapeHandle boxes;
197 // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4]
198 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &boxes));
199 ShapeHandle scores;
200 // scores is a tensor of Dimensions [batch_size, num_anchors, num_classes]
201 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &scores));
202 ShapeHandle max_output_size_per_class;
203 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size_per_class));
204 ShapeHandle max_total_size;
205 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_total_size));
206 ShapeHandle unused_shape;
207 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
208 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
209
210 DimensionHandle unused;
211 // boxes[0] and scores[0] are both batch_size
212 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
213 // boxes[1] and scores[1] are both num_anchors
214 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 1), c->Dim(scores, 1), &unused));
215 // The boxes[3] is 4.
216 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 3), 4, &unused));
217
218 DimensionHandle d = c->Dim(boxes, 2);
219 DimensionHandle class_dim = c->Dim(scores, 2);
220 if (c->ValueKnown(d) && c->ValueKnown(class_dim)) {
221 if (c->Value(d) != 1 && c->Value(d) != c->Value(class_dim)) {
222 return errors::InvalidArgument(
223 "third dimension of boxes must be either "
224 "1 or equal to the third dimension of scores");
225 }
226 }
227 DimensionHandle output_dim;
228 DimensionHandle batch_dim = c->Dim(boxes, 0);
229
230 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &output_dim));
231 if (c->ValueKnown(output_dim) && c->Value(output_dim) <= 0) {
232 return errors::InvalidArgument("max_total_size should be > 0 ");
233 }
234 DimensionHandle size_per_class;
235 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &size_per_class));
236
237 int64_t output_size;
238 bool pad_per_class;
239 TF_RETURN_IF_ERROR(c->GetAttr("pad_per_class", &pad_per_class));
240 if (!pad_per_class) {
241 output_size = c->Value(output_dim);
242 } else {
243 if (c->ValueKnown(size_per_class) && c->Value(size_per_class) <= 0) {
244 return errors::InvalidArgument(
245 "max_output_size_per_class must be > 0 "
246 "if pad_per_class is set to true ");
247 }
248 output_size = std::min(c->Value(output_dim),
249 c->Value(size_per_class) * c->Value(class_dim));
250 }
251 c->set_output(0, c->MakeShape({batch_dim, output_size, 4}));
252 c->set_output(1, c->MakeShape({batch_dim, output_size}));
253 c->set_output(2, c->MakeShape({batch_dim, output_size}));
254 c->set_output(3, c->Vector(batch_dim));
255 return OkStatus();
256}
257
258} // namespace
259
260// --------------------------------------------------------------------------
261REGISTER_OP("ResizeArea")
262 .Input("images: T")
263 .Input("size: int32")
264 .Output("resized_images: float")
265 .Attr(
266 "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
267 "bfloat16}")
268 .Attr("align_corners: bool = false")
269 .SetShapeFn(ResizeShapeFn);
270
271// --------------------------------------------------------------------------
272REGISTER_OP("ResizeBicubic")
273 .Input("images: T")
274 .Input("size: int32")
275 .Output("resized_images: float")
276 .Attr(
277 "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
278 "bfloat16}")
279 .Attr("align_corners: bool = false")
280 .Attr("half_pixel_centers: bool = false")
281 .SetShapeFn(ResizeShapeFn);
282
283// --------------------------------------------------------------------------
284REGISTER_OP("ResizeBicubicGrad")
285 .Input("grads: float")
286 .Input("original_image: T")
287 .Output("output: T")
288 .Attr("T: {float, double}")
289 .Attr("align_corners: bool = false")
290 .Attr("half_pixel_centers: bool = false")
291 .SetShapeFn([](InferenceContext* c) {
292 c->set_output(0, c->input(1));
293 return OkStatus();
294 });
295
296// --------------------------------------------------------------------------
297REGISTER_OP("ResizeBilinear")
298 .Input("images: T")
299 .Input("size: int32")
300 .Output("resized_images: float")
301 .Attr(
302 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
303 "float, double, bfloat16}")
304 .Attr("align_corners: bool = false")
305 .Attr("half_pixel_centers: bool = false")
306 .SetShapeFn(ResizeShapeFn);
307
308// --------------------------------------------------------------------------
309REGISTER_OP("ScaleAndTranslate")
310 .Input("images: T")
311 .Input("size: int32")
312 .Input("scale: float")
313 .Input("translation: float")
314 .Output("resized_images: float")
315 .Attr(
316 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
317 "float, double}")
318 .Attr("kernel_type: string = 'lanczos3'")
319 .Attr("antialias: bool = true")
320 .SetShapeFn(ResizeShapeFn);
321
322// --------------------------------------------------------------------------
323REGISTER_OP("QuantizedResizeBilinear")
324 .Input("images: T")
325 .Input("size: int32")
326 .Input("min: float")
327 .Input("max: float")
328 .Output("resized_images: T")
329 .Output("out_min: float")
330 .Output("out_max: float")
331 .Attr("T: {quint8, qint32, float}")
332 .Attr("align_corners: bool = false")
333 .Attr("half_pixel_centers: bool = false")
334 .SetShapeFn([](InferenceContext* c) {
335 TF_RETURN_IF_ERROR(ResizeShapeFn(c));
336 ShapeHandle min_shape;
337 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
338 ShapeHandle max_shape;
339 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
340 c->set_output(1, c->MakeShape({}));
341 c->set_output(2, c->MakeShape({}));
342 return OkStatus();
343 });
344
345// --------------------------------------------------------------------------
346REGISTER_OP("ResizeBilinearGrad")
347 .Input("grads: float")
348 .Input("original_image: T")
349 .Output("output: T")
350 .Attr("T: {float, bfloat16, half, double}")
351 .Attr("align_corners: bool = false")
352 .Attr("half_pixel_centers: bool = false")
353 .SetShapeFn([](InferenceContext* c) {
354 c->set_output(0, c->input(1));
355 return OkStatus();
356 });
357
358// --------------------------------------------------------------------------
359REGISTER_OP("ScaleAndTranslateGrad")
360 .Input("grads: T")
361 .Input("original_image: T")
362 .Input("scale: float")
363 .Input("translation: float")
364 .Output("output: T")
365 .Attr("T: {float}")
366 .Attr("kernel_type: string = 'lanczos3'")
367 .Attr("antialias: bool = true")
368 .SetShapeFn([](InferenceContext* c) {
369 c->set_output(0, c->input(1));
370 return OkStatus();
371 });
372
373// --------------------------------------------------------------------------
374REGISTER_OP("ResizeNearestNeighbor")
375 .Input("images: T")
376 .Input("size: int32")
377 .Output("resized_images: T")
378 .Attr(
379 "T: {int8, uint8, int16, uint16, int32, int64, half, float,"
380 "double, bfloat16}")
381 .Attr("align_corners: bool = false")
382 .Attr("half_pixel_centers: bool = false")
383 .SetShapeFn(ResizeShapeFn);
384
385// --------------------------------------------------------------------------
386REGISTER_OP("ResizeNearestNeighborGrad")
387 .Input("grads: T")
388 .Input("size: int32")
389 .Output("output: T")
390 .Attr("T: {uint8, int8, int32, half, float, double, bfloat16}")
391 .Attr("align_corners: bool = false")
392 .Attr("half_pixel_centers: bool = false")
393 .SetShapeFn([](InferenceContext* c) {
394 ShapeHandle input;
395 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
396 ShapeHandle unused;
397 DimensionHandle unused_dim;
398 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
399 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
400 const Tensor* size = c->input_tensor(1);
401 if (size == nullptr) {
402 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
403 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
404 } else {
405 auto size_vec = size->vec<int32>();
406 TF_RETURN_IF_ERROR(
407 c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
408 TF_RETURN_IF_ERROR(
409 c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
410 }
411 c->set_output(0, input);
412 return OkStatus();
413 });
414
415// --------------------------------------------------------------------------
416REGISTER_OP("RandomCrop")
417 .Input("image: T")
418 .Input("size: int64")
419 .Output("output: T")
420 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
421 .Attr("seed: int = 0")
422 .Attr("seed2: int = 0")
423 .SetIsStateful()
424 .Deprecated(8, "Random crop is now pure Python")
425 .SetShapeFn([](InferenceContext* c) {
426 ShapeHandle image;
427 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
428 DimensionHandle channels = c->Dim(image, -1);
429
430 ShapeHandle unused;
431 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
432
433 const Tensor* size = c->input_tensor(1);
434 DimensionHandle h;
435 DimensionHandle w;
436 if (size == nullptr) {
437 h = c->UnknownDim();
438 w = c->UnknownDim();
439 } else {
440 auto size_vec = size->vec<int64_t>();
441 h = c->MakeDim(size_vec(0));
442 w = c->MakeDim(size_vec(1));
443 }
444 c->set_output(0, c->MakeShape({h, w, channels}));
445 return OkStatus();
446 });
447// TODO(shlens): Support variable rank in RandomCrop.
448
449// --------------------------------------------------------------------------
450REGISTER_OP("DecodeImage")
451 .Input("contents: string")
452 // Setting `channels` to 0 means using the inherent number of channels in
453 // the image.
454 .Attr("channels: int = 0")
455 .Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
456 .Output("image: dtype")
457 .Attr("expand_animations: bool = true")
458 .SetShapeFn(DecodeImageV2ShapeFn);
459
460// --------------------------------------------------------------------------
461REGISTER_OP("DecodeJpeg")
462 .Input("contents: string")
463 .Attr("channels: int = 0")
464 .Attr("ratio: int = 1")
465 .Attr("fancy_upscaling: bool = true")
466 .Attr("try_recover_truncated: bool = false")
467 .Attr("acceptable_fraction: float = 1.0")
468 .Attr("dct_method: string = ''")
469 .Output("image: uint8")
470 .SetShapeFn(DecodeImageShapeFn);
471
472// --------------------------------------------------------------------------
473REGISTER_OP("DecodeAndCropJpeg")
474 .Input("contents: string")
475 .Input("crop_window: int32")
476 .Attr("channels: int = 0")
477 .Attr("ratio: int = 1")
478 .Attr("fancy_upscaling: bool = true")
479 .Attr("try_recover_truncated: bool = false")
480 .Attr("acceptable_fraction: float = 1.0")
481 .Attr("dct_method: string = ''")
482 .Output("image: uint8")
483 .SetShapeFn([](InferenceContext* c) {
484 ShapeHandle unused;
485 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
486 DimensionHandle channels_dim = c->UnknownDim();
487 DimensionHandle h = c->UnknownDim();
488 DimensionHandle w = c->UnknownDim();
489
490 int32_t channels;
491 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
492 if (channels != 0) {
493 if (channels < 0) {
494 return errors::InvalidArgument("channels must be non-negative, got ",
495 channels);
496 }
497 channels_dim = c->MakeDim(channels);
498 }
499
500 DimensionHandle unused_dim;
501 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
502 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
503
504 const Tensor* crop_window = c->input_tensor(1);
505 if (crop_window != nullptr) {
506 auto crop_window_vec = crop_window->vec<int32>();
507 h = c->MakeDim(crop_window_vec(2));
508 w = c->MakeDim(crop_window_vec(3));
509 }
510 c->set_output(0, c->MakeShape({h, w, channels_dim}));
511 return OkStatus();
512 });
513
514// --------------------------------------------------------------------------
515REGISTER_OP("EncodeJpeg")
516 .Input("image: uint8")
517 .Attr("format: {'', 'grayscale', 'rgb'} = ''")
518 .Attr("quality: int = 95")
519 .Attr("progressive: bool = false")
520 .Attr("optimize_size: bool = false")
521 .Attr("chroma_downsampling: bool = true")
522 .Attr("density_unit: {'in', 'cm'} = 'in'")
523 .Attr("x_density: int = 300")
524 .Attr("y_density: int = 300")
525 .Attr("xmp_metadata: string = ''")
526 .Output("contents: string")
527 .SetShapeFn(EncodeImageShapeFn);
528
529// --------------------------------------------------------------------------
530REGISTER_OP("EncodeJpegVariableQuality")
531 .Input("images: uint8")
532 .Input("quality: int32")
533 .Output("contents: string")
534 .SetShapeFn(EncodeImageShapeFn);
535
536// --------------------------------------------------------------------------
537REGISTER_OP("ExtractJpegShape")
538 .Input("contents: string")
539 .Output("image_shape: output_type")
540 .Attr("output_type: {int32, int64} = DT_INT32")
541 .SetShapeFn([](InferenceContext* c) {
542 ShapeHandle unused;
543 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
544 c->set_output(0, c->Vector(3));
545 return OkStatus();
546 });
547
548// --------------------------------------------------------------------------
549REGISTER_OP("AdjustContrast")
550 .Input("images: T")
551 .Input("contrast_factor: float")
552 .Input("min_value: float")
553 .Input("max_value: float")
554 .Output("output: float")
555 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
556 .Deprecated(2, "Use AdjustContrastv2 instead")
557 .SetShapeFn([](InferenceContext* c) {
558 // The contrast_factor, min_value, max_value should be scalar only.
559 ShapeHandle unused;
560 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
561 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
562 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
563 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
564 });
565
566// --------------------------------------------------------------------------
567REGISTER_OP("AdjustContrastv2")
568 .Input("images: T")
569 .Input("contrast_factor: float")
570 .Output("output: T")
571 .Attr("T: {half, float} = DT_FLOAT")
572 .SetShapeFn([](InferenceContext* c) {
573 // The contrast_factor should be scalar only.
574 ShapeHandle unused;
575 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
576 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
577 });
578
579// --------------------------------------------------------------------------
580REGISTER_OP("AdjustHue")
581 .Input("images: T")
582 .Input("delta: float")
583 .Output("output: T")
584 .Attr("T: {half, float} = DT_FLOAT")
585 .SetShapeFn([](InferenceContext* c) {
586 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
587 });
588
589// --------------------------------------------------------------------------
590REGISTER_OP("AdjustSaturation")
591 .Input("images: T")
592 .Input("scale: float")
593 .Output("output: T")
594 .Attr("T: {half, float} = DT_FLOAT")
595 .SetShapeFn([](InferenceContext* c) {
596 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
597 });
598
599// --------------------------------------------------------------------------
600REGISTER_OP("DecodePng")
601 .Input("contents: string")
602 .Attr("channels: int = 0")
603 .Attr("dtype: {uint8, uint16} = DT_UINT8")
604 .Output("image: dtype")
605 .SetShapeFn(DecodeImageShapeFn);
606
607// --------------------------------------------------------------------------
608REGISTER_OP("EncodePng")
609 .Attr("compression: int = -1")
610 .Attr("T: {uint8, uint16} = DT_UINT8")
611 .Input("image: T")
612 .Output("contents: string")
613 .SetShapeFn(EncodeImageShapeFn);
614
615// --------------------------------------------------------------------------
616REGISTER_OP("DecodeBmp")
617 .Input("contents: string")
618 .Output("image: uint8")
619 .Attr("channels: int = 0")
620 .SetShapeFn(DecodeImageShapeFn);
621
622// --------------------------------------------------------------------------
623REGISTER_OP("DecodeGif")
624 .Input("contents: string")
625 .Output("image: uint8")
626 .SetShapeFn([](InferenceContext* c) {
627 ShapeHandle unused;
628 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
629 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
630 InferenceContext::kUnknownDim,
631 InferenceContext::kUnknownDim, 3}));
632 return OkStatus();
633 });
634
635// --------------------------------------------------------------------------
636REGISTER_OP("RGBToHSV")
637 .Input("images: T")
638 .Output("output: T")
639 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
640 .SetShapeFn(ColorspaceShapeFn);
641
642// --------------------------------------------------------------------------
643REGISTER_OP("HSVToRGB")
644 .Input("images: T")
645 .Output("output: T")
646 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
647 .SetShapeFn(ColorspaceShapeFn);
648
649// --------------------------------------------------------------------------
650REGISTER_OP("DrawBoundingBoxes")
651 .Input("images: T")
652 .Input("boxes: float")
653 .Output("output: T")
654 .Attr("T: {float, half} = DT_FLOAT")
655 .SetShapeFn([](InferenceContext* c) {
656 // The rank of images should be 4.
657 ShapeHandle images;
658 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
659 // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
660 if (c->ValueKnown(c->Dim(images, 3))) {
661 int64_t depth = c->Value(c->Dim(images, 3));
662 if (!(depth == 1 || depth == 3 || depth == 4)) {
663 return errors::InvalidArgument(
664 "Channel depth should be either 1 (GRY), "
665 "3 (RGB), or 4 (RGBA)");
666 }
667 }
668
669 // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
670 ShapeHandle boxes;
671 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
672 // The last value of boxes shape is 4.
673 DimensionHandle unused;
674 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
675
676 // The rank of the input image (rank = 4) has already been restricted
677 // above, and the output is of the same shape as the input.
678 return shape_inference::UnchangedShape(c);
679 });
680
681// --------------------------------------------------------------------------
682REGISTER_OP("DrawBoundingBoxesV2")
683 .Input("images: T")
684 .Input("boxes: float")
685 .Input("colors: float")
686 .Output("output: T")
687 .Attr("T: {float, half} = DT_FLOAT")
688 .SetShapeFn([](InferenceContext* c) {
689 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
690 });
691
692// --------------------------------------------------------------------------
693REGISTER_OP("SampleDistortedBoundingBox")
694 .Input("image_size: T")
695 .Input("bounding_boxes: float")
696 .Output("begin: T")
697 .Output("size: T")
698 .Output("bboxes: float")
699 .Attr("T: {uint8, int8, int16, int32, int64}")
700 .Attr("seed: int = 0")
701 .Attr("seed2: int = 0")
702 .Attr("min_object_covered: float = 0.1")
703 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
704 .Attr("area_range: list(float) = [0.05, 1.0]")
705 .Attr("max_attempts: int = 100")
706 .Attr("use_image_if_no_bounding_boxes: bool = false")
707 .SetIsStateful()
708 .SetShapeFn([](InferenceContext* c) {
709 // Get inputs and validate ranks.
710 ShapeHandle image_size;
711 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
712 ShapeHandle bounding_boxes;
713 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
714 // image_size: 1-D with [height, width, channels]
715 // bounding_boxes: 3-D with shape [batch, N, 4]
716 DimensionHandle unused;
717 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
718 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
719
720 c->set_output(0, c->Vector(3));
721 c->set_output(1, c->Vector(3));
722 c->set_output(2, c->MakeShape({1, 1, 4}));
723 return OkStatus();
724 });
725
726REGISTER_OP("SampleDistortedBoundingBoxV2")
727 .Input("image_size: T")
728 .Input("bounding_boxes: float")
729 .Input("min_object_covered: float")
730 .Output("begin: T")
731 .Output("size: T")
732 .Output("bboxes: float")
733 .Attr("T: {uint8, int8, int16, int32, int64}")
734 .Attr("seed: int = 0")
735 .Attr("seed2: int = 0")
736 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
737 .Attr("area_range: list(float) = [0.05, 1.0]")
738 .Attr("max_attempts: int = 100")
739 .Attr("use_image_if_no_bounding_boxes: bool = false")
740 .SetIsStateful()
741 .SetShapeFn([](InferenceContext* c) {
742 // Get inputs and validate ranks.
743 ShapeHandle image_size;
744 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
745 ShapeHandle bounding_boxes;
746 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
747 ShapeHandle min_object_covered;
748 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
749 // image_size: 1-D with [height, width, channels]
750 // bounding_boxes: 3-D with shape [batch, N, 4]
751 DimensionHandle unused;
752 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
753 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
754
755 c->set_output(0, c->Vector(3));
756 c->set_output(1, c->Vector(3));
757 c->set_output(2, c->MakeShape({1, 1, 4}));
758 return OkStatus();
759 });
760
761REGISTER_OP("StatelessSampleDistortedBoundingBox")
762 .Input("image_size: T")
763 .Input("bounding_boxes: float")
764 .Input("min_object_covered: float")
765 .Input("seed: Tseed")
766 .Output("begin: T")
767 .Output("size: T")
768 .Output("bboxes: float")
769 .Attr("T: {uint8, int8, int16, int32, int64}")
770 .Attr("Tseed: {int32, int64}")
771 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
772 .Attr("area_range: list(float) = [0.05, 1.0]")
773 .Attr("max_attempts: int = 100")
774 .Attr("use_image_if_no_bounding_boxes: bool = false")
775 .SetShapeFn([](InferenceContext* c) {
776 // Get inputs and validate ranks.
777 ShapeHandle image_size;
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
779 ShapeHandle bounding_boxes;
780 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
781 ShapeHandle min_object_covered;
782 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
783 ShapeHandle seed;
784 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &seed));
785 // image_size: 1-D with [height, width, channels]
786 // bounding_boxes: 3-D with shape [batch, N, 4]
787 DimensionHandle unused;
788 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
789 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
790 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
791
792 c->set_output(0, c->Vector(3));
793 c->set_output(1, c->Vector(3));
794 c->set_output(2, c->MakeShape({1, 1, 4}));
795
796 return OkStatus();
797 });
798
799// --------------------------------------------------------------------------
800
801// glimpse = extract_glimpse(input, size, offsets) extract the glimpse
802// of size `size` centered at location `offsets` from the input tensor
803// `input`.
804//
805// REQUIRES: input.dims() == 4
806//
807REGISTER_OP("ExtractGlimpse")
808 .Input("input: float")
809 .Input("size: int32")
810 .Input("offsets: float")
811 .Output("glimpse: float")
812 .Attr("centered: bool = true")
813 .Attr("normalized: bool = true")
814 .Attr("uniform_noise: bool = true")
815 .Attr("noise: string = 'uniform'")
816 .SetShapeFn([](InferenceContext* c) {
817 ShapeHandle input;
818 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
819 ShapeHandle offsets;
820 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
821
822 DimensionHandle batch_dim;
823 TF_RETURN_IF_ERROR(
824 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
825 DimensionHandle unused;
826 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
827
828 bool uniform_noise = false;
829 TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
830 string noise;
831 TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
832 if (uniform_noise && (!noise.empty() && noise != "uniform")) {
833 return errors::InvalidArgument(
834 "The uniform_noise and noise should not be specified at the same "
835 "time");
836 }
837
838 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
839 c->Dim(input, 3));
840 });
841
842REGISTER_OP("ExtractGlimpseV2")
843 .Input("input: float")
844 .Input("size: int32")
845 .Input("offsets: float")
846 .Output("glimpse: float")
847 .Attr("centered: bool = true")
848 .Attr("normalized: bool = true")
849 .Attr("uniform_noise: bool = true")
850 .Attr("noise: string = 'uniform'")
851 .SetShapeFn([](InferenceContext* c) {
852 ShapeHandle input;
853 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
854 ShapeHandle offsets;
855 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
856
857 DimensionHandle batch_dim;
858 TF_RETURN_IF_ERROR(
859 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
860 DimensionHandle unused;
861 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
862
863 bool uniform_noise = false;
864 TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
865 string noise;
866 TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
867 if (uniform_noise && (!noise.empty() && noise != "uniform")) {
868 return errors::InvalidArgument(
869 "The uniform_noise and noise should not be specified at the same "
870 "time");
871 }
872
873 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
874 c->Dim(input, 3));
875 });
876
877// --------------------------------------------------------------------------
878
879REGISTER_OP("CropAndResize")
880 .Input("image: T")
881 .Input("boxes: float")
882 .Input("box_ind: int32")
883 .Input("crop_size: int32")
884 .Output("crops: float")
885 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
886 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
887 .Attr("extrapolation_value: float = 0")
888 .SetShapeFn([](InferenceContext* c) {
889 // Get inputs and validate ranks.
890 ShapeHandle input;
891 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
892 ShapeHandle boxes;
893 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
894 ShapeHandle box_ind;
895 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
896
897 // boxes[0] and box_ind[0] are both num_boxes.
898 DimensionHandle num_boxes_dim;
899 TF_RETURN_IF_ERROR(
900 c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
901
902 // boxes.dim(1) is 4.
903 DimensionHandle unused;
904 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
905
906 return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
907 c->Dim(input, 3));
908 });
909
910REGISTER_OP("CropAndResizeGradImage")
911 .Input("grads: float")
912 .Input("boxes: float")
913 .Input("box_ind: int32")
914 .Input("image_size: int32")
915 .Output("output: T")
916 .Attr("T: {float, half, double}")
917 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
918 .SetShapeFn([](InferenceContext* c) {
919 ShapeHandle out;
920 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
921 TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
922 c->set_output(0, out);
923 return OkStatus();
924 });
925
926REGISTER_OP("CropAndResizeGradBoxes")
927 .Input("grads: float")
928 .Input("image: T")
929 .Input("boxes: float")
930 .Input("box_ind: int32")
931 .Output("output: float")
932 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
933 .Attr("method: {'bilinear'} = 'bilinear'")
934 .SetShapeFn([](InferenceContext* c) {
935 c->set_output(0, c->input(2));
936 return OkStatus();
937 });
938
939// --------------------------------------------------------------------------
940
941REGISTER_OP("NonMaxSuppression")
942 .Input("boxes: float")
943 .Input("scores: float")
944 .Input("max_output_size: int32")
945 .Output("selected_indices: int32")
946 .Attr("iou_threshold: float = 0.5")
947 .SetShapeFn([](InferenceContext* c) {
948 // Get inputs and validate ranks.
949 ShapeHandle boxes;
950 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
951 ShapeHandle scores;
952 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
953 ShapeHandle max_output_size;
954 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
955 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
956 DimensionHandle unused;
957 // The boxes[0] and scores[0] are both num_boxes.
958 TF_RETURN_IF_ERROR(
959 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
960 // The boxes[1] is 4.
961 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
962
963 c->set_output(0, c->Vector(c->UnknownDim()));
964 return OkStatus();
965 });
966
967REGISTER_OP("NonMaxSuppressionV2")
968 .Input("boxes: T")
969 .Input("scores: T")
970 .Input("max_output_size: int32")
971 .Input("iou_threshold: T_threshold")
972 .Output("selected_indices: int32")
973 .Attr("T: {half, float} = DT_FLOAT")
974 .Attr("T_threshold: {half, float} = DT_FLOAT")
975 .SetShapeFn([](InferenceContext* c) {
976 // Get inputs and validate ranks.
977 ShapeHandle boxes;
978 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
979 ShapeHandle scores;
980 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
981 ShapeHandle max_output_size;
982 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
983 ShapeHandle iou_threshold;
984 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
985 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
986 DimensionHandle unused;
987 // The boxes[0] and scores[0] are both num_boxes.
988 TF_RETURN_IF_ERROR(
989 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
990 // The boxes[1] is 4.
991 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
992
993 c->set_output(0, c->Vector(c->UnknownDim()));
994 return OkStatus();
995 });
996
997REGISTER_OP("NonMaxSuppressionV3")
998 .Input("boxes: T")
999 .Input("scores: T")
1000 .Input("max_output_size: int32")
1001 .Input("iou_threshold: T_threshold")
1002 .Input("score_threshold: T_threshold")
1003 .Output("selected_indices: int32")
1004 .Attr("T: {half, float} = DT_FLOAT")
1005 .Attr("T_threshold: {half, float} = DT_FLOAT")
1006 .SetShapeFn(NMSShapeFn);
1007
1008REGISTER_OP("NonMaxSuppressionV4")
1009 .Input("boxes: T")
1010 .Input("scores: T")
1011 .Input("max_output_size: int32")
1012 .Input("iou_threshold: T_threshold")
1013 .Input("score_threshold: T_threshold")
1014 .Output("selected_indices: int32")
1015 .Output("valid_outputs: int32")
1016 .Attr("T: {half, float} = DT_FLOAT")
1017 .Attr("T_threshold: {half, float} = DT_FLOAT")
1018 .Attr("pad_to_max_output_size: bool = false")
1019 .SetShapeFn([](InferenceContext* c) {
1020 TF_RETURN_IF_ERROR(NMSShapeFn(c));
1021
1022 bool pad_to_max;
1023 TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1024 if (pad_to_max) {
1025 // If padded, overwrite the shape of the output to be static.
1026 DimensionHandle output_dim;
1027 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1028 c->set_output(0, c->MakeShape({output_dim}));
1029 }
1030 c->set_output(1, c->MakeShape({}));
1031 return OkStatus();
1032 });
1033
1034REGISTER_OP("NonMaxSuppressionV5")
1035 .Input("boxes: T")
1036 .Input("scores: T")
1037 .Input("max_output_size: int32")
1038 .Input("iou_threshold: T")
1039 .Input("score_threshold: T")
1040 .Input("soft_nms_sigma: T")
1041 .Output("selected_indices: int32")
1042 .Output("selected_scores: T")
1043 .Output("valid_outputs: int32")
1044 .Attr("T: {half, float} = DT_FLOAT")
1045 .Attr("pad_to_max_output_size: bool = false")
1046 .SetShapeFn([](InferenceContext* c) {
1047 TF_RETURN_IF_ERROR(SoftNMSShapeFn(c));
1048
1049 bool pad_to_max;
1050 TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1051 if (pad_to_max) {
1052 // If padded, overwrite the shape of the output to be static.
1053 DimensionHandle output_dim;
1054 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1055 c->set_output(0, c->MakeShape({output_dim}));
1056 c->set_output(1, c->MakeShape({output_dim}));
1057 }
1058
1059 c->set_output(2, c->MakeShape({}));
1060 return OkStatus();
1061 });
1062
1063REGISTER_OP("NonMaxSuppressionWithOverlaps")
1064 .Input("overlaps: float")
1065 .Input("scores: float")
1066 .Input("max_output_size: int32")
1067 .Input("overlap_threshold: float")
1068 .Input("score_threshold: float")
1069 .Output("selected_indices: int32")
1070 .SetShapeFn([](InferenceContext* c) {
1071 // Get inputs and validate ranks.
1072 ShapeHandle overlaps;
1073 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
1074 ShapeHandle scores;
1075 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
1076 ShapeHandle max_output_size;
1077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
1078 ShapeHandle overlap_threshold;
1079 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
1080 ShapeHandle score_threshold;
1081 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
1082 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
1083 DimensionHandle unused;
1084 // The boxes[0] and scores[0] are both num_boxes.
1085 TF_RETURN_IF_ERROR(
1086 c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
1087 // The boxes[1] is 4.
1088 TF_RETURN_IF_ERROR(
1089 c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
1090
1091 c->set_output(0, c->Vector(c->UnknownDim()));
1092 return OkStatus();
1093 });
1094
1095REGISTER_OP("CombinedNonMaxSuppression")
1096 .Input("boxes: float")
1097 .Input("scores: float")
1098 .Input("max_output_size_per_class: int32")
1099 .Input("max_total_size: int32")
1100 .Input("iou_threshold: float")
1101 .Input("score_threshold: float")
1102 .Output("nmsed_boxes: float")
1103 .Output("nmsed_scores: float")
1104 .Output("nmsed_classes: float")
1105 .Output("valid_detections: int32")
1106 .Attr("pad_per_class: bool = false")
1107 .Attr("clip_boxes: bool = true")
1108 .SetShapeFn(CombinedNMSShapeFn);
1109
1110REGISTER_OP("GenerateBoundingBoxProposals")
1111 .Input("scores: float")
1112 .Input("bbox_deltas: float")
1113 .Input("image_info: float")
1114 .Input("anchors: float")
1115 .Input("nms_threshold: float")
1116 .Input("pre_nms_topn: int32")
1117 .Input("min_size: float")
1118 .Output("rois: float")
1119 .Output("roi_probabilities: float")
1120 .Attr("post_nms_topn: int = 300")
1121 .SetShapeFn([](InferenceContext* c) -> Status {
1122 // make sure input tensors have are correct rank
1123 ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold,
1124 n_pre_nms, min_box_size;
1125 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores)); //(N, H, W, A)
1126 TF_RETURN_IF_ERROR(
1127 c->WithRank(c->input(1), 4, &bounding_boxes)); //(N,H,W,A4)
1128 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &images)); // (N,5)
1129 auto im_info = c->Dim(images, 1);
1130 TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info));
1131 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors)); // (A4)
1132 // check scalar tensors
1133 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &nms_threshold));
1134 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &n_pre_nms));
1135 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &min_box_size));
1136
1137 // TODO(skama): verify that the inputs are compatible
1138 int post_nms_top_n;
1139 TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn", &post_nms_top_n));
1140 auto roi_shape = c->MakeShape(
1141 {c->Dim(scores, 0), post_nms_top_n, 4}); //(N,post_nms_top_n,4)
1142 auto prob_shape = c->MakeShape(
1143 {c->Dim(scores, 0), post_nms_top_n}); // (N,post_nms_top_n)
1144 c->set_output(0, roi_shape);
1145 c->set_output(1, prob_shape);
1146 return OkStatus();
1147 });
1148
1149// V3 op supports fill_value.
1150// V2 op supports output_shape.
1151// V1 op is in contrib.
1152REGISTER_OP("ImageProjectiveTransformV2")
1153 .Input("images: dtype")
1154 .Input("transforms: float32")
1155 .Input("output_shape: int32")
1156 .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1157 .Attr("interpolation: string")
1158 .Attr("fill_mode: string = 'CONSTANT'")
1159 .Output("transformed_images: dtype")
1160 .SetShapeFn([](InferenceContext* c) {
1161 ShapeHandle input;
1162 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1163 return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1164 c->Dim(input, 3));
1165 });
1166
1167REGISTER_OP("ImageProjectiveTransformV3")
1168 .Input("images: dtype")
1169 .Input("transforms: float32")
1170 .Input("output_shape: int32")
1171 .Input("fill_value: float32")
1172 .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1173 .Attr("interpolation: string")
1174 .Attr("fill_mode: string = 'CONSTANT'")
1175 .Output("transformed_images: dtype")
1176 .SetShapeFn([](InferenceContext* c) {
1177 ShapeHandle input;
1178 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1179 return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1180 c->Dim(input, 3));
1181 });
1182
1183} // namespace tensorflow
1184