1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::DimensionHandle; |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | namespace { |
27 | |
28 | // Sets output[0] to shape [batch_dim,height,width,channel_dim], where |
29 | // height and width come from the size_tensor. |
30 | Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, |
31 | int size_input_idx, DimensionHandle channel_dim) { |
32 | // Verify shape of size input. |
33 | ShapeHandle size; |
34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size)); |
35 | DimensionHandle unused; |
36 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused)); |
37 | |
38 | // Get size values from the size tensor. |
39 | const Tensor* size_tensor = c->input_tensor(size_input_idx); |
40 | DimensionHandle width; |
41 | DimensionHandle height; |
42 | if (size_tensor == nullptr) { |
43 | width = c->UnknownDim(); |
44 | height = c->UnknownDim(); |
45 | } else { |
46 | // TODO(petewarden) - Remove once we have constant evaluation in C++ only. |
47 | if (size_tensor->dtype() != DT_INT32) { |
48 | return errors::InvalidArgument( |
49 | "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 " |
50 | "but got " , |
51 | DataTypeString(size_tensor->dtype()), " for input #" , size_input_idx, |
52 | " in " , c->DebugString()); |
53 | } |
54 | auto vec = size_tensor->vec<int32>(); |
55 | height = c->MakeDim(vec(0)); |
56 | width = c->MakeDim(vec(1)); |
57 | } |
58 | c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim})); |
59 | return OkStatus(); |
60 | } |
61 | |
62 | Status ResizeShapeFn(InferenceContext* c) { |
63 | ShapeHandle input; |
64 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
65 | return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */, |
66 | c->Dim(input, 3)); |
67 | } |
68 | |
69 | Status DecodeImageShapeFn(InferenceContext* c) { |
70 | ShapeHandle unused; |
71 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
72 | DimensionHandle channels_dim; |
73 | int32_t channels; |
74 | TF_RETURN_IF_ERROR(c->GetAttr("channels" , &channels)); |
75 | if (channels == 0) { |
76 | channels_dim = c->UnknownDim(); |
77 | } else { |
78 | if (channels < 0) { |
79 | return errors::InvalidArgument("channels must be non-negative, got " , |
80 | channels); |
81 | } |
82 | channels_dim = c->MakeDim(channels); |
83 | } |
84 | |
85 | c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, |
86 | InferenceContext::kUnknownDim, channels_dim})); |
87 | return OkStatus(); |
88 | } |
89 | |
90 | Status DecodeImageV2ShapeFn(InferenceContext* c) { |
91 | ShapeHandle unused; |
92 | int32_t channels; |
93 | bool expand_animations; |
94 | DimensionHandle channels_dim; |
95 | |
96 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
97 | TF_RETURN_IF_ERROR(c->GetAttr("channels" , &channels)); |
98 | TF_RETURN_IF_ERROR(c->GetAttr("expand_animations" , &expand_animations)); |
99 | |
100 | if (channels == 0) { |
101 | channels_dim = c->UnknownDim(); |
102 | } else { |
103 | if (channels < 0) { |
104 | return errors::InvalidArgument("channels must be non-negative, got " , |
105 | channels); |
106 | } |
107 | channels_dim = c->MakeDim(channels); |
108 | } |
109 | |
110 | // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes |
111 | // will be returned for jpg, png, and bmp. `expand_animations` set to false |
112 | // will always return 3-D shapes for all (jpg, png, bmp, gif). |
113 | if (expand_animations) { |
114 | c->set_output(0, c->UnknownShape()); |
115 | return OkStatus(); |
116 | } else { |
117 | c->set_output(0, |
118 | c->MakeShape({InferenceContext::kUnknownDim, |
119 | InferenceContext::kUnknownDim, channels_dim})); |
120 | return OkStatus(); |
121 | } |
122 | } |
123 | |
124 | Status EncodeImageShapeFn(InferenceContext* c) { |
125 | ShapeHandle unused; |
126 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused)); |
127 | c->set_output(0, c->Scalar()); |
128 | return OkStatus(); |
129 | } |
130 | |
131 | Status ColorspaceShapeFn(InferenceContext* c) { |
132 | ShapeHandle input; |
133 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); |
134 | |
135 | // The last dimension value is always 3. |
136 | DimensionHandle last_dim; |
137 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim)); |
138 | ShapeHandle out; |
139 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out)); |
140 | c->set_output(0, out); |
141 | |
142 | return OkStatus(); |
143 | } |
144 | |
145 | Status NMSShapeFn(InferenceContext* c) { |
146 | // Get inputs and validate ranks. |
147 | ShapeHandle boxes; |
148 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); |
149 | ShapeHandle scores; |
150 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); |
151 | ShapeHandle max_output_size; |
152 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); |
153 | ShapeHandle iou_threshold; |
154 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); |
155 | ShapeHandle score_threshold; |
156 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); |
157 | // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. |
158 | DimensionHandle unused; |
159 | // The boxes[0] and scores[0] are both num_boxes. |
160 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); |
161 | // The boxes[1] is 4. |
162 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); |
163 | |
164 | c->set_output(0, c->Vector(c->UnknownDim())); |
165 | return OkStatus(); |
166 | } |
167 | |
168 | Status SoftNMSShapeFn(InferenceContext* c) { |
169 | // Get inputs and validate ranks. |
170 | ShapeHandle boxes; |
171 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); |
172 | ShapeHandle scores; |
173 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); |
174 | ShapeHandle max_output_size; |
175 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); |
176 | ShapeHandle iou_threshold; |
177 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); |
178 | ShapeHandle score_threshold; |
179 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); |
180 | ShapeHandle soft_nms_sigma; |
181 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &soft_nms_sigma)); |
182 | // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. |
183 | DimensionHandle unused; |
184 | // The boxes[0] and scores[0] are both num_boxes. |
185 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); |
186 | // The boxes[1] is 4. |
187 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); |
188 | |
189 | c->set_output(0, c->Vector(c->UnknownDim())); |
190 | c->set_output(1, c->Vector(c->UnknownDim())); |
191 | return OkStatus(); |
192 | } |
193 | |
194 | Status CombinedNMSShapeFn(InferenceContext* c) { |
195 | // Get inputs and validate ranks |
196 | ShapeHandle boxes; |
197 | // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4] |
198 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &boxes)); |
199 | ShapeHandle scores; |
200 | // scores is a tensor of Dimensions [batch_size, num_anchors, num_classes] |
201 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &scores)); |
202 | ShapeHandle max_output_size_per_class; |
203 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size_per_class)); |
204 | ShapeHandle max_total_size; |
205 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_total_size)); |
206 | ShapeHandle unused_shape; |
207 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); |
208 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); |
209 | |
210 | DimensionHandle unused; |
211 | // boxes[0] and scores[0] are both batch_size |
212 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); |
213 | // boxes[1] and scores[1] are both num_anchors |
214 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 1), c->Dim(scores, 1), &unused)); |
215 | // The boxes[3] is 4. |
216 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 3), 4, &unused)); |
217 | |
218 | DimensionHandle d = c->Dim(boxes, 2); |
219 | DimensionHandle class_dim = c->Dim(scores, 2); |
220 | if (c->ValueKnown(d) && c->ValueKnown(class_dim)) { |
221 | if (c->Value(d) != 1 && c->Value(d) != c->Value(class_dim)) { |
222 | return errors::InvalidArgument( |
223 | "third dimension of boxes must be either " |
224 | "1 or equal to the third dimension of scores" ); |
225 | } |
226 | } |
227 | DimensionHandle output_dim; |
228 | DimensionHandle batch_dim = c->Dim(boxes, 0); |
229 | |
230 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &output_dim)); |
231 | if (c->ValueKnown(output_dim) && c->Value(output_dim) <= 0) { |
232 | return errors::InvalidArgument("max_total_size should be > 0 " ); |
233 | } |
234 | DimensionHandle size_per_class; |
235 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &size_per_class)); |
236 | |
237 | int64_t output_size; |
238 | bool pad_per_class; |
239 | TF_RETURN_IF_ERROR(c->GetAttr("pad_per_class" , &pad_per_class)); |
240 | if (!pad_per_class) { |
241 | output_size = c->Value(output_dim); |
242 | } else { |
243 | if (c->ValueKnown(size_per_class) && c->Value(size_per_class) <= 0) { |
244 | return errors::InvalidArgument( |
245 | "max_output_size_per_class must be > 0 " |
246 | "if pad_per_class is set to true " ); |
247 | } |
248 | output_size = std::min(c->Value(output_dim), |
249 | c->Value(size_per_class) * c->Value(class_dim)); |
250 | } |
251 | c->set_output(0, c->MakeShape({batch_dim, output_size, 4})); |
252 | c->set_output(1, c->MakeShape({batch_dim, output_size})); |
253 | c->set_output(2, c->MakeShape({batch_dim, output_size})); |
254 | c->set_output(3, c->Vector(batch_dim)); |
255 | return OkStatus(); |
256 | } |
257 | |
258 | } // namespace |
259 | |
260 | // -------------------------------------------------------------------------- |
261 | REGISTER_OP("ResizeArea" ) |
262 | .Input("images: T" ) |
263 | .Input("size: int32" ) |
264 | .Output("resized_images: float" ) |
265 | .Attr( |
266 | "T: {int8, uint8, int16, uint16, int32, int64, half, float, double," |
267 | "bfloat16}" ) |
268 | .Attr("align_corners: bool = false" ) |
269 | .SetShapeFn(ResizeShapeFn); |
270 | |
271 | // -------------------------------------------------------------------------- |
272 | REGISTER_OP("ResizeBicubic" ) |
273 | .Input("images: T" ) |
274 | .Input("size: int32" ) |
275 | .Output("resized_images: float" ) |
276 | .Attr( |
277 | "T: {int8, uint8, int16, uint16, int32, int64, half, float, double," |
278 | "bfloat16}" ) |
279 | .Attr("align_corners: bool = false" ) |
280 | .Attr("half_pixel_centers: bool = false" ) |
281 | .SetShapeFn(ResizeShapeFn); |
282 | |
283 | // -------------------------------------------------------------------------- |
284 | REGISTER_OP("ResizeBicubicGrad" ) |
285 | .Input("grads: float" ) |
286 | .Input("original_image: T" ) |
287 | .Output("output: T" ) |
288 | .Attr("T: {float, double}" ) |
289 | .Attr("align_corners: bool = false" ) |
290 | .Attr("half_pixel_centers: bool = false" ) |
291 | .SetShapeFn([](InferenceContext* c) { |
292 | c->set_output(0, c->input(1)); |
293 | return OkStatus(); |
294 | }); |
295 | |
296 | // -------------------------------------------------------------------------- |
297 | REGISTER_OP("ResizeBilinear" ) |
298 | .Input("images: T" ) |
299 | .Input("size: int32" ) |
300 | .Output("resized_images: float" ) |
301 | .Attr( |
302 | "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, " |
303 | "float, double, bfloat16}" ) |
304 | .Attr("align_corners: bool = false" ) |
305 | .Attr("half_pixel_centers: bool = false" ) |
306 | .SetShapeFn(ResizeShapeFn); |
307 | |
308 | // -------------------------------------------------------------------------- |
309 | REGISTER_OP("ScaleAndTranslate" ) |
310 | .Input("images: T" ) |
311 | .Input("size: int32" ) |
312 | .Input("scale: float" ) |
313 | .Input("translation: float" ) |
314 | .Output("resized_images: float" ) |
315 | .Attr( |
316 | "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, " |
317 | "float, double}" ) |
318 | .Attr("kernel_type: string = 'lanczos3'" ) |
319 | .Attr("antialias: bool = true" ) |
320 | .SetShapeFn(ResizeShapeFn); |
321 | |
322 | // -------------------------------------------------------------------------- |
323 | REGISTER_OP("QuantizedResizeBilinear" ) |
324 | .Input("images: T" ) |
325 | .Input("size: int32" ) |
326 | .Input("min: float" ) |
327 | .Input("max: float" ) |
328 | .Output("resized_images: T" ) |
329 | .Output("out_min: float" ) |
330 | .Output("out_max: float" ) |
331 | .Attr("T: {quint8, qint32, float}" ) |
332 | .Attr("align_corners: bool = false" ) |
333 | .Attr("half_pixel_centers: bool = false" ) |
334 | .SetShapeFn([](InferenceContext* c) { |
335 | TF_RETURN_IF_ERROR(ResizeShapeFn(c)); |
336 | ShapeHandle min_shape; |
337 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape)); |
338 | ShapeHandle max_shape; |
339 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape)); |
340 | c->set_output(1, c->MakeShape({})); |
341 | c->set_output(2, c->MakeShape({})); |
342 | return OkStatus(); |
343 | }); |
344 | |
345 | // -------------------------------------------------------------------------- |
346 | REGISTER_OP("ResizeBilinearGrad" ) |
347 | .Input("grads: float" ) |
348 | .Input("original_image: T" ) |
349 | .Output("output: T" ) |
350 | .Attr("T: {float, bfloat16, half, double}" ) |
351 | .Attr("align_corners: bool = false" ) |
352 | .Attr("half_pixel_centers: bool = false" ) |
353 | .SetShapeFn([](InferenceContext* c) { |
354 | c->set_output(0, c->input(1)); |
355 | return OkStatus(); |
356 | }); |
357 | |
358 | // -------------------------------------------------------------------------- |
359 | REGISTER_OP("ScaleAndTranslateGrad" ) |
360 | .Input("grads: T" ) |
361 | .Input("original_image: T" ) |
362 | .Input("scale: float" ) |
363 | .Input("translation: float" ) |
364 | .Output("output: T" ) |
365 | .Attr("T: {float}" ) |
366 | .Attr("kernel_type: string = 'lanczos3'" ) |
367 | .Attr("antialias: bool = true" ) |
368 | .SetShapeFn([](InferenceContext* c) { |
369 | c->set_output(0, c->input(1)); |
370 | return OkStatus(); |
371 | }); |
372 | |
373 | // -------------------------------------------------------------------------- |
374 | REGISTER_OP("ResizeNearestNeighbor" ) |
375 | .Input("images: T" ) |
376 | .Input("size: int32" ) |
377 | .Output("resized_images: T" ) |
378 | .Attr( |
379 | "T: {int8, uint8, int16, uint16, int32, int64, half, float," |
380 | "double, bfloat16}" ) |
381 | .Attr("align_corners: bool = false" ) |
382 | .Attr("half_pixel_centers: bool = false" ) |
383 | .SetShapeFn(ResizeShapeFn); |
384 | |
385 | // -------------------------------------------------------------------------- |
386 | REGISTER_OP("ResizeNearestNeighborGrad" ) |
387 | .Input("grads: T" ) |
388 | .Input("size: int32" ) |
389 | .Output("output: T" ) |
390 | .Attr("T: {uint8, int8, int32, half, float, double, bfloat16}" ) |
391 | .Attr("align_corners: bool = false" ) |
392 | .Attr("half_pixel_centers: bool = false" ) |
393 | .SetShapeFn([](InferenceContext* c) { |
394 | ShapeHandle input; |
395 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
396 | ShapeHandle unused; |
397 | DimensionHandle unused_dim; |
398 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
399 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim)); |
400 | const Tensor* size = c->input_tensor(1); |
401 | if (size == nullptr) { |
402 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input)); |
403 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input)); |
404 | } else { |
405 | auto size_vec = size->vec<int32>(); |
406 | TF_RETURN_IF_ERROR( |
407 | c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input)); |
408 | TF_RETURN_IF_ERROR( |
409 | c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input)); |
410 | } |
411 | c->set_output(0, input); |
412 | return OkStatus(); |
413 | }); |
414 | |
415 | // -------------------------------------------------------------------------- |
416 | REGISTER_OP("RandomCrop" ) |
417 | .Input("image: T" ) |
418 | .Input("size: int64" ) |
419 | .Output("output: T" ) |
420 | .Attr("T: {uint8, int8, int16, int32, int64, float, double}" ) |
421 | .Attr("seed: int = 0" ) |
422 | .Attr("seed2: int = 0" ) |
423 | .SetIsStateful() |
424 | .Deprecated(8, "Random crop is now pure Python" ) |
425 | .SetShapeFn([](InferenceContext* c) { |
426 | ShapeHandle image; |
427 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image)); |
428 | DimensionHandle channels = c->Dim(image, -1); |
429 | |
430 | ShapeHandle unused; |
431 | TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused)); |
432 | |
433 | const Tensor* size = c->input_tensor(1); |
434 | DimensionHandle h; |
435 | DimensionHandle w; |
436 | if (size == nullptr) { |
437 | h = c->UnknownDim(); |
438 | w = c->UnknownDim(); |
439 | } else { |
440 | auto size_vec = size->vec<int64_t>(); |
441 | h = c->MakeDim(size_vec(0)); |
442 | w = c->MakeDim(size_vec(1)); |
443 | } |
444 | c->set_output(0, c->MakeShape({h, w, channels})); |
445 | return OkStatus(); |
446 | }); |
447 | // TODO(shlens): Support variable rank in RandomCrop. |
448 | |
449 | // -------------------------------------------------------------------------- |
450 | REGISTER_OP("DecodeImage" ) |
451 | .Input("contents: string" ) |
452 | // Setting `channels` to 0 means using the inherent number of channels in |
453 | // the image. |
454 | .Attr("channels: int = 0" ) |
455 | .Attr("dtype: {uint8, uint16, float32} = DT_UINT8" ) |
456 | .Output("image: dtype" ) |
457 | .Attr("expand_animations: bool = true" ) |
458 | .SetShapeFn(DecodeImageV2ShapeFn); |
459 | |
460 | // -------------------------------------------------------------------------- |
461 | REGISTER_OP("DecodeJpeg" ) |
462 | .Input("contents: string" ) |
463 | .Attr("channels: int = 0" ) |
464 | .Attr("ratio: int = 1" ) |
465 | .Attr("fancy_upscaling: bool = true" ) |
466 | .Attr("try_recover_truncated: bool = false" ) |
467 | .Attr("acceptable_fraction: float = 1.0" ) |
468 | .Attr("dct_method: string = ''" ) |
469 | .Output("image: uint8" ) |
470 | .SetShapeFn(DecodeImageShapeFn); |
471 | |
472 | // -------------------------------------------------------------------------- |
473 | REGISTER_OP("DecodeAndCropJpeg" ) |
474 | .Input("contents: string" ) |
475 | .Input("crop_window: int32" ) |
476 | .Attr("channels: int = 0" ) |
477 | .Attr("ratio: int = 1" ) |
478 | .Attr("fancy_upscaling: bool = true" ) |
479 | .Attr("try_recover_truncated: bool = false" ) |
480 | .Attr("acceptable_fraction: float = 1.0" ) |
481 | .Attr("dct_method: string = ''" ) |
482 | .Output("image: uint8" ) |
483 | .SetShapeFn([](InferenceContext* c) { |
484 | ShapeHandle unused; |
485 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
486 | DimensionHandle channels_dim = c->UnknownDim(); |
487 | DimensionHandle h = c->UnknownDim(); |
488 | DimensionHandle w = c->UnknownDim(); |
489 | |
490 | int32_t channels; |
491 | TF_RETURN_IF_ERROR(c->GetAttr("channels" , &channels)); |
492 | if (channels != 0) { |
493 | if (channels < 0) { |
494 | return errors::InvalidArgument("channels must be non-negative, got " , |
495 | channels); |
496 | } |
497 | channels_dim = c->MakeDim(channels); |
498 | } |
499 | |
500 | DimensionHandle unused_dim; |
501 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
502 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim)); |
503 | |
504 | const Tensor* crop_window = c->input_tensor(1); |
505 | if (crop_window != nullptr) { |
506 | auto crop_window_vec = crop_window->vec<int32>(); |
507 | h = c->MakeDim(crop_window_vec(2)); |
508 | w = c->MakeDim(crop_window_vec(3)); |
509 | } |
510 | c->set_output(0, c->MakeShape({h, w, channels_dim})); |
511 | return OkStatus(); |
512 | }); |
513 | |
514 | // -------------------------------------------------------------------------- |
515 | REGISTER_OP("EncodeJpeg" ) |
516 | .Input("image: uint8" ) |
517 | .Attr("format: {'', 'grayscale', 'rgb'} = ''" ) |
518 | .Attr("quality: int = 95" ) |
519 | .Attr("progressive: bool = false" ) |
520 | .Attr("optimize_size: bool = false" ) |
521 | .Attr("chroma_downsampling: bool = true" ) |
522 | .Attr("density_unit: {'in', 'cm'} = 'in'" ) |
523 | .Attr("x_density: int = 300" ) |
524 | .Attr("y_density: int = 300" ) |
525 | .Attr("xmp_metadata: string = ''" ) |
526 | .Output("contents: string" ) |
527 | .SetShapeFn(EncodeImageShapeFn); |
528 | |
529 | // -------------------------------------------------------------------------- |
530 | REGISTER_OP("EncodeJpegVariableQuality" ) |
531 | .Input("images: uint8" ) |
532 | .Input("quality: int32" ) |
533 | .Output("contents: string" ) |
534 | .SetShapeFn(EncodeImageShapeFn); |
535 | |
536 | // -------------------------------------------------------------------------- |
537 | REGISTER_OP("ExtractJpegShape" ) |
538 | .Input("contents: string" ) |
539 | .Output("image_shape: output_type" ) |
540 | .Attr("output_type: {int32, int64} = DT_INT32" ) |
541 | .SetShapeFn([](InferenceContext* c) { |
542 | ShapeHandle unused; |
543 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
544 | c->set_output(0, c->Vector(3)); |
545 | return OkStatus(); |
546 | }); |
547 | |
548 | // -------------------------------------------------------------------------- |
549 | REGISTER_OP("AdjustContrast" ) |
550 | .Input("images: T" ) |
551 | .Input("contrast_factor: float" ) |
552 | .Input("min_value: float" ) |
553 | .Input("max_value: float" ) |
554 | .Output("output: float" ) |
555 | .Attr("T: {uint8, int8, int16, int32, int64, float, double}" ) |
556 | .Deprecated(2, "Use AdjustContrastv2 instead" ) |
557 | .SetShapeFn([](InferenceContext* c) { |
558 | // The contrast_factor, min_value, max_value should be scalar only. |
559 | ShapeHandle unused; |
560 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
561 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
562 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
563 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
564 | }); |
565 | |
566 | // -------------------------------------------------------------------------- |
567 | REGISTER_OP("AdjustContrastv2" ) |
568 | .Input("images: T" ) |
569 | .Input("contrast_factor: float" ) |
570 | .Output("output: T" ) |
571 | .Attr("T: {half, float} = DT_FLOAT" ) |
572 | .SetShapeFn([](InferenceContext* c) { |
573 | // The contrast_factor should be scalar only. |
574 | ShapeHandle unused; |
575 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
576 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
577 | }); |
578 | |
579 | // -------------------------------------------------------------------------- |
580 | REGISTER_OP("AdjustHue" ) |
581 | .Input("images: T" ) |
582 | .Input("delta: float" ) |
583 | .Output("output: T" ) |
584 | .Attr("T: {half, float} = DT_FLOAT" ) |
585 | .SetShapeFn([](InferenceContext* c) { |
586 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
587 | }); |
588 | |
589 | // -------------------------------------------------------------------------- |
590 | REGISTER_OP("AdjustSaturation" ) |
591 | .Input("images: T" ) |
592 | .Input("scale: float" ) |
593 | .Output("output: T" ) |
594 | .Attr("T: {half, float} = DT_FLOAT" ) |
595 | .SetShapeFn([](InferenceContext* c) { |
596 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
597 | }); |
598 | |
599 | // -------------------------------------------------------------------------- |
600 | REGISTER_OP("DecodePng" ) |
601 | .Input("contents: string" ) |
602 | .Attr("channels: int = 0" ) |
603 | .Attr("dtype: {uint8, uint16} = DT_UINT8" ) |
604 | .Output("image: dtype" ) |
605 | .SetShapeFn(DecodeImageShapeFn); |
606 | |
607 | // -------------------------------------------------------------------------- |
608 | REGISTER_OP("EncodePng" ) |
609 | .Attr("compression: int = -1" ) |
610 | .Attr("T: {uint8, uint16} = DT_UINT8" ) |
611 | .Input("image: T" ) |
612 | .Output("contents: string" ) |
613 | .SetShapeFn(EncodeImageShapeFn); |
614 | |
615 | // -------------------------------------------------------------------------- |
616 | REGISTER_OP("DecodeBmp" ) |
617 | .Input("contents: string" ) |
618 | .Output("image: uint8" ) |
619 | .Attr("channels: int = 0" ) |
620 | .SetShapeFn(DecodeImageShapeFn); |
621 | |
622 | // -------------------------------------------------------------------------- |
623 | REGISTER_OP("DecodeGif" ) |
624 | .Input("contents: string" ) |
625 | .Output("image: uint8" ) |
626 | .SetShapeFn([](InferenceContext* c) { |
627 | ShapeHandle unused; |
628 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
629 | c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, |
630 | InferenceContext::kUnknownDim, |
631 | InferenceContext::kUnknownDim, 3})); |
632 | return OkStatus(); |
633 | }); |
634 | |
635 | // -------------------------------------------------------------------------- |
636 | REGISTER_OP("RGBToHSV" ) |
637 | .Input("images: T" ) |
638 | .Output("output: T" ) |
639 | .Attr("T: {half, bfloat16, float, double} = DT_FLOAT" ) |
640 | .SetShapeFn(ColorspaceShapeFn); |
641 | |
642 | // -------------------------------------------------------------------------- |
643 | REGISTER_OP("HSVToRGB" ) |
644 | .Input("images: T" ) |
645 | .Output("output: T" ) |
646 | .Attr("T: {half, bfloat16, float, double} = DT_FLOAT" ) |
647 | .SetShapeFn(ColorspaceShapeFn); |
648 | |
649 | // -------------------------------------------------------------------------- |
650 | REGISTER_OP("DrawBoundingBoxes" ) |
651 | .Input("images: T" ) |
652 | .Input("boxes: float" ) |
653 | .Output("output: T" ) |
654 | .Attr("T: {float, half} = DT_FLOAT" ) |
655 | .SetShapeFn([](InferenceContext* c) { |
656 | // The rank of images should be 4. |
657 | ShapeHandle images; |
658 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images)); |
659 | // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA). |
660 | if (c->ValueKnown(c->Dim(images, 3))) { |
661 | int64_t depth = c->Value(c->Dim(images, 3)); |
662 | if (!(depth == 1 || depth == 3 || depth == 4)) { |
663 | return errors::InvalidArgument( |
664 | "Channel depth should be either 1 (GRY), " |
665 | "3 (RGB), or 4 (RGBA)" ); |
666 | } |
667 | } |
668 | |
669 | // The rank of boxes is 3: [batch, num_bounding_boxes, 4]. |
670 | ShapeHandle boxes; |
671 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes)); |
672 | // The last value of boxes shape is 4. |
673 | DimensionHandle unused; |
674 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused)); |
675 | |
676 | // The rank of the input image (rank = 4) has already been restricted |
677 | // above, and the output is of the same shape as the input. |
678 | return shape_inference::UnchangedShape(c); |
679 | }); |
680 | |
681 | // -------------------------------------------------------------------------- |
682 | REGISTER_OP("DrawBoundingBoxesV2" ) |
683 | .Input("images: T" ) |
684 | .Input("boxes: float" ) |
685 | .Input("colors: float" ) |
686 | .Output("output: T" ) |
687 | .Attr("T: {float, half} = DT_FLOAT" ) |
688 | .SetShapeFn([](InferenceContext* c) { |
689 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
690 | }); |
691 | |
692 | // -------------------------------------------------------------------------- |
693 | REGISTER_OP("SampleDistortedBoundingBox" ) |
694 | .Input("image_size: T" ) |
695 | .Input("bounding_boxes: float" ) |
696 | .Output("begin: T" ) |
697 | .Output("size: T" ) |
698 | .Output("bboxes: float" ) |
699 | .Attr("T: {uint8, int8, int16, int32, int64}" ) |
700 | .Attr("seed: int = 0" ) |
701 | .Attr("seed2: int = 0" ) |
702 | .Attr("min_object_covered: float = 0.1" ) |
703 | .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]" ) |
704 | .Attr("area_range: list(float) = [0.05, 1.0]" ) |
705 | .Attr("max_attempts: int = 100" ) |
706 | .Attr("use_image_if_no_bounding_boxes: bool = false" ) |
707 | .SetIsStateful() |
708 | .SetShapeFn([](InferenceContext* c) { |
709 | // Get inputs and validate ranks. |
710 | ShapeHandle image_size; |
711 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size)); |
712 | ShapeHandle bounding_boxes; |
713 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes)); |
714 | // image_size: 1-D with [height, width, channels] |
715 | // bounding_boxes: 3-D with shape [batch, N, 4] |
716 | DimensionHandle unused; |
717 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused)); |
718 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused)); |
719 | |
720 | c->set_output(0, c->Vector(3)); |
721 | c->set_output(1, c->Vector(3)); |
722 | c->set_output(2, c->MakeShape({1, 1, 4})); |
723 | return OkStatus(); |
724 | }); |
725 | |
726 | REGISTER_OP("SampleDistortedBoundingBoxV2" ) |
727 | .Input("image_size: T" ) |
728 | .Input("bounding_boxes: float" ) |
729 | .Input("min_object_covered: float" ) |
730 | .Output("begin: T" ) |
731 | .Output("size: T" ) |
732 | .Output("bboxes: float" ) |
733 | .Attr("T: {uint8, int8, int16, int32, int64}" ) |
734 | .Attr("seed: int = 0" ) |
735 | .Attr("seed2: int = 0" ) |
736 | .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]" ) |
737 | .Attr("area_range: list(float) = [0.05, 1.0]" ) |
738 | .Attr("max_attempts: int = 100" ) |
739 | .Attr("use_image_if_no_bounding_boxes: bool = false" ) |
740 | .SetIsStateful() |
741 | .SetShapeFn([](InferenceContext* c) { |
742 | // Get inputs and validate ranks. |
743 | ShapeHandle image_size; |
744 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size)); |
745 | ShapeHandle bounding_boxes; |
746 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes)); |
747 | ShapeHandle min_object_covered; |
748 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered)); |
749 | // image_size: 1-D with [height, width, channels] |
750 | // bounding_boxes: 3-D with shape [batch, N, 4] |
751 | DimensionHandle unused; |
752 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused)); |
753 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused)); |
754 | |
755 | c->set_output(0, c->Vector(3)); |
756 | c->set_output(1, c->Vector(3)); |
757 | c->set_output(2, c->MakeShape({1, 1, 4})); |
758 | return OkStatus(); |
759 | }); |
760 | |
761 | REGISTER_OP("StatelessSampleDistortedBoundingBox" ) |
762 | .Input("image_size: T" ) |
763 | .Input("bounding_boxes: float" ) |
764 | .Input("min_object_covered: float" ) |
765 | .Input("seed: Tseed" ) |
766 | .Output("begin: T" ) |
767 | .Output("size: T" ) |
768 | .Output("bboxes: float" ) |
769 | .Attr("T: {uint8, int8, int16, int32, int64}" ) |
770 | .Attr("Tseed: {int32, int64}" ) |
771 | .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]" ) |
772 | .Attr("area_range: list(float) = [0.05, 1.0]" ) |
773 | .Attr("max_attempts: int = 100" ) |
774 | .Attr("use_image_if_no_bounding_boxes: bool = false" ) |
775 | .SetShapeFn([](InferenceContext* c) { |
776 | // Get inputs and validate ranks. |
777 | ShapeHandle image_size; |
778 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size)); |
779 | ShapeHandle bounding_boxes; |
780 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes)); |
781 | ShapeHandle min_object_covered; |
782 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered)); |
783 | ShapeHandle seed; |
784 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &seed)); |
785 | // image_size: 1-D with [height, width, channels] |
786 | // bounding_boxes: 3-D with shape [batch, N, 4] |
787 | DimensionHandle unused; |
788 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused)); |
789 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused)); |
790 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); |
791 | |
792 | c->set_output(0, c->Vector(3)); |
793 | c->set_output(1, c->Vector(3)); |
794 | c->set_output(2, c->MakeShape({1, 1, 4})); |
795 | |
796 | return OkStatus(); |
797 | }); |
798 | |
799 | // -------------------------------------------------------------------------- |
800 | |
801 | // glimpse = extract_glimpse(input, size, offsets) extract the glimpse |
802 | // of size `size` centered at location `offsets` from the input tensor |
803 | // `input`. |
804 | // |
805 | // REQUIRES: input.dims() == 4 |
806 | // |
807 | REGISTER_OP("ExtractGlimpse" ) |
808 | .Input("input: float" ) |
809 | .Input("size: int32" ) |
810 | .Input("offsets: float" ) |
811 | .Output("glimpse: float" ) |
812 | .Attr("centered: bool = true" ) |
813 | .Attr("normalized: bool = true" ) |
814 | .Attr("uniform_noise: bool = true" ) |
815 | .Attr("noise: string = 'uniform'" ) |
816 | .SetShapeFn([](InferenceContext* c) { |
817 | ShapeHandle input; |
818 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
819 | ShapeHandle offsets; |
820 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets)); |
821 | |
822 | DimensionHandle batch_dim; |
823 | TF_RETURN_IF_ERROR( |
824 | c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim)); |
825 | DimensionHandle unused; |
826 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused)); |
827 | |
828 | bool uniform_noise = false; |
829 | TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise" , &uniform_noise)); |
830 | string noise; |
831 | TF_RETURN_IF_ERROR(c->GetAttr("noise" , &noise)); |
832 | if (uniform_noise && (!noise.empty() && noise != "uniform" )) { |
833 | return errors::InvalidArgument( |
834 | "The uniform_noise and noise should not be specified at the same " |
835 | "time" ); |
836 | } |
837 | |
838 | return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */, |
839 | c->Dim(input, 3)); |
840 | }); |
841 | |
842 | REGISTER_OP("ExtractGlimpseV2" ) |
843 | .Input("input: float" ) |
844 | .Input("size: int32" ) |
845 | .Input("offsets: float" ) |
846 | .Output("glimpse: float" ) |
847 | .Attr("centered: bool = true" ) |
848 | .Attr("normalized: bool = true" ) |
849 | .Attr("uniform_noise: bool = true" ) |
850 | .Attr("noise: string = 'uniform'" ) |
851 | .SetShapeFn([](InferenceContext* c) { |
852 | ShapeHandle input; |
853 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
854 | ShapeHandle offsets; |
855 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets)); |
856 | |
857 | DimensionHandle batch_dim; |
858 | TF_RETURN_IF_ERROR( |
859 | c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim)); |
860 | DimensionHandle unused; |
861 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused)); |
862 | |
863 | bool uniform_noise = false; |
864 | TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise" , &uniform_noise)); |
865 | string noise; |
866 | TF_RETURN_IF_ERROR(c->GetAttr("noise" , &noise)); |
867 | if (uniform_noise && (!noise.empty() && noise != "uniform" )) { |
868 | return errors::InvalidArgument( |
869 | "The uniform_noise and noise should not be specified at the same " |
870 | "time" ); |
871 | } |
872 | |
873 | return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */, |
874 | c->Dim(input, 3)); |
875 | }); |
876 | |
877 | // -------------------------------------------------------------------------- |
878 | |
879 | REGISTER_OP("CropAndResize" ) |
880 | .Input("image: T" ) |
881 | .Input("boxes: float" ) |
882 | .Input("box_ind: int32" ) |
883 | .Input("crop_size: int32" ) |
884 | .Output("crops: float" ) |
885 | .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}" ) |
886 | .Attr("method: {'bilinear', 'nearest'} = 'bilinear'" ) |
887 | .Attr("extrapolation_value: float = 0" ) |
888 | .SetShapeFn([](InferenceContext* c) { |
889 | // Get inputs and validate ranks. |
890 | ShapeHandle input; |
891 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
892 | ShapeHandle boxes; |
893 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes)); |
894 | ShapeHandle box_ind; |
895 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind)); |
896 | |
897 | // boxes[0] and box_ind[0] are both num_boxes. |
898 | DimensionHandle num_boxes_dim; |
899 | TF_RETURN_IF_ERROR( |
900 | c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim)); |
901 | |
902 | // boxes.dim(1) is 4. |
903 | DimensionHandle unused; |
904 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); |
905 | |
906 | return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */, |
907 | c->Dim(input, 3)); |
908 | }); |
909 | |
910 | REGISTER_OP("CropAndResizeGradImage" ) |
911 | .Input("grads: float" ) |
912 | .Input("boxes: float" ) |
913 | .Input("box_ind: int32" ) |
914 | .Input("image_size: int32" ) |
915 | .Output("output: T" ) |
916 | .Attr("T: {float, half, double}" ) |
917 | .Attr("method: {'bilinear', 'nearest'} = 'bilinear'" ) |
918 | .SetShapeFn([](InferenceContext* c) { |
919 | ShapeHandle out; |
920 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out)); |
921 | TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out)); |
922 | c->set_output(0, out); |
923 | return OkStatus(); |
924 | }); |
925 | |
926 | REGISTER_OP("CropAndResizeGradBoxes" ) |
927 | .Input("grads: float" ) |
928 | .Input("image: T" ) |
929 | .Input("boxes: float" ) |
930 | .Input("box_ind: int32" ) |
931 | .Output("output: float" ) |
932 | .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}" ) |
933 | .Attr("method: {'bilinear'} = 'bilinear'" ) |
934 | .SetShapeFn([](InferenceContext* c) { |
935 | c->set_output(0, c->input(2)); |
936 | return OkStatus(); |
937 | }); |
938 | |
939 | // -------------------------------------------------------------------------- |
940 | |
941 | REGISTER_OP("NonMaxSuppression" ) |
942 | .Input("boxes: float" ) |
943 | .Input("scores: float" ) |
944 | .Input("max_output_size: int32" ) |
945 | .Output("selected_indices: int32" ) |
946 | .Attr("iou_threshold: float = 0.5" ) |
947 | .SetShapeFn([](InferenceContext* c) { |
948 | // Get inputs and validate ranks. |
949 | ShapeHandle boxes; |
950 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); |
951 | ShapeHandle scores; |
952 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); |
953 | ShapeHandle max_output_size; |
954 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); |
955 | // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. |
956 | DimensionHandle unused; |
957 | // The boxes[0] and scores[0] are both num_boxes. |
958 | TF_RETURN_IF_ERROR( |
959 | c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); |
960 | // The boxes[1] is 4. |
961 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); |
962 | |
963 | c->set_output(0, c->Vector(c->UnknownDim())); |
964 | return OkStatus(); |
965 | }); |
966 | |
967 | REGISTER_OP("NonMaxSuppressionV2" ) |
968 | .Input("boxes: T" ) |
969 | .Input("scores: T" ) |
970 | .Input("max_output_size: int32" ) |
971 | .Input("iou_threshold: T_threshold" ) |
972 | .Output("selected_indices: int32" ) |
973 | .Attr("T: {half, float} = DT_FLOAT" ) |
974 | .Attr("T_threshold: {half, float} = DT_FLOAT" ) |
975 | .SetShapeFn([](InferenceContext* c) { |
976 | // Get inputs and validate ranks. |
977 | ShapeHandle boxes; |
978 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); |
979 | ShapeHandle scores; |
980 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); |
981 | ShapeHandle max_output_size; |
982 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); |
983 | ShapeHandle iou_threshold; |
984 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); |
985 | // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. |
986 | DimensionHandle unused; |
987 | // The boxes[0] and scores[0] are both num_boxes. |
988 | TF_RETURN_IF_ERROR( |
989 | c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); |
990 | // The boxes[1] is 4. |
991 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); |
992 | |
993 | c->set_output(0, c->Vector(c->UnknownDim())); |
994 | return OkStatus(); |
995 | }); |
996 | |
997 | REGISTER_OP("NonMaxSuppressionV3" ) |
998 | .Input("boxes: T" ) |
999 | .Input("scores: T" ) |
1000 | .Input("max_output_size: int32" ) |
1001 | .Input("iou_threshold: T_threshold" ) |
1002 | .Input("score_threshold: T_threshold" ) |
1003 | .Output("selected_indices: int32" ) |
1004 | .Attr("T: {half, float} = DT_FLOAT" ) |
1005 | .Attr("T_threshold: {half, float} = DT_FLOAT" ) |
1006 | .SetShapeFn(NMSShapeFn); |
1007 | |
1008 | REGISTER_OP("NonMaxSuppressionV4" ) |
1009 | .Input("boxes: T" ) |
1010 | .Input("scores: T" ) |
1011 | .Input("max_output_size: int32" ) |
1012 | .Input("iou_threshold: T_threshold" ) |
1013 | .Input("score_threshold: T_threshold" ) |
1014 | .Output("selected_indices: int32" ) |
1015 | .Output("valid_outputs: int32" ) |
1016 | .Attr("T: {half, float} = DT_FLOAT" ) |
1017 | .Attr("T_threshold: {half, float} = DT_FLOAT" ) |
1018 | .Attr("pad_to_max_output_size: bool = false" ) |
1019 | .SetShapeFn([](InferenceContext* c) { |
1020 | TF_RETURN_IF_ERROR(NMSShapeFn(c)); |
1021 | |
1022 | bool pad_to_max; |
1023 | TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size" , &pad_to_max)); |
1024 | if (pad_to_max) { |
1025 | // If padded, overwrite the shape of the output to be static. |
1026 | DimensionHandle output_dim; |
1027 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim)); |
1028 | c->set_output(0, c->MakeShape({output_dim})); |
1029 | } |
1030 | c->set_output(1, c->MakeShape({})); |
1031 | return OkStatus(); |
1032 | }); |
1033 | |
1034 | REGISTER_OP("NonMaxSuppressionV5" ) |
1035 | .Input("boxes: T" ) |
1036 | .Input("scores: T" ) |
1037 | .Input("max_output_size: int32" ) |
1038 | .Input("iou_threshold: T" ) |
1039 | .Input("score_threshold: T" ) |
1040 | .Input("soft_nms_sigma: T" ) |
1041 | .Output("selected_indices: int32" ) |
1042 | .Output("selected_scores: T" ) |
1043 | .Output("valid_outputs: int32" ) |
1044 | .Attr("T: {half, float} = DT_FLOAT" ) |
1045 | .Attr("pad_to_max_output_size: bool = false" ) |
1046 | .SetShapeFn([](InferenceContext* c) { |
1047 | TF_RETURN_IF_ERROR(SoftNMSShapeFn(c)); |
1048 | |
1049 | bool pad_to_max; |
1050 | TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size" , &pad_to_max)); |
1051 | if (pad_to_max) { |
1052 | // If padded, overwrite the shape of the output to be static. |
1053 | DimensionHandle output_dim; |
1054 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim)); |
1055 | c->set_output(0, c->MakeShape({output_dim})); |
1056 | c->set_output(1, c->MakeShape({output_dim})); |
1057 | } |
1058 | |
1059 | c->set_output(2, c->MakeShape({})); |
1060 | return OkStatus(); |
1061 | }); |
1062 | |
1063 | REGISTER_OP("NonMaxSuppressionWithOverlaps" ) |
1064 | .Input("overlaps: float" ) |
1065 | .Input("scores: float" ) |
1066 | .Input("max_output_size: int32" ) |
1067 | .Input("overlap_threshold: float" ) |
1068 | .Input("score_threshold: float" ) |
1069 | .Output("selected_indices: int32" ) |
1070 | .SetShapeFn([](InferenceContext* c) { |
1071 | // Get inputs and validate ranks. |
1072 | ShapeHandle overlaps; |
1073 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps)); |
1074 | ShapeHandle scores; |
1075 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); |
1076 | ShapeHandle max_output_size; |
1077 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); |
1078 | ShapeHandle overlap_threshold; |
1079 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold)); |
1080 | ShapeHandle score_threshold; |
1081 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); |
1082 | // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. |
1083 | DimensionHandle unused; |
1084 | // The boxes[0] and scores[0] are both num_boxes. |
1085 | TF_RETURN_IF_ERROR( |
1086 | c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused)); |
1087 | // The boxes[1] is 4. |
1088 | TF_RETURN_IF_ERROR( |
1089 | c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused)); |
1090 | |
1091 | c->set_output(0, c->Vector(c->UnknownDim())); |
1092 | return OkStatus(); |
1093 | }); |
1094 | |
1095 | REGISTER_OP("CombinedNonMaxSuppression" ) |
1096 | .Input("boxes: float" ) |
1097 | .Input("scores: float" ) |
1098 | .Input("max_output_size_per_class: int32" ) |
1099 | .Input("max_total_size: int32" ) |
1100 | .Input("iou_threshold: float" ) |
1101 | .Input("score_threshold: float" ) |
1102 | .Output("nmsed_boxes: float" ) |
1103 | .Output("nmsed_scores: float" ) |
1104 | .Output("nmsed_classes: float" ) |
1105 | .Output("valid_detections: int32" ) |
1106 | .Attr("pad_per_class: bool = false" ) |
1107 | .Attr("clip_boxes: bool = true" ) |
1108 | .SetShapeFn(CombinedNMSShapeFn); |
1109 | |
1110 | REGISTER_OP("GenerateBoundingBoxProposals" ) |
1111 | .Input("scores: float" ) |
1112 | .Input("bbox_deltas: float" ) |
1113 | .Input("image_info: float" ) |
1114 | .Input("anchors: float" ) |
1115 | .Input("nms_threshold: float" ) |
1116 | .Input("pre_nms_topn: int32" ) |
1117 | .Input("min_size: float" ) |
1118 | .Output("rois: float" ) |
1119 | .Output("roi_probabilities: float" ) |
1120 | .Attr("post_nms_topn: int = 300" ) |
1121 | .SetShapeFn([](InferenceContext* c) -> Status { |
1122 | // make sure input tensors have are correct rank |
1123 | ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold, |
1124 | n_pre_nms, min_box_size; |
1125 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores)); //(N, H, W, A) |
1126 | TF_RETURN_IF_ERROR( |
1127 | c->WithRank(c->input(1), 4, &bounding_boxes)); //(N,H,W,A4) |
1128 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &images)); // (N,5) |
1129 | auto im_info = c->Dim(images, 1); |
1130 | TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info)); |
1131 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors)); // (A4) |
1132 | // check scalar tensors |
1133 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &nms_threshold)); |
1134 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &n_pre_nms)); |
1135 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &min_box_size)); |
1136 | |
1137 | // TODO(skama): verify that the inputs are compatible |
1138 | int post_nms_top_n; |
1139 | TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn" , &post_nms_top_n)); |
1140 | auto roi_shape = c->MakeShape( |
1141 | {c->Dim(scores, 0), post_nms_top_n, 4}); //(N,post_nms_top_n,4) |
1142 | auto prob_shape = c->MakeShape( |
1143 | {c->Dim(scores, 0), post_nms_top_n}); // (N,post_nms_top_n) |
1144 | c->set_output(0, roi_shape); |
1145 | c->set_output(1, prob_shape); |
1146 | return OkStatus(); |
1147 | }); |
1148 | |
1149 | // V3 op supports fill_value. |
1150 | // V2 op supports output_shape. |
1151 | // V1 op is in contrib. |
1152 | REGISTER_OP("ImageProjectiveTransformV2" ) |
1153 | .Input("images: dtype" ) |
1154 | .Input("transforms: float32" ) |
1155 | .Input("output_shape: int32" ) |
1156 | .Attr("dtype: {uint8, int32, int64, float16, float32, float64}" ) |
1157 | .Attr("interpolation: string" ) |
1158 | .Attr("fill_mode: string = 'CONSTANT'" ) |
1159 | .Output("transformed_images: dtype" ) |
1160 | .SetShapeFn([](InferenceContext* c) { |
1161 | ShapeHandle input; |
1162 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
1163 | return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */, |
1164 | c->Dim(input, 3)); |
1165 | }); |
1166 | |
1167 | REGISTER_OP("ImageProjectiveTransformV3" ) |
1168 | .Input("images: dtype" ) |
1169 | .Input("transforms: float32" ) |
1170 | .Input("output_shape: int32" ) |
1171 | .Input("fill_value: float32" ) |
1172 | .Attr("dtype: {uint8, int32, int64, float16, float32, float64}" ) |
1173 | .Attr("interpolation: string" ) |
1174 | .Attr("fill_mode: string = 'CONSTANT'" ) |
1175 | .Output("transformed_images: dtype" ) |
1176 | .SetShapeFn([](InferenceContext* c) { |
1177 | ShapeHandle input; |
1178 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
1179 | return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */, |
1180 | c->Dim(input, 3)); |
1181 | }); |
1182 | |
1183 | } // namespace tensorflow |
1184 | |