1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | |
18 | #include "tensorflow/cc/framework/grad_op_registry.h" |
19 | #include "tensorflow/cc/framework/gradients.h" |
20 | #include "tensorflow/cc/ops/array_ops_internal.h" |
21 | #include "tensorflow/cc/ops/standard_ops.h" |
22 | #include "tensorflow/core/lib/strings/strcat.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace ops { |
26 | namespace { |
27 | |
28 | REGISTER_NO_GRADIENT_OP("Const" ); |
29 | REGISTER_NO_GRADIENT_OP("StopGradient" ); |
30 | REGISTER_NO_GRADIENT_OP("ConcatOffset" ); |
31 | REGISTER_NO_GRADIENT_OP("EditDistance" ); |
32 | REGISTER_NO_GRADIENT_OP("ZerosLike" ); |
33 | REGISTER_NO_GRADIENT_OP("InvertPermutation" ); |
34 | REGISTER_NO_GRADIENT_OP("Shape" ); |
35 | REGISTER_NO_GRADIENT_OP("ShapeN" ); |
36 | REGISTER_NO_GRADIENT_OP("Rank" ); |
37 | REGISTER_NO_GRADIENT_OP("Size" ); |
38 | REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs" ); |
39 | REGISTER_NO_GRADIENT_OP("OneHot" ); |
40 | |
41 | Status PackGrad(const Scope& scope, const Operation& op, |
42 | const std::vector<Output>& grad_inputs, |
43 | std::vector<Output>* grad_outputs) { |
44 | int N; |
45 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N" , &N)); |
46 | int axis; |
47 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis" , &axis)); |
48 | |
49 | grad_outputs->reserve(N); |
50 | auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); |
51 | for (const Output& o : grad_op.output) { |
52 | grad_outputs->emplace_back(o); |
53 | } |
54 | return scope.status(); |
55 | } |
56 | REGISTER_GRADIENT_OP("Pack" , PackGrad); |
57 | |
58 | Status UnpackGrad(const Scope& scope, const Operation& op, |
59 | const std::vector<Output>& grad_inputs, |
60 | std::vector<Output>* grad_outputs) { |
61 | int axis; |
62 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis" , &axis)); |
63 | grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); |
64 | return scope.status(); |
65 | } |
66 | REGISTER_GRADIENT_OP("Unpack" , UnpackGrad); |
67 | |
68 | Status IdentityGrad(const Scope& scope, const Operation& op, |
69 | const std::vector<Output>& grad_inputs, |
70 | std::vector<Output>* grad_outputs) { |
71 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
72 | return scope.status(); |
73 | } |
74 | REGISTER_GRADIENT_OP("Identity" , IdentityGrad); |
75 | |
76 | Status RefIdentityGrad(const Scope& scope, const Operation& op, |
77 | const std::vector<Output>& grad_inputs, |
78 | std::vector<Output>* grad_outputs) { |
79 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
80 | return scope.status(); |
81 | } |
82 | REGISTER_GRADIENT_OP("RefIdentity" , RefIdentityGrad); |
83 | |
84 | Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, |
85 | const std::vector<Output>& grad_inputs, |
86 | std::vector<Output>* grad_outputs) { |
87 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
88 | return scope.status(); |
89 | } |
90 | REGISTER_GRADIENT_OP("QuantizeAndDequantize" , QuantizeAndDequantizeGrad); |
91 | |
92 | Status QuantizeAndDequantizeV4GradHelper(const Scope& scope, |
93 | const Operation& op, |
94 | const std::vector<Output>& grad_inputs, |
95 | std::vector<Output>* grad_outputs) { |
96 | Input input = Shape(scope, op.input(0)); |
97 | Input input_min = op.input(1); |
98 | Input input_max = op.input(2); |
99 | int64_t axis; |
100 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis" , &axis)); |
101 | auto qdq_v4_grad = QuantizeAndDequantizeV4Grad( |
102 | scope, grad_inputs[0], input, input_min, input_max, |
103 | QuantizeAndDequantizeV4Grad::Axis(axis)); |
104 | grad_outputs->push_back(qdq_v4_grad.input_backprop); |
105 | grad_outputs->push_back(qdq_v4_grad.input_min_backprop); |
106 | grad_outputs->push_back(qdq_v4_grad.input_max_backprop); |
107 | return scope.status(); |
108 | } |
109 | REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4" , |
110 | QuantizeAndDequantizeV4GradHelper); |
111 | |
112 | Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, |
113 | const std::vector<Output>& grad_inputs, |
114 | std::vector<Output>* grad_outputs) { |
115 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
116 | grad_outputs->push_back(NoGradient()); |
117 | grad_outputs->push_back(NoGradient()); |
118 | grad_outputs->push_back(NoGradient()); |
119 | return scope.status(); |
120 | } |
121 | REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3" , QuantizeAndDequantizeV3Grad); |
122 | |
123 | Status SplitGrad(const Scope& scope, const Operation& op, |
124 | const std::vector<Output>& grad_inputs, |
125 | std::vector<Output>* grad_outputs) { |
126 | grad_outputs->push_back(NoGradient()); |
127 | grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0))); |
128 | return scope.status(); |
129 | } |
130 | REGISTER_GRADIENT_OP("Split" , SplitGrad); |
131 | |
132 | Status SplitVGrad(const Scope& scope, const Operation& op, |
133 | const std::vector<Output>& grad_inputs, |
134 | std::vector<Output>* grad_outputs) { |
135 | if (op.num_inputs() < 3) { |
136 | return errors::InvalidArgument("SplitV requires 3 arguments" ); |
137 | } |
138 | grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2))); |
139 | for (int i = 0; i < op.num_inputs() - 1; ++i) { |
140 | grad_outputs->push_back(NoGradient()); |
141 | } |
142 | return scope.status(); |
143 | } |
144 | REGISTER_GRADIENT_OP("SplitV" , SplitVGrad); |
145 | |
146 | Status FillGrad(const Scope& scope, const Operation& op, |
147 | const std::vector<Output>& grad_inputs, |
148 | std::vector<Output>* grad_outputs) { |
149 | // y = fill(fill_shape, x) |
150 | // No gradient returned for the fill_shape argument. |
151 | grad_outputs->push_back(NoGradient()); |
152 | // The gradient for x (which must be a scalar) is just the sum of |
153 | // all the gradients from the shape it fills. |
154 | // We use ReduceSum to implement this, which needs an argument providing |
155 | // the indices of all the dimensions of the incoming gradient. |
156 | // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) |
157 | auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), |
158 | Const(scope, 1)); |
159 | grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); |
160 | return scope.status(); |
161 | } |
162 | REGISTER_GRADIENT_OP("Fill" , FillGrad); |
163 | |
164 | Status DiagGrad(const Scope& scope, const Operation& op, |
165 | const std::vector<Output>& grad_inputs, |
166 | std::vector<Output>* grad_outputs) { |
167 | grad_outputs->push_back(DiagPart(scope, grad_inputs[0])); |
168 | return scope.status(); |
169 | } |
170 | REGISTER_GRADIENT_OP("Diag" , DiagGrad); |
171 | |
172 | Status DiagPartGrad(const Scope& scope, const Operation& op, |
173 | const std::vector<Output>& grad_inputs, |
174 | std::vector<Output>* grad_outputs) { |
175 | grad_outputs->push_back(Diag(scope, grad_inputs[0])); |
176 | return scope.status(); |
177 | } |
178 | REGISTER_GRADIENT_OP("DiagPart" , DiagPartGrad); |
179 | |
180 | Status MatrixDiagGrad(const Scope& scope, const Operation& op, |
181 | const std::vector<Output>& grad_inputs, |
182 | std::vector<Output>* grad_outputs) { |
183 | grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0])); |
184 | return scope.status(); |
185 | } |
186 | REGISTER_GRADIENT_OP("MatrixDiag" , MatrixDiagGrad); |
187 | |
188 | Status MatrixBandPartGrad(const Scope& scope, const Operation& op, |
189 | const std::vector<Output>& grad_inputs, |
190 | std::vector<Output>* grad_outputs) { |
191 | auto num_lower = op.input(1); |
192 | auto num_upper = op.input(2); |
193 | grad_outputs->push_back( |
194 | MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper)); |
195 | grad_outputs->push_back(NoGradient()); |
196 | grad_outputs->push_back(NoGradient()); |
197 | return scope.status(); |
198 | } |
199 | REGISTER_GRADIENT_OP("MatrixBandPart" , MatrixBandPartGrad); |
200 | |
201 | Status GatherNdGrad(const Scope& scope, const Operation& op, |
202 | const std::vector<Output>& grad_inputs, |
203 | std::vector<Output>* grad_outputs) { |
204 | auto ref = op.input(0); |
205 | auto ref_shape = Shape(scope, ref); |
206 | auto indices = op.input(1); |
207 | grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape)); |
208 | grad_outputs->push_back(NoGradient()); |
209 | return scope.status(); |
210 | } |
211 | REGISTER_GRADIENT_OP("GatherNd" , GatherNdGrad); |
212 | |
213 | Status CheckNumericsGrad(const Scope& scope, const Operation& op, |
214 | const std::vector<Output>& grad_inputs, |
215 | std::vector<Output>* grad_outputs) { |
216 | string message; |
217 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message" , &message)); |
218 | string err_msg = strings::StrCat( |
219 | "Not a number (NaN) or infinity (Inf) values detected in gradient. " , |
220 | message); |
221 | grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); |
222 | return scope.status(); |
223 | } |
224 | REGISTER_GRADIENT_OP("CheckNumerics" , CheckNumericsGrad); |
225 | |
226 | Status ReshapeGrad(const Scope& scope, const Operation& op, |
227 | const std::vector<Output>& grad_inputs, |
228 | std::vector<Output>* grad_outputs) { |
229 | auto input_shape = Shape(scope, op.input(0)); |
230 | grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); |
231 | grad_outputs->push_back(NoGradient()); |
232 | return scope.status(); |
233 | } |
234 | REGISTER_GRADIENT_OP("Reshape" , ReshapeGrad); |
235 | |
236 | Status ExpandDimsGrad(const Scope& scope, const Operation& op, |
237 | const std::vector<Output>& grad_inputs, |
238 | std::vector<Output>* grad_outputs) { |
239 | auto input_shape = Shape(scope, op.input(0)); |
240 | grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); |
241 | grad_outputs->push_back(NoGradient()); |
242 | return scope.status(); |
243 | } |
244 | REGISTER_GRADIENT_OP("ExpandDims" , ExpandDimsGrad); |
245 | |
246 | Status SqueezeGrad(const Scope& scope, const Operation& op, |
247 | const std::vector<Output>& grad_inputs, |
248 | std::vector<Output>* grad_outputs) { |
249 | auto input_shape = Shape(scope, op.input(0)); |
250 | grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); |
251 | return scope.status(); |
252 | } |
253 | REGISTER_GRADIENT_OP("Squeeze" , SqueezeGrad); |
254 | |
255 | Status TransposeGrad(const Scope& scope, const Operation& op, |
256 | const std::vector<Output>& grad_inputs, |
257 | std::vector<Output>* grad_outputs) { |
258 | auto inverted_perm = InvertPermutation(scope, op.input(1)); |
259 | grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm)); |
260 | grad_outputs->push_back(NoGradient()); |
261 | return scope.status(); |
262 | } |
263 | REGISTER_GRADIENT_OP("Transpose" , TransposeGrad); |
264 | |
265 | Status ReverseSequenceGrad(const Scope& scope, const Operation& op, |
266 | const std::vector<Output>& grad_inputs, |
267 | std::vector<Output>* grad_outputs) { |
268 | auto seq_lengths = op.input(1); |
269 | int batch_dim; |
270 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim" , &batch_dim)); |
271 | int seq_dim; |
272 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim" , &seq_dim)); |
273 | grad_outputs->push_back( |
274 | ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, |
275 | ReverseSequence::BatchDim(batch_dim))); |
276 | grad_outputs->push_back(NoGradient()); |
277 | return scope.status(); |
278 | } |
279 | REGISTER_GRADIENT_OP("ReverseSequence" , ReverseSequenceGrad); |
280 | |
281 | Status ReverseGrad(const Scope& scope, const Operation& op, |
282 | const std::vector<Output>& grad_inputs, |
283 | std::vector<Output>* grad_outputs) { |
284 | auto reverse_dims = op.input(1); |
285 | grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); |
286 | grad_outputs->push_back(NoGradient()); |
287 | return scope.status(); |
288 | } |
289 | REGISTER_GRADIENT_OP("ReverseV2" , ReverseGrad); |
290 | |
291 | Status ScatterNdGrad(const Scope& scope, const Operation& op, |
292 | const std::vector<Output>& grad_inputs, |
293 | std::vector<Output>* grad_outputs) { |
294 | auto indices = op.input(0); |
295 | grad_outputs->push_back(NoGradient()); |
296 | grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); |
297 | grad_outputs->push_back(NoGradient()); |
298 | return scope.status(); |
299 | } |
300 | REGISTER_GRADIENT_OP("ScatterNd" , ScatterNdGrad); |
301 | |
302 | Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, |
303 | const std::vector<Output>& grad_inputs, |
304 | std::vector<Output>* grad_outputs) { |
305 | auto indices = op.input(1); |
306 | grad_outputs->push_back(Identity(scope, grad_inputs[0])); |
307 | grad_outputs->push_back(NoGradient()); |
308 | grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); |
309 | return scope.status(); |
310 | } |
311 | REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd" , ScatterNdNonAliasingAddGrad); |
312 | |
313 | template <bool IsPadV2> |
314 | Status PadGrad(const Scope& scope, const Operation& op, |
315 | const std::vector<Output>& grad_inputs, |
316 | std::vector<Output>* grad_outputs) { |
317 | auto x = op.input(0); |
318 | auto a = op.input(1); // [Rank(x), 2] |
319 | // Takes a slice of a. The 1st column. [Rank(x), 1]. |
320 | auto size = Stack(scope, {Rank(scope, x), 1}); |
321 | auto pad_before = Slice(scope, a, {0, 0}, size); |
322 | // Make it a 1-D tensor. |
323 | auto begin = Reshape(scope, pad_before, {-1}); |
324 | grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); |
325 | grad_outputs->push_back(NoGradient()); |
326 | // PadV2 adds a "constant_values" input. |
327 | if (IsPadV2) { |
328 | grad_outputs->push_back(NoGradient()); |
329 | } |
330 | return scope.status(); |
331 | } |
332 | REGISTER_GRADIENT_OP("Pad" , PadGrad<false>); |
333 | REGISTER_GRADIENT_OP("PadV2" , PadGrad<true>); |
334 | |
335 | Status SpaceToBatchGrad(const Scope& scope, const Operation& op, |
336 | const std::vector<Output>& grad_inputs, |
337 | std::vector<Output>* grad_outputs) { |
338 | int block_size; |
339 | TF_RETURN_IF_ERROR( |
340 | GetNodeAttr(op.node()->attrs(), "block_size" , &block_size)); |
341 | grad_outputs->push_back( |
342 | BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); |
343 | grad_outputs->push_back(NoGradient()); |
344 | return scope.status(); |
345 | } |
346 | REGISTER_GRADIENT_OP("SpaceToBatch" , SpaceToBatchGrad); |
347 | |
348 | Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op, |
349 | const std::vector<Output>& grad_inputs, |
350 | std::vector<Output>* grad_outputs) { |
351 | grad_outputs->push_back( |
352 | BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2))); |
353 | grad_outputs->push_back(NoGradient()); |
354 | grad_outputs->push_back(NoGradient()); |
355 | return scope.status(); |
356 | } |
357 | REGISTER_GRADIENT_OP("SpaceToBatchND" , SpaceToBatchNDGrad); |
358 | |
359 | Status BatchToSpaceGrad(const Scope& scope, const Operation& op, |
360 | const std::vector<Output>& grad_inputs, |
361 | std::vector<Output>* grad_outputs) { |
362 | int block_size; |
363 | TF_RETURN_IF_ERROR( |
364 | GetNodeAttr(op.node()->attrs(), "block_size" , &block_size)); |
365 | grad_outputs->push_back( |
366 | SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); |
367 | grad_outputs->push_back(NoGradient()); |
368 | return scope.status(); |
369 | } |
370 | REGISTER_GRADIENT_OP("BatchToSpace" , BatchToSpaceGrad); |
371 | |
372 | Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op, |
373 | const std::vector<Output>& grad_inputs, |
374 | std::vector<Output>* grad_outputs) { |
375 | grad_outputs->push_back( |
376 | SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2))); |
377 | grad_outputs->push_back(NoGradient()); |
378 | grad_outputs->push_back(NoGradient()); |
379 | return scope.status(); |
380 | } |
381 | REGISTER_GRADIENT_OP("BatchToSpaceND" , BatchToSpaceNDGrad); |
382 | |
383 | Status SpaceToDepthGrad(const Scope& scope, const Operation& op, |
384 | const std::vector<Output>& grad_inputs, |
385 | std::vector<Output>* grad_outputs) { |
386 | int block_size; |
387 | TF_RETURN_IF_ERROR( |
388 | GetNodeAttr(op.node()->attrs(), "block_size" , &block_size)); |
389 | grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); |
390 | return scope.status(); |
391 | } |
392 | REGISTER_GRADIENT_OP("SpaceToDepth" , SpaceToDepthGrad); |
393 | |
394 | Status DepthToSpaceGrad(const Scope& scope, const Operation& op, |
395 | const std::vector<Output>& grad_inputs, |
396 | std::vector<Output>* grad_outputs) { |
397 | int block_size; |
398 | TF_RETURN_IF_ERROR( |
399 | GetNodeAttr(op.node()->attrs(), "block_size" , &block_size)); |
400 | grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); |
401 | return scope.status(); |
402 | } |
403 | REGISTER_GRADIENT_OP("DepthToSpace" , DepthToSpaceGrad); |
404 | |
405 | Status MirrorPadGrad(const Scope& scope, const Operation& op, |
406 | const std::vector<Output>& grad_inputs, |
407 | std::vector<Output>* grad_outputs) { |
408 | string mode; |
409 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode" , &mode)); |
410 | grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( |
411 | scope, grad_inputs[0], op.input(1), mode)); |
412 | grad_outputs->push_back(NoGradient()); |
413 | return scope.status(); |
414 | } |
415 | REGISTER_GRADIENT_OP("MirrorPad" , MirrorPadGrad); |
416 | |
417 | // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4. |
418 | Status MirrorPadGradGrad(const Scope& scope, const Operation& op, |
419 | const std::vector<Output>& grad_inputs, |
420 | std::vector<Output>* grad_outputs) { |
421 | string mode; |
422 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode" , &mode)); |
423 | grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); |
424 | grad_outputs->push_back(NoGradient()); |
425 | return scope.status(); |
426 | } |
427 | REGISTER_GRADIENT_OP("MirrorPadGrad" , MirrorPadGradGrad); |
428 | |
429 | Status StridedSliceGradHelper(const Scope& scope, const Operation& op, |
430 | const std::vector<Output>& grad_inputs, |
431 | std::vector<Output>* grad_outputs) { |
432 | Input x = Shape(scope, op.input(0)); |
433 | Input begin = op.input(1); |
434 | Input end = op.input(2); |
435 | Input strides = op.input(3); |
436 | int64_t begin_mask; |
437 | int64_t end_mask; |
438 | int64_t ellipsis_mask; |
439 | int64_t new_axis_mask; |
440 | int64_t shrink_axis_mask; |
441 | TF_RETURN_IF_ERROR( |
442 | GetNodeAttr(op.node()->attrs(), "begin_mask" , &begin_mask)); |
443 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask" , &end_mask)); |
444 | TF_RETURN_IF_ERROR( |
445 | GetNodeAttr(op.node()->attrs(), "ellipsis_mask" , &ellipsis_mask)); |
446 | TF_RETURN_IF_ERROR( |
447 | GetNodeAttr(op.node()->attrs(), "new_axis_mask" , &new_axis_mask)); |
448 | TF_RETURN_IF_ERROR( |
449 | GetNodeAttr(op.node()->attrs(), "shrink_axis_mask" , &shrink_axis_mask)); |
450 | grad_outputs->push_back( |
451 | StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0], |
452 | StridedSliceGrad::BeginMask(begin_mask) |
453 | .EndMask(end_mask) |
454 | .EllipsisMask(ellipsis_mask) |
455 | .NewAxisMask(new_axis_mask) |
456 | .ShrinkAxisMask(shrink_axis_mask))); |
457 | // No gradients returned for begin, end and strides |
458 | grad_outputs->push_back(NoGradient()); |
459 | grad_outputs->push_back(NoGradient()); |
460 | grad_outputs->push_back(NoGradient()); |
461 | return scope.status(); |
462 | } |
463 | REGISTER_GRADIENT_OP("StridedSlice" , StridedSliceGradHelper); |
464 | |
465 | Status SliceGrad(const Scope& scope, const Operation& op, |
466 | const std::vector<Output>& grad_inputs, |
467 | std::vector<Output>* grad_outputs) { |
468 | // Propagate the incoming gradient along all the selected values, |
469 | // and zero everywhere else. Use the Pad operator for this. |
470 | // |
471 | // First create an Nx2 padding where N is the number of input |
472 | // dimensions. The first column is the number of prepended zeros |
473 | // for each dimension, and the second column is the number of |
474 | // appended zeros. |
475 | // |
476 | // The first column is just the begin vector. |
477 | // The second column is the shape of the input element-wise |
478 | // subtracted by begin+size |
479 | |
480 | // Running example: |
481 | // input.shape = [3, 5, 3] |
482 | // begin = [1, 2, 1], size = [1, 3, 2] |
483 | Input input = op.input(0); |
484 | Input begin = op.input(1); |
485 | // input_rank = 3 |
486 | auto input_rank = Rank(scope, input); |
487 | // slice_size = [1, 3, 2] |
488 | auto slice_size = Shape(scope, op.output(0)); |
489 | // padding_shape = [3, 1] |
490 | auto padding_shape = Stack(scope, {input_rank, 1}); |
491 | // before_padding = [[1] |
492 | // [2] |
493 | // [1]] |
494 | Input before_padding = Reshape(scope, begin, padding_shape); |
495 | // after_padding_sizes = shape(input) - slice_size - begin |
496 | // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1] |
497 | // = [1, 0, 0] |
498 | auto after_padding_sizes = |
499 | Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin); |
500 | // after_padding = [[1] |
501 | // [0] |
502 | // [0]] |
503 | Input after_padding = Reshape(scope, after_padding_sizes, padding_shape); |
504 | // paddings = [[1 1] |
505 | // [2 0] |
506 | // [1 0]] |
507 | auto paddings = |
508 | Concat(scope, {before_padding, after_padding}, Const(scope, 1)); |
509 | grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings)); |
510 | // Nothing propagated for "begin" and "size" inputs |
511 | grad_outputs->push_back(NoGradient()); |
512 | grad_outputs->push_back(NoGradient()); |
513 | return scope.status(); |
514 | } |
515 | REGISTER_GRADIENT_OP("Slice" , SliceGrad); |
516 | |
517 | Status ConcatGradHelper(const Scope& scope, const Operation& op, |
518 | const std::vector<Output>& grad_inputs, |
519 | std::vector<Output>* grad_outputs, |
520 | int start_value_index, int end_value_index, |
521 | int dim_index) { |
522 | if (end_value_index >= op.num_inputs()) { |
523 | return errors::Internal("Invalid input index" ); |
524 | } |
525 | std::vector<Output> inputs; |
526 | inputs.reserve(end_value_index - start_value_index); |
527 | for (int i = start_value_index; i < end_value_index; ++i) { |
528 | inputs.push_back(op.input(i)); |
529 | } |
530 | |
531 | auto shapes = ShapeN(scope, inputs); |
532 | const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset" ); |
533 | auto builder = |
534 | ::tensorflow::NodeBuilder(unique_name, "ConcatOffset" ) |
535 | .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index))) |
536 | .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output)); |
537 | scope.UpdateBuilder(&builder); |
538 | ::tensorflow::Node* concat_offset_node; |
539 | scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node)); |
540 | scope.UpdateStatus(scope.DoShapeInference(concat_offset_node)); |
541 | if (concat_offset_node->num_outputs() != inputs.size()) { |
542 | return errors::Internal("ConcatOffset has invalid output count" ); |
543 | } |
544 | if (grad_inputs.size() != 1) { |
545 | return errors::InvalidArgument("Concat grad should have 1 input" ); |
546 | } |
547 | |
548 | // For each dx[i], we take a slice of dy. The offset and size of the |
549 | // slice is given by offset[i] and shape[i]. |
550 | const Output& dy = grad_inputs[0]; |
551 | for (int i = 0; i < inputs.size(); ++i) { |
552 | grad_outputs->push_back( |
553 | Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i])); |
554 | } |
555 | |
556 | // Insert a NoGradient for the axis. |
557 | grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient()); |
558 | return scope.status(); |
559 | } |
560 | |
561 | Status ConcatV2Grad(const Scope& scope, const Operation& op, |
562 | const std::vector<Output>& grad_inputs, |
563 | std::vector<Output>* grad_outputs) { |
564 | return ConcatGradHelper(scope, op, grad_inputs, grad_outputs, |
565 | /*start_value_index=*/0, |
566 | /*end_value_index=*/op.num_inputs() - 1, |
567 | /*dim+index=*/op.num_inputs() - 1); |
568 | } |
569 | |
570 | REGISTER_GRADIENT_OP("ConcatV2" , ConcatV2Grad); |
571 | |
572 | Status BroadcastToGrad(const Scope& scope, const Operation& op, |
573 | const std::vector<Output>& grad_inputs, |
574 | std::vector<Output>* grad_outputs) { |
575 | if (grad_inputs.size() != 1) { |
576 | return errors::InvalidArgument("BroadcastTo grad should have 1 grad input" ); |
577 | } |
578 | if (op.num_inputs() != 2) { |
579 | return errors::InvalidArgument("BroadcastTo requires 2 inputs" ); |
580 | } |
581 | |
582 | auto x_shape = Shape(scope, op.input(0)); |
583 | auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1)); |
584 | auto sum_gx = Sum(scope, grad_inputs[0], args.r0); |
585 | grad_outputs->push_back(Reshape(scope, sum_gx, x_shape)); |
586 | grad_outputs->push_back(NoGradient()); |
587 | return scope.status(); |
588 | } |
589 | |
590 | REGISTER_GRADIENT_OP("BroadcastTo" , BroadcastToGrad); |
591 | |
592 | Status TileGrad(const Scope& scope, const Operation& op, |
593 | const std::vector<Output>& grad_inputs, |
594 | std::vector<Output>* grad_outputs) { |
595 | if (op.num_inputs() != 2) { |
596 | return errors::InvalidArgument("Tile requires 2 inputs" ); |
597 | } |
598 | if (grad_inputs.size() != 1) { |
599 | return errors::InvalidArgument("Tile grad requires 1 grad input" ); |
600 | } |
601 | |
602 | Shape::Attrs shape_attrs; |
603 | shape_attrs.out_type_ = op.input_type(1); |
604 | auto input_shape = Shape(scope, op.input(0), shape_attrs); |
605 | // We interleave multiples and input_shape to get split_shape, |
606 | // reshape grad to split_shape, and reduce along all even |
607 | // dimensions (the tiled dimensions) to get the result |
608 | // with shape input_shape. For example |
609 | // input_shape = [20, 30, 40] |
610 | // multiples = [2, 3, 4] |
611 | // split_shape = [2, 20, 3, 30, 4, 40] |
612 | // axes = [0, 2, 4] |
613 | auto stack = Stack(scope, {op.input(1), input_shape.output}); |
614 | auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1); |
615 | auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1}); |
616 | auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2); |
617 | auto input_grad = ReduceSum( |
618 | scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output); |
619 | grad_outputs->push_back(input_grad.output); |
620 | grad_outputs->push_back(NoGradient()); |
621 | return scope.status(); |
622 | } |
623 | REGISTER_GRADIENT_OP("Tile" , TileGrad); |
624 | |
625 | // Create a constant of the provided d_type; |
626 | Output ConstHelper(const Scope& scope, int value, DataType d_type) { |
627 | return Cast(scope, Const(scope, value), d_type); |
628 | } |
629 | |
630 | // Adds the batch offsets to the given indices and returns the results. |
631 | Output GetBatchIndices(const Scope& scope, const Output& params_shape, |
632 | const Output& indices, int batch_dims) { |
633 | Output batch_indices = indices; |
634 | auto indices_ndims = Rank(scope, indices); |
635 | auto casted_params_shape = Cast(scope, params_shape, indices.type()); |
636 | Output accum_dim_value = ConstHelper(scope, 1, indices.type()); |
637 | for (int dim = batch_dims; dim > 0; dim--) { |
638 | Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1}); |
639 | accum_dim_value = Multiply(scope, accum_dim_value, |
640 | Slice(scope, casted_params_shape, {dim}, {1})); |
641 | auto start = ConstHelper(scope, 0, indices.type()); |
642 | auto step = ConstHelper(scope, 1, indices.type()); |
643 | Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step); |
644 | dim_indices = Multiply(scope, dim_indices, accum_dim_value); |
645 | auto one = Cast(scope, Const(scope, {1}), indices.type()); |
646 | auto dim_shape = Concat( |
647 | scope, |
648 | {Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value, |
649 | Output(Tile(scope, one, |
650 | ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))}, |
651 | /*axis=*/0); |
652 | batch_indices = |
653 | Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape)); |
654 | } |
655 | |
656 | return batch_indices; |
657 | } |
658 | |
659 | Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values, |
660 | Output indices, int batch_dims, Output gather_dim_size) { |
661 | // Axis is the first non-batch dimension. |
662 | auto indices_size = ExpandDims(scope, Size(scope, indices), 0); |
663 | Output outer_shape, flat_values_shape; |
664 | if (batch_dims != 0) { |
665 | auto values_shape = Shape(scope, values); |
666 | // Add the batch offsets to indices and flatten the batch dimensions. |
667 | outer_shape = Slice(scope, values_shape, {0}, {batch_dims}); |
668 | auto inner_shape = |
669 | Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1}); |
670 | auto batch_size = Prod(scope, outer_shape, /*axis=*/0); |
671 | flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0); |
672 | gather_dim_size = Multiply(scope, gather_dim_size, batch_size); |
673 | indices = GetBatchIndices(scope, params_shape, indices, batch_dims); |
674 | values = Reshape(scope, values, flat_values_shape); |
675 | } |
676 | |
677 | indices = Reshape(scope, indices, indices_size); |
678 | Output params_grad = |
679 | UnsortedSegmentSum(scope, values, indices, gather_dim_size); |
680 | |
681 | if (batch_dims != 0) { |
682 | // Put back the batch dimensions. |
683 | params_grad = Reshape(scope, params_grad, params_shape); |
684 | } |
685 | return params_grad; |
686 | } |
687 | |
688 | Status GatherV2Grad(const Scope& scope, const Operation& op, |
689 | const std::vector<Output>& grad_inputs, |
690 | std::vector<Output>* grad_outputs) { |
691 | if (op.num_inputs() != 3) { |
692 | return errors::InvalidArgument("Gather requires 3 inputs" ); |
693 | } |
694 | if (grad_inputs.size() != 1) { |
695 | return errors::InvalidArgument("Gather grad requires 1 grad input" ); |
696 | } |
697 | |
698 | // params can be large, so colocate the shape calculation with it. |
699 | // params can be very large for sparse model, array_ops.shape raises |
700 | // exception on the Windows platform when any dimension is larger than |
701 | // int32. params_shape is not used in optimizer apply_sparse gradients, |
702 | // so it's fine to convert it back to int32 regardless of truncation. |
703 | auto params = op.input(0); |
704 | auto colocate_scope = scope.ColocateWith(params); |
705 | Shape::Attrs shape_attrs; |
706 | shape_attrs.out_type_ = DT_INT64; |
707 | auto params_shape64 = Shape(colocate_scope, params, shape_attrs); |
708 | Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32); |
709 | |
710 | auto indices = op.input(1); |
711 | auto indices_size = ExpandDims(scope, Size(scope, indices), 0); |
712 | auto axis = op.input(2); |
713 | auto axis_expand = ExpandDims(scope, axis, 0); |
714 | |
715 | int batch_dims; |
716 | TF_RETURN_IF_ERROR( |
717 | GetNodeAttr(op.node()->attrs(), "batch_dims" , &batch_dims)); |
718 | if (batch_dims < 0) { |
719 | // TODO(bdodson): Figure out if we can find the param rank here, like the |
720 | // python implementation does. |
721 | return errors::InvalidArgument( |
722 | "C++ GatherV2 gradient does not support negative batch_dims." ); |
723 | } |
724 | |
725 | // Handle axis by transposing the axis dimension to be the first non-batch |
726 | // dimension, compute the gradient and transpose the result back. |
727 | auto outer_shape = Slice(scope, params_shape, {0}, axis_expand); |
728 | auto inner_shape = |
729 | Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1}); |
730 | auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0); |
731 | auto values_dims = Size(scope, values_shape); |
732 | auto axis_dims = Size(scope, outer_shape); |
733 | |
734 | Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1); |
735 | Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1); |
736 | Output inner_axes_indices = |
737 | Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1); |
738 | Output axis_dims_expand = ExpandDims(scope, axis_dims, 0); |
739 | |
740 | auto values = Reshape(scope, grad_inputs[0], values_shape); |
741 | |
742 | // Move values[axis] up to values[batch_dims] |
743 | Output transpose_dims = Concat(scope, |
744 | {outer_batches_indices, axis_dims_expand, |
745 | batch_axis_indices, inner_axes_indices}, |
746 | 0); |
747 | auto values_transpose = Transpose(scope, values, transpose_dims); |
748 | Output gather_dim_size = |
749 | Squeeze(scope, Slice(scope, params_shape, axis_expand, {1})); |
750 | params_shape = Gather(scope, params_shape, transpose_dims); |
751 | |
752 | auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose, |
753 | indices, batch_dims, gather_dim_size); |
754 | |
755 | // Inverts the above transpose by moving dimension batch_dims back to its |
756 | // original position. |
757 | Output invert_transpose_dims = Concat(scope, |
758 | {outer_batches_indices, |
759 | Add(scope, batch_axis_indices, 1), |
760 | {batch_dims}, |
761 | inner_axes_indices}, |
762 | 0); |
763 | |
764 | params_grad = Transpose(scope, params_grad, invert_transpose_dims); |
765 | |
766 | grad_outputs->push_back(params_grad); |
767 | grad_outputs->push_back(NoGradient()); |
768 | grad_outputs->push_back(NoGradient()); |
769 | return scope.status(); |
770 | } |
771 | |
772 | REGISTER_GRADIENT_OP("GatherV2" , GatherV2Grad); |
773 | |
774 | } // anonymous namespace |
775 | } // namespace ops |
776 | } // namespace tensorflow |
777 | |