1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/ops/nn_ops.h" |
17 | #include "tensorflow/cc/ops/nn_ops_internal.h" |
18 | #include "tensorflow/cc/ops/standard_ops.h" |
19 | |
20 | #include "tensorflow/cc/framework/grad_op_registry.h" |
21 | #include "tensorflow/cc/framework/gradients.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace ops { |
25 | namespace { |
26 | |
27 | Status SoftmaxGrad(const Scope& scope, const Operation& op, |
28 | const std::vector<Output>& grad_inputs, |
29 | std::vector<Output>* grad_outputs) { |
30 | // Softmax gradient function. |
31 | // p = softmax(x) maps from [batch, n] to [batch, m] |
32 | // dp/dx = [dp0/dx0 ... dp0/dxn-1 ] |
33 | // [ ... ... ] |
34 | // [dpm-1/dx0 ... dpm-1/dxn-1] |
35 | // dL/dx = dp/dx * dL/dy |
36 | // |
37 | // Using alternative formula: |
38 | // dL/dx = dL/dy * y - sum(dL/dy * y) * y |
39 | // = (dL/dy - sum(dL/dy * y)) * y |
40 | auto y = op.output(0); |
41 | auto dyy = Mul(scope, grad_inputs[0], y); |
42 | auto sum = Sum(scope, dyy, /*axis=*/-1, Sum::KeepDims(true)); |
43 | auto sub = Sub(scope, grad_inputs[0], sum); |
44 | auto dx = Mul(scope, sub, y); |
45 | grad_outputs->push_back(dx); |
46 | return scope.status(); |
47 | } |
48 | REGISTER_GRADIENT_OP("Softmax" , SoftmaxGrad); |
49 | |
50 | bool IsZero(const Scope& scope, const Output& grad) { |
51 | string op_type_name = grad.op().node()->type_string(); |
52 | if (op_type_name == "ZerosLike" || op_type_name == "Zeros" ) { |
53 | return true; |
54 | } |
55 | // The Operation we were provided is not named something obvious so |
56 | // we need to actually look at its contents. |
57 | // The original python code did this by calling a utility function called |
58 | // tensor_util.constant_value. |
59 | // There is no C++ equivalent to tensor_util.constant_value so we do nothing |
60 | // for the moment. |
61 | return false; |
62 | } |
63 | |
64 | // Multiply after broadcasting vec to match dimensions of mat. |
65 | // Args: |
66 | // vec: A 1-D tensor of dimension [D0] |
67 | // mat: A 2-D tensor of dimension [D0, D1] |
68 | // |
69 | // Returns: |
70 | // A tensor of dimension [D0, D1], the result for vec * mat. |
71 | Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) { |
72 | auto reshaped = ExpandDims(scope, vec, -1); |
73 | return Multiply(scope, reshaped, mat); |
74 | } |
75 | |
76 | Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope, |
77 | const Operation& op, |
78 | const std::vector<Output>& grad_inputs, |
79 | std::vector<Output>* grad_outputs) { |
80 | // Softmax gradient with cross entropy logits function. |
81 | // We multiply the backprop for cost with the gradients - op.output[1]. |
82 | // There is no gradient for labels. |
83 | |
84 | // The outputs of the network are at input index 0. |
85 | auto logits = op.input(0); |
86 | // The "truth" labels are at index 1. |
87 | auto softmax_grad = op.output(1); |
88 | |
89 | // The loss is the output at index 0, and backprop is the output at index 1. |
90 | auto grad_loss = grad_inputs[0]; |
91 | auto grad_grad = grad_inputs[1]; |
92 | |
93 | auto grad = BroadcastMul(scope, grad_loss, softmax_grad); |
94 | if (!IsZero(scope, grad_grad)) { |
95 | std::vector<int> axis; |
96 | auto logits_softmax = Softmax(scope, logits); |
97 | |
98 | auto grad_grad_expand = ExpandDims(scope, grad_grad, 1); |
99 | auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2); |
100 | auto matmul_result = |
101 | BatchMatMul(scope, grad_grad_expand, logits_softmax_expand); |
102 | axis.push_back(1); |
103 | auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis)); |
104 | auto subtraction_result = Subtract(scope, grad_grad, squeeze_result); |
105 | auto multiply_result = Multiply(scope, subtraction_result, logits_softmax); |
106 | grad = Add(scope, grad, multiply_result); |
107 | } |
108 | auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f); |
109 | grad_outputs->push_back(grad); |
110 | grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax)); |
111 | return scope.status(); |
112 | } |
113 | REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits" , |
114 | SoftmaxCrossEntropyWithLogitsGrad); |
115 | |
116 | Status LogSoftmaxGrad(const Scope& scope, const Operation& op, |
117 | const std::vector<Output>& grad_inputs, |
118 | std::vector<Output>* grad_outputs) { |
119 | auto softmax = Exp(scope, op.output(0)); |
120 | auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); |
121 | auto mul = Mul(scope, sum, softmax); |
122 | auto dx = Sub(scope, grad_inputs[0], mul); |
123 | grad_outputs->push_back(dx); |
124 | return scope.status(); |
125 | } |
126 | REGISTER_GRADIENT_OP("LogSoftmax" , LogSoftmaxGrad); |
127 | |
128 | Status ReluGradHelper(const Scope& scope, const Operation& op, |
129 | const std::vector<Output>& grad_inputs, |
130 | std::vector<Output>* grad_outputs) { |
131 | auto dx = internal::ReluGrad(scope, grad_inputs[0], op.input(0)); |
132 | grad_outputs->push_back(dx); |
133 | return scope.status(); |
134 | } |
135 | REGISTER_GRADIENT_OP("Relu" , ReluGradHelper); |
136 | |
137 | Status Relu6GradHelper(const Scope& scope, const Operation& op, |
138 | const std::vector<Output>& grad_inputs, |
139 | std::vector<Output>* grad_outputs) { |
140 | auto dx = internal::Relu6Grad(scope, grad_inputs[0], op.input(0)); |
141 | grad_outputs->push_back(dx); |
142 | return scope.status(); |
143 | } |
144 | REGISTER_GRADIENT_OP("Relu6" , Relu6GradHelper); |
145 | |
146 | Status LeakyReluGradHelper(const Scope& scope, const Operation& op, |
147 | const std::vector<Output>& grad_inputs, |
148 | std::vector<Output>* grad_outputs) { |
149 | float alpha; |
150 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha" , &alpha)); |
151 | internal::LeakyReluGrad::Attrs attrs; |
152 | auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), |
153 | attrs.Alpha(alpha)); |
154 | grad_outputs->push_back(dx); |
155 | return scope.status(); |
156 | } |
157 | REGISTER_GRADIENT_OP("LeakyRelu" , LeakyReluGradHelper); |
158 | |
159 | Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, |
160 | const std::vector<Output>& grad_inputs, |
161 | std::vector<Output>* grad_outputs) { |
162 | float alpha; |
163 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha" , &alpha)); |
164 | internal::LeakyReluGrad::Attrs attrs; |
165 | auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), |
166 | attrs.Alpha(alpha)); |
167 | grad_outputs->push_back(dx); |
168 | grad_outputs->push_back(NoGradient()); |
169 | return scope.status(); |
170 | } |
171 | REGISTER_GRADIENT_OP("LeakyReluGrad" , LeakyReluGradGradHelper); |
172 | |
173 | Status EluGradHelper(const Scope& scope, const Operation& op, |
174 | const std::vector<Output>& grad_inputs, |
175 | std::vector<Output>* grad_outputs) { |
176 | auto dx = internal::EluGrad(scope, grad_inputs[0], op.output(0)); |
177 | grad_outputs->push_back(dx); |
178 | return scope.status(); |
179 | } |
180 | REGISTER_GRADIENT_OP("Elu" , EluGradHelper); |
181 | |
182 | Status SeluGradHelper(const Scope& scope, const Operation& op, |
183 | const std::vector<Output>& grad_inputs, |
184 | std::vector<Output>* grad_outputs) { |
185 | auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0)); |
186 | grad_outputs->push_back(dx); |
187 | return scope.status(); |
188 | } |
189 | REGISTER_GRADIENT_OP("Selu" , SeluGradHelper); |
190 | |
191 | Status L2LossGrad(const Scope& scope, const Operation& op, |
192 | const std::vector<Output>& grad_inputs, |
193 | std::vector<Output>* grad_outputs) { |
194 | grad_outputs->push_back(Mul(scope, op.input(0), grad_inputs[0])); |
195 | return scope.status(); |
196 | } |
197 | REGISTER_GRADIENT_OP("L2Loss" , L2LossGrad); |
198 | |
199 | Status BiasAddGradHelper(const Scope& scope, const Operation& op, |
200 | const std::vector<Output>& grad_inputs, |
201 | std::vector<Output>* grad_outputs) { |
202 | string data_format; |
203 | TF_RETURN_IF_ERROR( |
204 | GetNodeAttr(op.output(0).node()->attrs(), "data_format" , &data_format)); |
205 | auto dx_1 = |
206 | BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format)); |
207 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
208 | grad_outputs->push_back(dx_1); |
209 | return scope.status(); |
210 | } |
211 | REGISTER_GRADIENT_OP("BiasAdd" , BiasAddGradHelper); |
212 | |
213 | Status Conv2DGrad(const Scope& scope, const Operation& op, |
214 | const std::vector<Output>& grad_inputs, |
215 | std::vector<Output>* grad_outputs) { |
216 | string data_format; |
217 | string padding; |
218 | std::vector<int32> strides; |
219 | bool use_cudnn_on_gpu; |
220 | auto attrs = op.output(0).node()->attrs(); |
221 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
222 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
223 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides" , &strides)); |
224 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu" , &use_cudnn_on_gpu)); |
225 | auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1), |
226 | grad_inputs[0], strides, padding, |
227 | Conv2DBackpropInput::DataFormat(data_format) |
228 | .UseCudnnOnGpu(use_cudnn_on_gpu)); |
229 | grad_outputs->push_back(dx_1); |
230 | auto dx_2 = |
231 | Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)), |
232 | grad_inputs[0], strides, padding, |
233 | Conv2DBackpropFilter::DataFormat(data_format) |
234 | .UseCudnnOnGpu(use_cudnn_on_gpu)); |
235 | grad_outputs->push_back(dx_2); |
236 | return scope.status(); |
237 | } |
238 | REGISTER_GRADIENT_OP("Conv2D" , Conv2DGrad); |
239 | |
240 | Status MaxPoolGradHelper(const Scope& scope, const Operation& op, |
241 | const std::vector<Output>& grad_inputs, |
242 | std::vector<Output>* grad_outputs) { |
243 | string data_format; |
244 | string padding; |
245 | std::vector<int32> strides; |
246 | std::vector<int32> ksize; |
247 | auto attrs = op.output(0).node()->attrs(); |
248 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
249 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize" , &ksize)); |
250 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
251 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides" , &strides)); |
252 | auto dx = internal::MaxPoolGrad( |
253 | scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding, |
254 | internal::MaxPoolGrad::DataFormat(data_format)); |
255 | grad_outputs->push_back(dx); |
256 | return scope.status(); |
257 | } |
258 | REGISTER_GRADIENT_OP("MaxPool" , MaxPoolGradHelper); |
259 | |
260 | Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, |
261 | const std::vector<Output>& grad_inputs, |
262 | std::vector<Output>* grad_outputs) { |
263 | string data_format; |
264 | string padding; |
265 | auto attrs = op.output(0).node()->attrs(); |
266 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
267 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
268 | auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0], |
269 | op.input(1), op.input(2), padding, |
270 | MaxPoolGradV2::DataFormat(data_format)); |
271 | grad_outputs->push_back(dx); |
272 | grad_outputs->push_back(NoGradient()); |
273 | grad_outputs->push_back(NoGradient()); |
274 | return scope.status(); |
275 | } |
276 | REGISTER_GRADIENT_OP("MaxPoolV2" , MaxPoolGradV2Helper); |
277 | |
278 | Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, |
279 | const std::vector<Output>& grad_inputs, |
280 | std::vector<Output>* grad_outputs) { |
281 | std::vector<int32> ksize; |
282 | std::vector<int32> strides; |
283 | string padding; |
284 | string data_format; |
285 | auto attrs = op.output(0).node()->attrs(); |
286 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize" , &ksize)); |
287 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides" , &strides)); |
288 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
289 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
290 | MaxPool3DGrad::Attrs grad_attrs; |
291 | auto dx = |
292 | MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize, |
293 | strides, padding, grad_attrs.DataFormat(data_format)); |
294 | grad_outputs->push_back(dx); |
295 | return scope.status(); |
296 | } |
297 | REGISTER_GRADIENT_OP("MaxPool3D" , MaxPool3DGradHelper); |
298 | |
299 | Status AvgPoolGradHelper(const Scope& scope, const Operation& op, |
300 | const std::vector<Output>& grad_inputs, |
301 | std::vector<Output>* grad_outputs) { |
302 | std::vector<int32> ksize; |
303 | std::vector<int32> strides; |
304 | string padding; |
305 | string data_format; |
306 | auto attrs = op.output(0).node()->attrs(); |
307 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize" , &ksize)); |
308 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides" , &strides)); |
309 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
310 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
311 | internal::AvgPoolGrad::Attrs grad_attrs; |
312 | auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), |
313 | grad_inputs[0], ksize, strides, padding, |
314 | grad_attrs.DataFormat(data_format)); |
315 | grad_outputs->push_back(dx); |
316 | return scope.status(); |
317 | } |
318 | REGISTER_GRADIENT_OP("AvgPool" , AvgPoolGradHelper); |
319 | |
320 | Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, |
321 | const std::vector<Output>& grad_inputs, |
322 | std::vector<Output>* grad_outputs) { |
323 | std::vector<int32> ksize; |
324 | std::vector<int32> strides; |
325 | string padding; |
326 | string data_format; |
327 | auto attrs = op.output(0).node()->attrs(); |
328 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize" , &ksize)); |
329 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides" , &strides)); |
330 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding" , &padding)); |
331 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format" , &data_format)); |
332 | AvgPool3DGrad::Attrs grad_attrs; |
333 | auto dx = |
334 | AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize, |
335 | strides, padding, grad_attrs.DataFormat(data_format)); |
336 | grad_outputs->push_back(dx); |
337 | return scope.status(); |
338 | } |
339 | REGISTER_GRADIENT_OP("AvgPool3D" , AvgPool3DGradHelper); |
340 | |
341 | Status LRNGradHelper(const Scope& scope, const Operation& op, |
342 | const std::vector<Output>& grad_inputs, |
343 | std::vector<Output>* grad_outputs) { |
344 | auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0)); |
345 | grad_outputs->push_back(dx); |
346 | return scope.status(); |
347 | } |
348 | REGISTER_GRADIENT_OP("LRN" , LRNGradHelper); |
349 | |
350 | Status SoftplusGradHelper(const Scope& scope, const Operation& op, |
351 | const std::vector<Output>& grad_inputs, |
352 | std::vector<Output>* grad_outputs) { |
353 | auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); |
354 | grad_outputs->push_back(dx); |
355 | return scope.status(); |
356 | } |
357 | REGISTER_GRADIENT_OP("Softplus" , SoftplusGradHelper); |
358 | |
359 | Status SoftsignGradHelper(const Scope& scope, const Operation& op, |
360 | const std::vector<Output>& grad_inputs, |
361 | std::vector<Output>* grad_outputs) { |
362 | auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); |
363 | grad_outputs->push_back(dx); |
364 | return scope.status(); |
365 | } |
366 | REGISTER_GRADIENT_OP("Softsign" , SoftsignGradHelper); |
367 | |
368 | Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, |
369 | const std::vector<Output>& grad_inputs, |
370 | std::vector<Output>* grad_outputs) { |
371 | bool overlapping; |
372 | TF_RETURN_IF_ERROR( |
373 | GetNodeAttr(op.output(0).node()->attrs(), "overlapping" , &overlapping)); |
374 | auto dx = internal::FractionalAvgPoolGrad( |
375 | scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), |
376 | grad_inputs[0], op.output(1), op.output(2), |
377 | internal::FractionalAvgPoolGrad::Overlapping(overlapping)); |
378 | grad_outputs->push_back(dx); |
379 | return scope.status(); |
380 | } |
381 | REGISTER_GRADIENT_OP("FractionalAvgPool" , FractionalAvgPoolGradHelper); |
382 | |
383 | Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, |
384 | const std::vector<Output>& grad_inputs, |
385 | std::vector<Output>* grad_outputs) { |
386 | bool overlapping; |
387 | TF_RETURN_IF_ERROR( |
388 | GetNodeAttr(op.output(0).node()->attrs(), "overlapping" , &overlapping)); |
389 | auto dx = internal::FractionalMaxPoolGrad( |
390 | scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), |
391 | op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); |
392 | grad_outputs->push_back(dx); |
393 | return scope.status(); |
394 | } |
395 | REGISTER_GRADIENT_OP("FractionalMaxPool" , FractionalMaxPoolGradHelper); |
396 | |
397 | // Templated constructor for FusedBatchNormGrad[..]::Attrs. |
398 | template <typename T> |
399 | T FusedBatchNormGradAttrs(float epsilon, StringPiece data_format, |
400 | bool is_training) { |
401 | T result; |
402 | result.epsilon_ = epsilon; |
403 | result.data_format_ = data_format; |
404 | result.is_training_ = is_training; |
405 | return result; |
406 | } |
407 | |
408 | using BatchNormGradFn = |
409 | std::function<Status(const Scope&, Output x, Output grad_y, Output scale, |
410 | const std::vector<Output>& reserve_spaces, |
411 | float epsilon, StringPiece data_format, |
412 | bool is_training, std::vector<Output>* grad_outputs)>; |
413 | |
414 | Status BaseFusedBatchNormGrad(const Scope& scope, const Operation& op, |
415 | const std::vector<Output>& grad_inputs, |
416 | BatchNormGradFn grad_fn, |
417 | std::vector<Output>* grad_outputs) { |
418 | if (op.num_outputs() < 5) { |
419 | return errors::InvalidArgument( |
420 | "FusedBatchNorm requires at least 5 outputs" ); |
421 | } |
422 | if (grad_inputs.empty()) { |
423 | return errors::InvalidArgument("FusedBatchNorm grad requires 1 grad input" ); |
424 | } |
425 | if (op.num_inputs() < 3) { |
426 | return errors::InvalidArgument("FusedBatchNorm has too few inputs" ); |
427 | } |
428 | |
429 | Output x = op.input(0); |
430 | Output grad_y = grad_inputs[0]; |
431 | Output scale = op.input(1); |
432 | float epsilon; |
433 | std::string data_format; |
434 | bool is_training; |
435 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "epsilon" , &epsilon)); |
436 | TF_RETURN_IF_ERROR( |
437 | GetNodeAttr(op.node()->attrs(), "data_format" , &data_format)); |
438 | TF_RETURN_IF_ERROR( |
439 | GetNodeAttr(op.node()->attrs(), "is_training" , &is_training)); |
440 | |
441 | std::vector<Output> reserve_spaces; |
442 | reserve_spaces.push_back(op.output(3)); |
443 | reserve_spaces.push_back(op.output(4)); |
444 | if (op.num_outputs() > 5) { |
445 | reserve_spaces.push_back(op.output(5)); |
446 | } |
447 | |
448 | if (is_training) { |
449 | return grad_fn(scope, x, grad_y, scale, reserve_spaces, epsilon, |
450 | data_format, is_training, grad_outputs); |
451 | } else { |
452 | if (op.num_inputs() < 5) { |
453 | return errors::InvalidArgument( |
454 | "FusedBatchNorm requires 5 inputs in eval mode" ); |
455 | } |
456 | |
457 | reserve_spaces[0] = op.input(3); // pop_mean |
458 | reserve_spaces[1] = op.input(4); // pop_var |
459 | if (data_format == "NCHW" ) { |
460 | x = Transpose(scope, x, {0, 2, 3, 1}); |
461 | grad_y = Transpose(scope, grad_y, {0, 2, 3, 1}); |
462 | } else if (data_format == "NCDHW" ) { |
463 | x = Transpose(scope, x, {0, 2, 3, 4, 1}); |
464 | grad_y = Transpose(scope, grad_y, {0, 2, 3, 4, 1}); |
465 | } |
466 | |
467 | StringPiece target_data_format; |
468 | if (data_format == "NCHW" || data_format == "NHWC" ) { |
469 | target_data_format = "NHWC" ; |
470 | } else { |
471 | target_data_format = "NDHWC" ; |
472 | } |
473 | |
474 | TF_RETURN_IF_ERROR(grad_fn(scope, x, grad_y, scale, reserve_spaces, epsilon, |
475 | target_data_format, is_training, grad_outputs)); |
476 | if (data_format == "NCHW" ) { |
477 | (*grad_outputs)[0] = Transpose(scope, (*grad_outputs)[0], {0, 3, 1, 2}); |
478 | } else if (data_format == "NCDHW" ) { |
479 | (*grad_outputs)[0] = |
480 | Transpose(scope, (*grad_outputs)[0], {0, 4, 1, 2, 3}); |
481 | } |
482 | return scope.status(); |
483 | } |
484 | } |
485 | |
486 | Status FusedBatchNormV3Grad(const Scope& scope, const Operation& op, |
487 | const std::vector<Output>& grad_inputs, |
488 | std::vector<Output>* grad_outputs) { |
489 | return BaseFusedBatchNormGrad( |
490 | scope, op, grad_inputs, |
491 | [](const Scope& scope, Output x, Output grad_y, Output scale, |
492 | const std::vector<Output>& reserve_spaces, float epsilon, |
493 | StringPiece data_format, bool is_training, |
494 | std::vector<Output>* grad_outputs) { |
495 | FusedBatchNormGradV3 grad( |
496 | scope, grad_y, x, scale, reserve_spaces[0], reserve_spaces[1], |
497 | reserve_spaces[2], |
498 | FusedBatchNormGradAttrs<FusedBatchNormGradV3::Attrs>( |
499 | epsilon, data_format, is_training)); |
500 | grad_outputs->push_back(grad.x_backprop); |
501 | grad_outputs->push_back(grad.scale_backprop); |
502 | grad_outputs->push_back(grad.offset_backprop); |
503 | grad_outputs->push_back(NoGradient()); |
504 | grad_outputs->push_back(NoGradient()); |
505 | return scope.status(); |
506 | }, |
507 | grad_outputs); |
508 | } |
509 | |
510 | REGISTER_GRADIENT_OP("FusedBatchNormV3" , FusedBatchNormV3Grad); |
511 | |
512 | Status Conv2DBackpropInputGrad(const Scope& scope, const Operation& op, |
513 | const std::vector<Output>& grad_inputs, |
514 | std::vector<Output>* grad_outputs) { |
515 | if (op.num_inputs() != 3) { |
516 | return errors::InvalidArgument("Conv2DBackpropInput requires 3 inputs." ); |
517 | } |
518 | if (grad_inputs.empty()) { |
519 | return errors::InvalidArgument( |
520 | "Conv2DBackpropInput grad requires 1 grad input" ); |
521 | } |
522 | |
523 | std::vector<int> dilations, strides, explicit_paddings; |
524 | bool use_cudnn_on_gpu; |
525 | std::string data_format, padding; |
526 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "dilations" , &dilations)); |
527 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "strides" , &strides)); |
528 | TF_RETURN_IF_ERROR( |
529 | GetNodeAttr(op.node()->attrs(), "explicit_paddings" , &explicit_paddings)); |
530 | TF_RETURN_IF_ERROR( |
531 | GetNodeAttr(op.node()->attrs(), "use_cudnn_on_gpu" , &use_cudnn_on_gpu)); |
532 | TF_RETURN_IF_ERROR( |
533 | GetNodeAttr(op.node()->attrs(), "data_format" , &data_format)); |
534 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "padding" , &padding)); |
535 | |
536 | grad_outputs->push_back(NoGradient()); |
537 | |
538 | Conv2DBackpropFilter::Attrs filter_attrs; |
539 | filter_attrs.use_cudnn_on_gpu_ = use_cudnn_on_gpu; |
540 | filter_attrs.explicit_paddings_ = explicit_paddings; |
541 | filter_attrs.data_format_ = data_format; |
542 | filter_attrs.dilations_ = dilations; |
543 | grad_outputs->push_back( |
544 | Conv2DBackpropFilter(scope, grad_inputs[0], Shape(scope, op.input(1)), |
545 | op.input(2), strides, padding, filter_attrs)); |
546 | |
547 | Conv2D::Attrs conv_attrs; |
548 | conv_attrs.use_cudnn_on_gpu_ = use_cudnn_on_gpu; |
549 | conv_attrs.explicit_paddings_ = explicit_paddings; |
550 | conv_attrs.data_format_ = data_format; |
551 | conv_attrs.dilations_ = dilations; |
552 | grad_outputs->push_back( |
553 | Conv2D(scope, grad_inputs[0], op.input(1), strides, padding, conv_attrs)); |
554 | return scope.status(); |
555 | } |
556 | REGISTER_GRADIENT_OP("Conv2DBackpropInput" , Conv2DBackpropInputGrad); |
557 | |
558 | Status DepthwiseConv2dNativeGrad(const Scope& scope, const Operation& op, |
559 | const std::vector<Output>& grad_inputs, |
560 | std::vector<Output>* grad_outputs) { |
561 | if (op.num_inputs() != 2) { |
562 | return errors::InvalidArgument("DepthwiseConv2dNative requires 2 inputs." ); |
563 | } |
564 | if (grad_inputs.empty()) { |
565 | return errors::InvalidArgument( |
566 | "DepthwiseConv2dNative grad requires 1 grad input" ); |
567 | } |
568 | |
569 | std::vector<int> dilations, strides, explicit_paddings; |
570 | std::string data_format, padding; |
571 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "dilations" , &dilations)); |
572 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "strides" , &strides)); |
573 | TF_RETURN_IF_ERROR( |
574 | GetNodeAttr(op.node()->attrs(), "explicit_paddings" , &explicit_paddings)); |
575 | TF_RETURN_IF_ERROR( |
576 | GetNodeAttr(op.node()->attrs(), "data_format" , &data_format)); |
577 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "padding" , &padding)); |
578 | |
579 | DepthwiseConv2dNativeBackpropInput::Attrs input_attrs; |
580 | input_attrs.explicit_paddings_ = explicit_paddings; |
581 | input_attrs.data_format_ = data_format; |
582 | input_attrs.dilations_ = dilations; |
583 | grad_outputs->push_back(DepthwiseConv2dNativeBackpropInput( |
584 | scope, Shape(scope, op.input(0)), op.input(1), grad_inputs[0], strides, |
585 | padding, input_attrs)); |
586 | |
587 | DepthwiseConv2dNativeBackpropFilter::Attrs filter_attrs; |
588 | filter_attrs.explicit_paddings_ = explicit_paddings; |
589 | filter_attrs.data_format_ = data_format; |
590 | filter_attrs.dilations_ = dilations; |
591 | grad_outputs->push_back(DepthwiseConv2dNativeBackpropFilter( |
592 | scope, op.input(0), Shape(scope, op.input(1)), grad_inputs[0], strides, |
593 | padding, filter_attrs)); |
594 | return scope.status(); |
595 | } |
596 | REGISTER_GRADIENT_OP("DepthwiseConv2dNative" , DepthwiseConv2dNativeGrad); |
597 | |
598 | } // anonymous namespace |
599 | } // namespace ops |
600 | } // namespace tensorflow |
601 | |