1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | #include "tensorflow/cc/framework/grad_op_registry.h" |
18 | #include "tensorflow/cc/framework/gradients.h" |
19 | #include "tensorflow/cc/ops/image_ops_internal.h" |
20 | #include "tensorflow/cc/ops/standard_ops.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace ops { |
24 | namespace { |
25 | |
26 | REGISTER_NO_GRADIENT_OP("NonMaxSuppression" ); |
27 | REGISTER_NO_GRADIENT_OP("NonMaxSuppressionV2" ); |
28 | REGISTER_NO_GRADIENT_OP("NonMaxSuppressionV3" ); |
29 | REGISTER_NO_GRADIENT_OP("NonMaxSuppressionV4" ); |
30 | REGISTER_NO_GRADIENT_OP("NonMaxSuppressionV5" ); |
31 | |
32 | Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, |
33 | const std::vector<Output>& grad_inputs, |
34 | std::vector<Output>* grad_outputs) { |
35 | bool align_corners; |
36 | TF_RETURN_IF_ERROR( |
37 | GetNodeAttr(op.node()->attrs(), "align_corners" , &align_corners)); |
38 | bool half_pixel_centers; |
39 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers" , |
40 | &half_pixel_centers)); |
41 | // The internal gradient implementation needs the shape of the input image. |
42 | // x_shape = shape(x)[1:3] |
43 | // = slice(shape(x), {1}, {3 - 1}) |
44 | auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); |
45 | grad_outputs->push_back(internal::ResizeNearestNeighborGrad( |
46 | scope, grad_inputs[0], x_shape, |
47 | internal::ResizeNearestNeighborGrad::AlignCorners(align_corners) |
48 | .HalfPixelCenters(half_pixel_centers))); |
49 | grad_outputs->push_back(NoGradient()); |
50 | return scope.status(); |
51 | } |
52 | REGISTER_GRADIENT_OP("ResizeNearestNeighbor" , ResizeNearestNeighborGradHelper); |
53 | |
54 | Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, |
55 | const std::vector<Output>& grad_inputs, |
56 | std::vector<Output>* grad_outputs) { |
57 | bool align_corners; |
58 | TF_RETURN_IF_ERROR( |
59 | GetNodeAttr(op.node()->attrs(), "align_corners" , &align_corners)); |
60 | bool half_pixel_centers; |
61 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers" , |
62 | &half_pixel_centers)); |
63 | grad_outputs->push_back(internal::ResizeBilinearGrad( |
64 | scope, grad_inputs[0], op.input(0), |
65 | internal::ResizeBilinearGrad::AlignCorners(align_corners) |
66 | .HalfPixelCenters(half_pixel_centers))); |
67 | grad_outputs->push_back(NoGradient()); |
68 | return scope.status(); |
69 | } |
70 | REGISTER_GRADIENT_OP("ResizeBilinear" , ResizeBilinearGradHelper); |
71 | |
72 | Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, |
73 | const std::vector<Output>& grad_inputs, |
74 | std::vector<Output>* grad_outputs) { |
75 | bool align_corners; |
76 | TF_RETURN_IF_ERROR( |
77 | GetNodeAttr(op.node()->attrs(), "align_corners" , &align_corners)); |
78 | bool half_pixel_centers; |
79 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers" , |
80 | &half_pixel_centers)); |
81 | |
82 | grad_outputs->push_back(internal::ResizeBicubicGrad( |
83 | scope, grad_inputs[0], op.input(0), |
84 | internal::ResizeBicubicGrad::AlignCorners(align_corners) |
85 | .HalfPixelCenters(half_pixel_centers))); |
86 | grad_outputs->push_back(NoGradient()); |
87 | return scope.status(); |
88 | } |
89 | REGISTER_GRADIENT_OP("ResizeBicubic" , ResizeBicubicGradHelper); |
90 | |
91 | Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, |
92 | const std::vector<Output>& grad_inputs, |
93 | std::vector<Output>* grad_outputs) { |
94 | string kernel_type; |
95 | TF_RETURN_IF_ERROR( |
96 | GetNodeAttr(op.node()->attrs(), "kernel_type" , &kernel_type)); |
97 | bool antialias; |
98 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias" , &antialias)); |
99 | grad_outputs->push_back(internal::ScaleAndTranslateGrad( |
100 | scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), |
101 | internal::ScaleAndTranslateGrad::KernelType(kernel_type) |
102 | .Antialias(antialias))); |
103 | |
104 | grad_outputs->push_back(NoGradient()); |
105 | grad_outputs->push_back(NoGradient()); |
106 | grad_outputs->push_back(NoGradient()); |
107 | return scope.status(); |
108 | } |
109 | |
110 | REGISTER_GRADIENT_OP("ScaleAndTranslate" , ScaleAndTranslateGradHelper); |
111 | |
112 | Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, |
113 | const std::vector<Output>& grad_inputs, |
114 | std::vector<Output>* grad_outputs) { |
115 | DataType input_type; |
116 | string method; |
117 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method" , &method)); |
118 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T" , &input_type)); |
119 | auto image_shape = Shape(scope, op.input(0)); |
120 | grad_outputs->push_back(CropAndResizeGradImage( |
121 | scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type, |
122 | CropAndResizeGradImage::Method(method))); |
123 | grad_outputs->push_back(CropAndResizeGradBoxes( |
124 | scope, grad_inputs[0], op.input(0), op.input(1), op.input(2))); |
125 | grad_outputs->push_back(NoGradient()); |
126 | grad_outputs->push_back(NoGradient()); |
127 | return scope.status(); |
128 | } |
129 | |
130 | REGISTER_GRADIENT_OP("CropAndResize" , CropAndResizeGradHelper); |
131 | } // anonymous namespace |
132 | } // namespace ops |
133 | } // namespace tensorflow |
134 | |