1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/cc/framework/grad_op_registry.h"
19#include "tensorflow/cc/framework/gradients.h"
20#include "tensorflow/cc/ops/array_ops_internal.h"
21#include "tensorflow/cc/ops/standard_ops.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23
24namespace tensorflow {
25namespace ops {
26namespace {
27
28REGISTER_NO_GRADIENT_OP("Const");
29REGISTER_NO_GRADIENT_OP("StopGradient");
30REGISTER_NO_GRADIENT_OP("ConcatOffset");
31REGISTER_NO_GRADIENT_OP("EditDistance");
32REGISTER_NO_GRADIENT_OP("ZerosLike");
33REGISTER_NO_GRADIENT_OP("InvertPermutation");
34REGISTER_NO_GRADIENT_OP("Shape");
35REGISTER_NO_GRADIENT_OP("ShapeN");
36REGISTER_NO_GRADIENT_OP("Rank");
37REGISTER_NO_GRADIENT_OP("Size");
38REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39REGISTER_NO_GRADIENT_OP("OneHot");
40
41Status PackGrad(const Scope& scope, const Operation& op,
42 const std::vector<Output>& grad_inputs,
43 std::vector<Output>* grad_outputs) {
44 int N;
45 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46 int axis;
47 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48
49 grad_outputs->reserve(N);
50 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51 for (const Output& o : grad_op.output) {
52 grad_outputs->emplace_back(o);
53 }
54 return scope.status();
55}
56REGISTER_GRADIENT_OP("Pack", PackGrad);
57
58Status UnpackGrad(const Scope& scope, const Operation& op,
59 const std::vector<Output>& grad_inputs,
60 std::vector<Output>* grad_outputs) {
61 int axis;
62 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64 return scope.status();
65}
66REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67
68Status IdentityGrad(const Scope& scope, const Operation& op,
69 const std::vector<Output>& grad_inputs,
70 std::vector<Output>* grad_outputs) {
71 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72 return scope.status();
73}
74REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75
76Status RefIdentityGrad(const Scope& scope, const Operation& op,
77 const std::vector<Output>& grad_inputs,
78 std::vector<Output>* grad_outputs) {
79 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80 return scope.status();
81}
82REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83
84Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85 const std::vector<Output>& grad_inputs,
86 std::vector<Output>* grad_outputs) {
87 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88 return scope.status();
89}
90REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91
92Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93 const Operation& op,
94 const std::vector<Output>& grad_inputs,
95 std::vector<Output>* grad_outputs) {
96 Input input = Shape(scope, op.input(0));
97 Input input_min = op.input(1);
98 Input input_max = op.input(2);
99 int64_t axis;
100 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101 auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102 scope, grad_inputs[0], input, input_min, input_max,
103 QuantizeAndDequantizeV4Grad::Axis(axis));
104 grad_outputs->push_back(qdq_v4_grad.input_backprop);
105 grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106 grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107 return scope.status();
108}
109REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110 QuantizeAndDequantizeV4GradHelper);
111
112Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113 const std::vector<Output>& grad_inputs,
114 std::vector<Output>* grad_outputs) {
115 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116 grad_outputs->push_back(NoGradient());
117 grad_outputs->push_back(NoGradient());
118 grad_outputs->push_back(NoGradient());
119 return scope.status();
120}
121REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122
123Status SplitGrad(const Scope& scope, const Operation& op,
124 const std::vector<Output>& grad_inputs,
125 std::vector<Output>* grad_outputs) {
126 grad_outputs->push_back(NoGradient());
127 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128 return scope.status();
129}
130REGISTER_GRADIENT_OP("Split", SplitGrad);
131
132Status SplitVGrad(const Scope& scope, const Operation& op,
133 const std::vector<Output>& grad_inputs,
134 std::vector<Output>* grad_outputs) {
135 if (op.num_inputs() < 3) {
136 return errors::InvalidArgument("SplitV requires 3 arguments");
137 }
138 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2)));
139 for (int i = 0; i < op.num_inputs() - 1; ++i) {
140 grad_outputs->push_back(NoGradient());
141 }
142 return scope.status();
143}
144REGISTER_GRADIENT_OP("SplitV", SplitVGrad);
145
146Status FillGrad(const Scope& scope, const Operation& op,
147 const std::vector<Output>& grad_inputs,
148 std::vector<Output>* grad_outputs) {
149 // y = fill(fill_shape, x)
150 // No gradient returned for the fill_shape argument.
151 grad_outputs->push_back(NoGradient());
152 // The gradient for x (which must be a scalar) is just the sum of
153 // all the gradients from the shape it fills.
154 // We use ReduceSum to implement this, which needs an argument providing
155 // the indices of all the dimensions of the incoming gradient.
156 // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
157 auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
158 Const(scope, 1));
159 grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
160 return scope.status();
161}
162REGISTER_GRADIENT_OP("Fill", FillGrad);
163
164Status DiagGrad(const Scope& scope, const Operation& op,
165 const std::vector<Output>& grad_inputs,
166 std::vector<Output>* grad_outputs) {
167 grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
168 return scope.status();
169}
170REGISTER_GRADIENT_OP("Diag", DiagGrad);
171
172Status DiagPartGrad(const Scope& scope, const Operation& op,
173 const std::vector<Output>& grad_inputs,
174 std::vector<Output>* grad_outputs) {
175 grad_outputs->push_back(Diag(scope, grad_inputs[0]));
176 return scope.status();
177}
178REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
179
180Status MatrixDiagGrad(const Scope& scope, const Operation& op,
181 const std::vector<Output>& grad_inputs,
182 std::vector<Output>* grad_outputs) {
183 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
184 return scope.status();
185}
186REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
187
188Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
189 const std::vector<Output>& grad_inputs,
190 std::vector<Output>* grad_outputs) {
191 auto num_lower = op.input(1);
192 auto num_upper = op.input(2);
193 grad_outputs->push_back(
194 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
195 grad_outputs->push_back(NoGradient());
196 grad_outputs->push_back(NoGradient());
197 return scope.status();
198}
199REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
200
201Status GatherNdGrad(const Scope& scope, const Operation& op,
202 const std::vector<Output>& grad_inputs,
203 std::vector<Output>* grad_outputs) {
204 auto ref = op.input(0);
205 auto ref_shape = Shape(scope, ref);
206 auto indices = op.input(1);
207 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
208 grad_outputs->push_back(NoGradient());
209 return scope.status();
210}
211REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
212
213Status CheckNumericsGrad(const Scope& scope, const Operation& op,
214 const std::vector<Output>& grad_inputs,
215 std::vector<Output>* grad_outputs) {
216 string message;
217 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
218 string err_msg = strings::StrCat(
219 "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
220 message);
221 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
222 return scope.status();
223}
224REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
225
226Status ReshapeGrad(const Scope& scope, const Operation& op,
227 const std::vector<Output>& grad_inputs,
228 std::vector<Output>* grad_outputs) {
229 auto input_shape = Shape(scope, op.input(0));
230 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
231 grad_outputs->push_back(NoGradient());
232 return scope.status();
233}
234REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
235
236Status ExpandDimsGrad(const Scope& scope, const Operation& op,
237 const std::vector<Output>& grad_inputs,
238 std::vector<Output>* grad_outputs) {
239 auto input_shape = Shape(scope, op.input(0));
240 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
241 grad_outputs->push_back(NoGradient());
242 return scope.status();
243}
244REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
245
246Status SqueezeGrad(const Scope& scope, const Operation& op,
247 const std::vector<Output>& grad_inputs,
248 std::vector<Output>* grad_outputs) {
249 auto input_shape = Shape(scope, op.input(0));
250 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
251 return scope.status();
252}
253REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
254
255Status TransposeGrad(const Scope& scope, const Operation& op,
256 const std::vector<Output>& grad_inputs,
257 std::vector<Output>* grad_outputs) {
258 auto inverted_perm = InvertPermutation(scope, op.input(1));
259 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
260 grad_outputs->push_back(NoGradient());
261 return scope.status();
262}
263REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
264
265Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
266 const std::vector<Output>& grad_inputs,
267 std::vector<Output>* grad_outputs) {
268 auto seq_lengths = op.input(1);
269 int batch_dim;
270 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
271 int seq_dim;
272 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
273 grad_outputs->push_back(
274 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
275 ReverseSequence::BatchDim(batch_dim)));
276 grad_outputs->push_back(NoGradient());
277 return scope.status();
278}
279REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
280
281Status ReverseGrad(const Scope& scope, const Operation& op,
282 const std::vector<Output>& grad_inputs,
283 std::vector<Output>* grad_outputs) {
284 auto reverse_dims = op.input(1);
285 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
286 grad_outputs->push_back(NoGradient());
287 return scope.status();
288}
289REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
290
291Status ScatterNdGrad(const Scope& scope, const Operation& op,
292 const std::vector<Output>& grad_inputs,
293 std::vector<Output>* grad_outputs) {
294 auto indices = op.input(0);
295 grad_outputs->push_back(NoGradient());
296 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
297 grad_outputs->push_back(NoGradient());
298 return scope.status();
299}
300REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
301
302Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
303 const std::vector<Output>& grad_inputs,
304 std::vector<Output>* grad_outputs) {
305 auto indices = op.input(1);
306 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
307 grad_outputs->push_back(NoGradient());
308 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
309 return scope.status();
310}
311REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
312
313template <bool IsPadV2>
314Status PadGrad(const Scope& scope, const Operation& op,
315 const std::vector<Output>& grad_inputs,
316 std::vector<Output>* grad_outputs) {
317 auto x = op.input(0);
318 auto a = op.input(1); // [Rank(x), 2]
319 // Takes a slice of a. The 1st column. [Rank(x), 1].
320 auto size = Stack(scope, {Rank(scope, x), 1});
321 auto pad_before = Slice(scope, a, {0, 0}, size);
322 // Make it a 1-D tensor.
323 auto begin = Reshape(scope, pad_before, {-1});
324 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
325 grad_outputs->push_back(NoGradient());
326 // PadV2 adds a "constant_values" input.
327 if (IsPadV2) {
328 grad_outputs->push_back(NoGradient());
329 }
330 return scope.status();
331}
332REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
333REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
334
335Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
336 const std::vector<Output>& grad_inputs,
337 std::vector<Output>* grad_outputs) {
338 int block_size;
339 TF_RETURN_IF_ERROR(
340 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
341 grad_outputs->push_back(
342 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
343 grad_outputs->push_back(NoGradient());
344 return scope.status();
345}
346REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
347
348Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
349 const std::vector<Output>& grad_inputs,
350 std::vector<Output>* grad_outputs) {
351 grad_outputs->push_back(
352 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
353 grad_outputs->push_back(NoGradient());
354 grad_outputs->push_back(NoGradient());
355 return scope.status();
356}
357REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
358
359Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
360 const std::vector<Output>& grad_inputs,
361 std::vector<Output>* grad_outputs) {
362 int block_size;
363 TF_RETURN_IF_ERROR(
364 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
365 grad_outputs->push_back(
366 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
367 grad_outputs->push_back(NoGradient());
368 return scope.status();
369}
370REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
371
372Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
373 const std::vector<Output>& grad_inputs,
374 std::vector<Output>* grad_outputs) {
375 grad_outputs->push_back(
376 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
377 grad_outputs->push_back(NoGradient());
378 grad_outputs->push_back(NoGradient());
379 return scope.status();
380}
381REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
382
383Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
384 const std::vector<Output>& grad_inputs,
385 std::vector<Output>* grad_outputs) {
386 int block_size;
387 TF_RETURN_IF_ERROR(
388 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
389 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
390 return scope.status();
391}
392REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
393
394Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
395 const std::vector<Output>& grad_inputs,
396 std::vector<Output>* grad_outputs) {
397 int block_size;
398 TF_RETURN_IF_ERROR(
399 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
400 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
401 return scope.status();
402}
403REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
404
405Status MirrorPadGrad(const Scope& scope, const Operation& op,
406 const std::vector<Output>& grad_inputs,
407 std::vector<Output>* grad_outputs) {
408 string mode;
409 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
410 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
411 scope, grad_inputs[0], op.input(1), mode));
412 grad_outputs->push_back(NoGradient());
413 return scope.status();
414}
415REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
416
417// TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
418Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
419 const std::vector<Output>& grad_inputs,
420 std::vector<Output>* grad_outputs) {
421 string mode;
422 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
423 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
424 grad_outputs->push_back(NoGradient());
425 return scope.status();
426}
427REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
428
429Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
430 const std::vector<Output>& grad_inputs,
431 std::vector<Output>* grad_outputs) {
432 Input x = Shape(scope, op.input(0));
433 Input begin = op.input(1);
434 Input end = op.input(2);
435 Input strides = op.input(3);
436 int64_t begin_mask;
437 int64_t end_mask;
438 int64_t ellipsis_mask;
439 int64_t new_axis_mask;
440 int64_t shrink_axis_mask;
441 TF_RETURN_IF_ERROR(
442 GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
443 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
444 TF_RETURN_IF_ERROR(
445 GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
446 TF_RETURN_IF_ERROR(
447 GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
448 TF_RETURN_IF_ERROR(
449 GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
450 grad_outputs->push_back(
451 StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
452 StridedSliceGrad::BeginMask(begin_mask)
453 .EndMask(end_mask)
454 .EllipsisMask(ellipsis_mask)
455 .NewAxisMask(new_axis_mask)
456 .ShrinkAxisMask(shrink_axis_mask)));
457 // No gradients returned for begin, end and strides
458 grad_outputs->push_back(NoGradient());
459 grad_outputs->push_back(NoGradient());
460 grad_outputs->push_back(NoGradient());
461 return scope.status();
462}
463REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
464
465Status SliceGrad(const Scope& scope, const Operation& op,
466 const std::vector<Output>& grad_inputs,
467 std::vector<Output>* grad_outputs) {
468 // Propagate the incoming gradient along all the selected values,
469 // and zero everywhere else. Use the Pad operator for this.
470 //
471 // First create an Nx2 padding where N is the number of input
472 // dimensions. The first column is the number of prepended zeros
473 // for each dimension, and the second column is the number of
474 // appended zeros.
475 //
476 // The first column is just the begin vector.
477 // The second column is the shape of the input element-wise
478 // subtracted by begin+size
479
480 // Running example:
481 // input.shape = [3, 5, 3]
482 // begin = [1, 2, 1], size = [1, 3, 2]
483 Input input = op.input(0);
484 Input begin = op.input(1);
485 // input_rank = 3
486 auto input_rank = Rank(scope, input);
487 // slice_size = [1, 3, 2]
488 auto slice_size = Shape(scope, op.output(0));
489 // padding_shape = [3, 1]
490 auto padding_shape = Stack(scope, {input_rank, 1});
491 // before_padding = [[1]
492 // [2]
493 // [1]]
494 Input before_padding = Reshape(scope, begin, padding_shape);
495 // after_padding_sizes = shape(input) - slice_size - begin
496 // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
497 // = [1, 0, 0]
498 auto after_padding_sizes =
499 Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
500 // after_padding = [[1]
501 // [0]
502 // [0]]
503 Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
504 // paddings = [[1 1]
505 // [2 0]
506 // [1 0]]
507 auto paddings =
508 Concat(scope, {before_padding, after_padding}, Const(scope, 1));
509 grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
510 // Nothing propagated for "begin" and "size" inputs
511 grad_outputs->push_back(NoGradient());
512 grad_outputs->push_back(NoGradient());
513 return scope.status();
514}
515REGISTER_GRADIENT_OP("Slice", SliceGrad);
516
517Status ConcatGradHelper(const Scope& scope, const Operation& op,
518 const std::vector<Output>& grad_inputs,
519 std::vector<Output>* grad_outputs,
520 int start_value_index, int end_value_index,
521 int dim_index) {
522 if (end_value_index >= op.num_inputs()) {
523 return errors::Internal("Invalid input index");
524 }
525 std::vector<Output> inputs;
526 inputs.reserve(end_value_index - start_value_index);
527 for (int i = start_value_index; i < end_value_index; ++i) {
528 inputs.push_back(op.input(i));
529 }
530
531 auto shapes = ShapeN(scope, inputs);
532 const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset");
533 auto builder =
534 ::tensorflow::NodeBuilder(unique_name, "ConcatOffset")
535 .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
536 .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
537 scope.UpdateBuilder(&builder);
538 ::tensorflow::Node* concat_offset_node;
539 scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node));
540 scope.UpdateStatus(scope.DoShapeInference(concat_offset_node));
541 if (concat_offset_node->num_outputs() != inputs.size()) {
542 return errors::Internal("ConcatOffset has invalid output count");
543 }
544 if (grad_inputs.size() != 1) {
545 return errors::InvalidArgument("Concat grad should have 1 input");
546 }
547
548 // For each dx[i], we take a slice of dy. The offset and size of the
549 // slice is given by offset[i] and shape[i].
550 const Output& dy = grad_inputs[0];
551 for (int i = 0; i < inputs.size(); ++i) {
552 grad_outputs->push_back(
553 Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i]));
554 }
555
556 // Insert a NoGradient for the axis.
557 grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient());
558 return scope.status();
559}
560
561Status ConcatV2Grad(const Scope& scope, const Operation& op,
562 const std::vector<Output>& grad_inputs,
563 std::vector<Output>* grad_outputs) {
564 return ConcatGradHelper(scope, op, grad_inputs, grad_outputs,
565 /*start_value_index=*/0,
566 /*end_value_index=*/op.num_inputs() - 1,
567 /*dim+index=*/op.num_inputs() - 1);
568}
569
570REGISTER_GRADIENT_OP("ConcatV2", ConcatV2Grad);
571
572Status BroadcastToGrad(const Scope& scope, const Operation& op,
573 const std::vector<Output>& grad_inputs,
574 std::vector<Output>* grad_outputs) {
575 if (grad_inputs.size() != 1) {
576 return errors::InvalidArgument("BroadcastTo grad should have 1 grad input");
577 }
578 if (op.num_inputs() != 2) {
579 return errors::InvalidArgument("BroadcastTo requires 2 inputs");
580 }
581
582 auto x_shape = Shape(scope, op.input(0));
583 auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1));
584 auto sum_gx = Sum(scope, grad_inputs[0], args.r0);
585 grad_outputs->push_back(Reshape(scope, sum_gx, x_shape));
586 grad_outputs->push_back(NoGradient());
587 return scope.status();
588}
589
590REGISTER_GRADIENT_OP("BroadcastTo", BroadcastToGrad);
591
592Status TileGrad(const Scope& scope, const Operation& op,
593 const std::vector<Output>& grad_inputs,
594 std::vector<Output>* grad_outputs) {
595 if (op.num_inputs() != 2) {
596 return errors::InvalidArgument("Tile requires 2 inputs");
597 }
598 if (grad_inputs.size() != 1) {
599 return errors::InvalidArgument("Tile grad requires 1 grad input");
600 }
601
602 Shape::Attrs shape_attrs;
603 shape_attrs.out_type_ = op.input_type(1);
604 auto input_shape = Shape(scope, op.input(0), shape_attrs);
605 // We interleave multiples and input_shape to get split_shape,
606 // reshape grad to split_shape, and reduce along all even
607 // dimensions (the tiled dimensions) to get the result
608 // with shape input_shape. For example
609 // input_shape = [20, 30, 40]
610 // multiples = [2, 3, 4]
611 // split_shape = [2, 20, 3, 30, 4, 40]
612 // axes = [0, 2, 4]
613 auto stack = Stack(scope, {op.input(1), input_shape.output});
614 auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1);
615 auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1});
616 auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2);
617 auto input_grad = ReduceSum(
618 scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output);
619 grad_outputs->push_back(input_grad.output);
620 grad_outputs->push_back(NoGradient());
621 return scope.status();
622}
623REGISTER_GRADIENT_OP("Tile", TileGrad);
624
625// Create a constant of the provided d_type;
626Output ConstHelper(const Scope& scope, int value, DataType d_type) {
627 return Cast(scope, Const(scope, value), d_type);
628}
629
630// Adds the batch offsets to the given indices and returns the results.
631Output GetBatchIndices(const Scope& scope, const Output& params_shape,
632 const Output& indices, int batch_dims) {
633 Output batch_indices = indices;
634 auto indices_ndims = Rank(scope, indices);
635 auto casted_params_shape = Cast(scope, params_shape, indices.type());
636 Output accum_dim_value = ConstHelper(scope, 1, indices.type());
637 for (int dim = batch_dims; dim > 0; dim--) {
638 Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1});
639 accum_dim_value = Multiply(scope, accum_dim_value,
640 Slice(scope, casted_params_shape, {dim}, {1}));
641 auto start = ConstHelper(scope, 0, indices.type());
642 auto step = ConstHelper(scope, 1, indices.type());
643 Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step);
644 dim_indices = Multiply(scope, dim_indices, accum_dim_value);
645 auto one = Cast(scope, Const(scope, {1}), indices.type());
646 auto dim_shape = Concat(
647 scope,
648 {Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value,
649 Output(Tile(scope, one,
650 ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))},
651 /*axis=*/0);
652 batch_indices =
653 Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape));
654 }
655
656 return batch_indices;
657}
658
659Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values,
660 Output indices, int batch_dims, Output gather_dim_size) {
661 // Axis is the first non-batch dimension.
662 auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
663 Output outer_shape, flat_values_shape;
664 if (batch_dims != 0) {
665 auto values_shape = Shape(scope, values);
666 // Add the batch offsets to indices and flatten the batch dimensions.
667 outer_shape = Slice(scope, values_shape, {0}, {batch_dims});
668 auto inner_shape =
669 Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1});
670 auto batch_size = Prod(scope, outer_shape, /*axis=*/0);
671 flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0);
672 gather_dim_size = Multiply(scope, gather_dim_size, batch_size);
673 indices = GetBatchIndices(scope, params_shape, indices, batch_dims);
674 values = Reshape(scope, values, flat_values_shape);
675 }
676
677 indices = Reshape(scope, indices, indices_size);
678 Output params_grad =
679 UnsortedSegmentSum(scope, values, indices, gather_dim_size);
680
681 if (batch_dims != 0) {
682 // Put back the batch dimensions.
683 params_grad = Reshape(scope, params_grad, params_shape);
684 }
685 return params_grad;
686}
687
688Status GatherV2Grad(const Scope& scope, const Operation& op,
689 const std::vector<Output>& grad_inputs,
690 std::vector<Output>* grad_outputs) {
691 if (op.num_inputs() != 3) {
692 return errors::InvalidArgument("Gather requires 3 inputs");
693 }
694 if (grad_inputs.size() != 1) {
695 return errors::InvalidArgument("Gather grad requires 1 grad input");
696 }
697
698 // params can be large, so colocate the shape calculation with it.
699 // params can be very large for sparse model, array_ops.shape raises
700 // exception on the Windows platform when any dimension is larger than
701 // int32. params_shape is not used in optimizer apply_sparse gradients,
702 // so it's fine to convert it back to int32 regardless of truncation.
703 auto params = op.input(0);
704 auto colocate_scope = scope.ColocateWith(params);
705 Shape::Attrs shape_attrs;
706 shape_attrs.out_type_ = DT_INT64;
707 auto params_shape64 = Shape(colocate_scope, params, shape_attrs);
708 Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32);
709
710 auto indices = op.input(1);
711 auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
712 auto axis = op.input(2);
713 auto axis_expand = ExpandDims(scope, axis, 0);
714
715 int batch_dims;
716 TF_RETURN_IF_ERROR(
717 GetNodeAttr(op.node()->attrs(), "batch_dims", &batch_dims));
718 if (batch_dims < 0) {
719 // TODO(bdodson): Figure out if we can find the param rank here, like the
720 // python implementation does.
721 return errors::InvalidArgument(
722 "C++ GatherV2 gradient does not support negative batch_dims.");
723 }
724
725 // Handle axis by transposing the axis dimension to be the first non-batch
726 // dimension, compute the gradient and transpose the result back.
727 auto outer_shape = Slice(scope, params_shape, {0}, axis_expand);
728 auto inner_shape =
729 Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1});
730 auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0);
731 auto values_dims = Size(scope, values_shape);
732 auto axis_dims = Size(scope, outer_shape);
733
734 Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1);
735 Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1);
736 Output inner_axes_indices =
737 Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1);
738 Output axis_dims_expand = ExpandDims(scope, axis_dims, 0);
739
740 auto values = Reshape(scope, grad_inputs[0], values_shape);
741
742 // Move values[axis] up to values[batch_dims]
743 Output transpose_dims = Concat(scope,
744 {outer_batches_indices, axis_dims_expand,
745 batch_axis_indices, inner_axes_indices},
746 0);
747 auto values_transpose = Transpose(scope, values, transpose_dims);
748 Output gather_dim_size =
749 Squeeze(scope, Slice(scope, params_shape, axis_expand, {1}));
750 params_shape = Gather(scope, params_shape, transpose_dims);
751
752 auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose,
753 indices, batch_dims, gather_dim_size);
754
755 // Inverts the above transpose by moving dimension batch_dims back to its
756 // original position.
757 Output invert_transpose_dims = Concat(scope,
758 {outer_batches_indices,
759 Add(scope, batch_axis_indices, 1),
760 {batch_dims},
761 inner_axes_indices},
762 0);
763
764 params_grad = Transpose(scope, params_grad, invert_transpose_dims);
765
766 grad_outputs->push_back(params_grad);
767 grad_outputs->push_back(NoGradient());
768 grad_outputs->push_back(NoGradient());
769 return scope.status();
770}
771
772REGISTER_GRADIENT_OP("GatherV2", GatherV2Grad);
773
774} // anonymous namespace
775} // namespace ops
776} // namespace tensorflow
777