1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/function.h" |
17 | #include "tensorflow/core/lib/core/errors.h" |
18 | #include "tensorflow/core/util/padding.h" |
19 | #include "tensorflow/core/util/tensor_format.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | typedef FunctionDefHelper FDH; |
24 | |
25 | Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { |
26 | // clang-format off |
27 | *g = FDH::Define( |
28 | "SoftmaxGrad" , |
29 | // Arg defs |
30 | {"x: T" , "grad_softmax: T" }, |
31 | // Ret val defs |
32 | {"grad_x: T" }, |
33 | // Attr defs |
34 | {{"T: {float, double, bfloat16}" }}, |
35 | // Nodes |
36 | // Based on _SoftmaxGrad in nn_grad.py. |
37 | { |
38 | {{"softmax" }, "Softmax" , {"x" }, {{"T" , "$T" }}}, |
39 | {{"n0" }, "Mul" , {"grad_softmax" , "softmax" }, {{"T" , "$T" }}}, |
40 | FDH::Const<int32>("indices" , {-1}), |
41 | {{"n1" }, "Sum" , {"n0" , "indices" }, {{"keep_dims" , true}, {"T" , "$T" }}}, |
42 | {{"n2" }, "Sub" , {"grad_softmax" , "n1" }, {{"T" , "$T" }}}, |
43 | {{"grad_x" }, "Mul" , {"n2" , "softmax" }, {{"T" , "$T" }}} |
44 | }); |
45 | // clang-format on |
46 | return OkStatus(); |
47 | } |
48 | REGISTER_OP_GRADIENT("Softmax" , SoftmaxGrad); |
49 | |
50 | Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { |
51 | // clang-format off |
52 | *g = FDH::Define( |
53 | "LogSoftmaxGrad" , |
54 | // Arg defs |
55 | {"x: T" , "grad_logsoftmax: T" }, |
56 | // Ret val defs |
57 | {"grad_x: T" }, |
58 | // Attr defs |
59 | {{"T: {float, double}" }}, |
60 | // Nodes |
61 | // Based on _LogSoftmaxGrad in nn_grad.py. |
62 | { |
63 | {{"softmax" }, "Softmax" , {"x" }, {{"T" , "$T" }}}, |
64 | FDH::Const<int32>("indices" , {-1}), |
65 | {{"n0" }, "Sum" , {"grad_logsoftmax" , "indices" }, |
66 | {{"keep_dims" , true}, {"T" , "$T" }}}, |
67 | {{"n1" }, "Mul" , {"n0" , "softmax" }, {{"T" , "$T" }}}, |
68 | {{"grad_x" }, "Sub" , {"grad_logsoftmax" , "n1" }, {{"T" , "$T" }}} |
69 | }); |
70 | // clang-format on |
71 | return OkStatus(); |
72 | } |
73 | REGISTER_OP_GRADIENT("LogSoftmax" , LogSoftmaxGrad); |
74 | |
75 | Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { |
76 | // clang-format off |
77 | *g = FDH::Define( |
78 | // Arg defs |
79 | {"x: T" , "dy: T" }, |
80 | // Ret val defs |
81 | {"dx: T" }, |
82 | // Attr defs |
83 | {{"T: {float, double}" }}, |
84 | // Nodes |
85 | { |
86 | {{"dx" }, "ReluGrad" , {"dy" , "x" }, {{"T" , "$T" }}} |
87 | }); |
88 | // clang-format on |
89 | return OkStatus(); |
90 | } |
91 | REGISTER_OP_GRADIENT("Relu" , ReluGrad); |
92 | |
93 | Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) { |
94 | // clang-format off |
95 | *g = FDH::Define( |
96 | // Arg defs |
97 | {"x: T" , "dy: T" }, |
98 | // Ret val defs |
99 | {"dx: T" }, |
100 | // Attr defs |
101 | {{"T: {float, double}" }}, |
102 | // Nodes |
103 | { |
104 | {{"dx" }, "Relu6Grad" , {"dy" , "x" }, {{"T" , "$T" }}} |
105 | }); |
106 | // clang-format on |
107 | return OkStatus(); |
108 | } |
109 | REGISTER_OP_GRADIENT("Relu6" , Relu6Grad); |
110 | |
111 | Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { |
112 | // clang-format off |
113 | *g = FDH::Define( |
114 | // Arg defs |
115 | {"features: T" , "labels: T" , "dcost_dloss: T" , "donotcare: T" }, |
116 | // Ret val defs |
117 | {"dcost_dfeatures: T" , "dcost_dlabels: T" }, |
118 | // Attr defs |
119 | {{"T: {float, double}" }}, |
120 | // Nodes |
121 | { |
122 | // _, dloss_dfeatures = CrossEntropy(features, labels) |
123 | {{"donotcare_loss" , "dloss_dfeatures" }, "CrossEntropy" , |
124 | {"features" , "labels" }, {{"T" , "$T" }}}, |
125 | // dcost_dloss is of shape [batch_size]. |
126 | // dcost_dloss_mat is of shape [batch_size, 1]. |
127 | FDH::Const("neg1" , -1), |
128 | {{"dcost_dloss_mat" }, "ExpandDims" , {"dcost_dloss" , "neg1" }, |
129 | {{"T" , "$T" }}}, |
130 | // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures |
131 | {{"dcost_dfeatures" }, "Mul" , {"dcost_dloss_mat" , "dloss_dfeatures" }, |
132 | {{"T" , "$T" }}}, |
133 | {{"dcost_dlabels" }, "ZerosLike" , {"labels" }, {{"T" , "$T" }}}, |
134 | }); |
135 | // clang-format on |
136 | return OkStatus(); |
137 | } |
138 | REGISTER_OP_GRADIENT("CrossEntropy" , CrossEntropyGrad); |
139 | |
140 | Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) { |
141 | // clang-format off |
142 | *g = FDH::Define( |
143 | // Arg defs |
144 | {"input: T" , "filter: T" , "grad: T" }, |
145 | // Ret val defs |
146 | {"input_grad: T" , "filter_grad: T" }, |
147 | // Attr defs |
148 | {"T: {float, double}" , |
149 | "strides: list(int)" , |
150 | "use_cudnn_on_gpu: bool = true" , |
151 | GetPaddingAttrString(), |
152 | GetConvnetDataFormatAttrString()}, |
153 | // Nodes |
154 | { |
155 | {{"i_shape" }, "Shape" , {"input" }, {{"T" , "$T" }}}, |
156 | {{"input_grad" }, "Conv2DBackpropInput" , {"i_shape" , "filter" , "grad" }, |
157 | /*Attrs=*/{{"T" , "$T" }, |
158 | {"strides" , "$strides" }, |
159 | {"padding" , "$padding" }, |
160 | {"data_format" , "$data_format" }, |
161 | {"use_cudnn_on_gpu" , "$use_cudnn_on_gpu" }}}, |
162 | |
163 | {{"f_shape" }, "Shape" , {"filter" }, {{"T" , "$T" }}}, |
164 | {{"filter_grad" }, "Conv2DBackpropFilter" , {"input" , "f_shape" , "grad" }, |
165 | /*Attrs=*/{{"T" , "$T" }, |
166 | {"strides" , "$strides" }, |
167 | {"padding" , "$padding" }, |
168 | {"data_format" , "$data_format" }, |
169 | {"use_cudnn_on_gpu" , "$use_cudnn_on_gpu" }}}, |
170 | }); |
171 | // clang-format on |
172 | return OkStatus(); |
173 | } |
174 | REGISTER_OP_GRADIENT("Conv2D" , Conv2DGrad); |
175 | |
176 | Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { |
177 | // clang-format off |
178 | *g = FDH::Define( |
179 | // Arg defs |
180 | {"input: T" , "grad: T" }, |
181 | // Ret val defs |
182 | {"output: T" }, |
183 | // Attr defs |
184 | {"T: {float, half} = DT_FLOAT" , |
185 | "ksize: list(int) >= 4" , |
186 | "strides: list(int) >= 4" , |
187 | GetPaddingAttrString()}, |
188 | // Nodes |
189 | { |
190 | // Invoke MaxPool again to recompute the outputs (removed by CSE?). |
191 | {{"maxpool" }, "MaxPool" , {"input" }, |
192 | /*Attrs=*/{{"T" , "$T" }, |
193 | {"ksize" , "$ksize" }, |
194 | {"strides" , "$strides" }, |
195 | {"padding" , "$padding" }}}, |
196 | {{"output" }, "MaxPoolGrad" , {"input" , "maxpool" , "grad" }, |
197 | /*Attrs=*/{{"T" , "$T" }, |
198 | {"ksize" , "$ksize" }, |
199 | {"strides" , "$strides" }, |
200 | {"padding" , "$padding" }}} |
201 | }); |
202 | // clang-format on |
203 | return OkStatus(); |
204 | } |
205 | REGISTER_OP_GRADIENT("MaxPool" , MaxPoolGrad); |
206 | |
207 | Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) { |
208 | // clang-format off |
209 | *g = FDH::Define( |
210 | // Arg defs |
211 | {"input: T" , "grad: T" }, |
212 | // Ret val defs |
213 | {"output: T" }, |
214 | // Attr defs |
215 | {"T: {float, half} = DT_FLOAT" , |
216 | "ksize: list(int) >= 4" , |
217 | "strides: list(int) >= 4" , |
218 | GetPaddingAttrString()}, |
219 | // Nodes |
220 | { |
221 | {{"i_shape" }, "Shape" , {"input" }, {{"T" , "$T" }}}, |
222 | {{"output" }, "AvgPoolGrad" , {"i_shape" , "grad" }, |
223 | /*Attrs=*/{{"T" , "$T" }, |
224 | {"ksize" , "$ksize" }, |
225 | {"strides" , "$strides" }, |
226 | {"padding" , "$padding" }}} |
227 | }); |
228 | // clang-format on |
229 | return OkStatus(); |
230 | } |
231 | REGISTER_OP_GRADIENT("AvgPool" , AvgPoolGrad); |
232 | |
233 | Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) { |
234 | // clang-format off |
235 | *g = FDH::Define( |
236 | // Arg defs |
237 | {"input: T" , "grad: T" }, |
238 | // Ret val defs |
239 | {"output: T" }, |
240 | // Attr defs |
241 | {"T: {float, half} = DT_FLOAT" , |
242 | "ksize: list(int) >= 4" , |
243 | "strides: list(int) >= 4" , |
244 | GetPaddingAttrString()}, |
245 | // Nodes |
246 | { |
247 | // Invoke MaxPool again to recompute the outputs (removed by CSE?). |
248 | {{"maxpool" }, "MaxPool" , {"input" }, |
249 | /*Attrs=*/{{"T" , "$T" }, |
250 | {"ksize" , "$ksize" }, |
251 | {"strides" , "$strides" }, |
252 | {"padding" , "$padding" }}}, |
253 | {{"output" }, "MaxPoolGradGrad" , {"input" , "maxpool" , "grad" }, |
254 | /*Attrs=*/{{"T" , "$T" }, |
255 | {"ksize" , "$ksize" }, |
256 | {"strides" , "$strides" }, |
257 | {"padding" , "$padding" }}} |
258 | }); |
259 | // clang-format on |
260 | return OkStatus(); |
261 | } |
262 | REGISTER_OP_GRADIENT("MaxPoolGrad" , MaxPoolGradGrad); |
263 | |
264 | Status BiasAddGrad(const AttrSlice& attrs, FunctionDef* g) { |
265 | // clang-format off |
266 | *g = FDH::Define( |
267 | // Arg defs |
268 | {"input: T" , "bias: T" , "grad: T" }, |
269 | // Ret val defs |
270 | {"grad: T" , "bias_grad: T" }, |
271 | // Attr defs |
272 | {{"T: {float, double}" }, |
273 | GetConvnetDataFormatAttrString()}, |
274 | // Nodes |
275 | { |
276 | {{"bias_grad" }, "BiasAddGrad" , {"grad" }, |
277 | /*Attrs=*/{{"T" , "$T" }, |
278 | {"data_format" , "$data_format" }}} |
279 | }); |
280 | // clang-format on |
281 | return OkStatus(); |
282 | } |
283 | REGISTER_OP_GRADIENT("BiasAdd" , BiasAddGrad); |
284 | |
285 | } // end namespace tensorflow |
286 | |