1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17
18#include "tensorflow/cc/framework/grad_op_registry.h"
19#include "tensorflow/cc/framework/gradients.h"
20#include "tensorflow/cc/gradients/grad_helper.h"
21#include "tensorflow/cc/ops/array_ops.h"
22#include "tensorflow/cc/ops/array_ops_internal.h"
23#include "tensorflow/cc/ops/math_ops.h"
24#include "tensorflow/cc/ops/math_ops_internal.h"
25#include "tensorflow/cc/ops/standard_ops.h"
26
27namespace tensorflow {
28namespace ops {
29namespace {
30
31// Logical operations have no gradients.
32REGISTER_NO_GRADIENT_OP("Less");
33REGISTER_NO_GRADIENT_OP("LessEqual");
34REGISTER_NO_GRADIENT_OP("Greater");
35REGISTER_NO_GRADIENT_OP("GreaterEqual");
36REGISTER_NO_GRADIENT_OP("Equal");
37REGISTER_NO_GRADIENT_OP("ApproximateEqual");
38REGISTER_NO_GRADIENT_OP("NotEqual");
39REGISTER_NO_GRADIENT_OP("LogicalAnd");
40REGISTER_NO_GRADIENT_OP("LogicalOr");
41REGISTER_NO_GRADIENT_OP("LogicalNot");
42REGISTER_NO_GRADIENT_OP("Floor");
43
44// Conjugate helper function returns the conjugate of an Output if it
45// is complex valued.
46Output ConjugateHelper(const Scope& scope, const Output& out) {
47 DataType dtype = out.type();
48 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
49 return Conj(scope, out);
50 } else {
51 return out;
52 }
53}
54
55// TODO(andydavis) Add control dependencies to gradient functions (as needed).
56
57Status AbsGrad(const Scope& scope, const Operation& op,
58 const std::vector<Output>& grad_inputs,
59 std::vector<Output>* grad_outputs) {
60 // dx = dy * sign(x)
61 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
62 return scope.status();
63}
64REGISTER_GRADIENT_OP("Abs", AbsGrad);
65
66Status NegGrad(const Scope& scope, const Operation& op,
67 const std::vector<Output>& grad_inputs,
68 std::vector<Output>* grad_outputs) {
69 // dx = -dy;
70 grad_outputs->push_back(Neg(scope, grad_inputs[0]));
71 return scope.status();
72}
73REGISTER_GRADIENT_OP("Neg", NegGrad);
74
75Status InvGrad(const Scope& scope, const Operation& op,
76 const std::vector<Output>& grad_inputs,
77 std::vector<Output>* grad_outputs) {
78 // Use the built-in operator.
79 grad_outputs->push_back(
80 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
81 return scope.status();
82}
83REGISTER_GRADIENT_OP("Inv", InvGrad);
84REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
85
86Status SquareGrad(const Scope& scope, const Operation& op,
87 const std::vector<Output>& grad_inputs,
88 std::vector<Output>* grad_outputs) {
89 // dy/dx = (2 * x)
90 auto two = Cast(scope, Const(scope, 2), op.input(0).type());
91 auto dydx = Mul(scope, two, op.input(0));
92 // grad(x) = grad(y) * conj(dy/dx)
93 grad_outputs->push_back(
94 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
95 return scope.status();
96}
97REGISTER_GRADIENT_OP("Square", SquareGrad);
98
99Status SqrtGrad(const Scope& scope, const Operation& op,
100 const std::vector<Output>& grad_inputs,
101 std::vector<Output>* grad_outputs) {
102 // Use the built-in operator.
103 grad_outputs->push_back(
104 internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
105 return scope.status();
106}
107REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
108
109Status RsqrtGrad(const Scope& scope, const Operation& op,
110 const std::vector<Output>& grad_inputs,
111 std::vector<Output>* grad_outputs) {
112 // Use the built-in operator.
113 grad_outputs->push_back(
114 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
115 return scope.status();
116}
117REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
118
119Status ExpGrad(const Scope& scope, const Operation& op,
120 const std::vector<Output>& grad_inputs,
121 std::vector<Output>* grad_outputs) {
122 // dy/dx = exp(x) = y
123 // grad(x) = grad(y) * conj(dy/dx)
124 // = grad(y) * conj(y)
125 grad_outputs->push_back(
126 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
127 return scope.status();
128}
129REGISTER_GRADIENT_OP("Exp", ExpGrad);
130
131Status Expm1Grad(const Scope& scope, const Operation& op,
132 const std::vector<Output>& grad_inputs,
133 std::vector<Output>* grad_outputs) {
134 // y = expm1(x)
135 // dy/dx = exp(x)
136 auto dydx = Exp(scope, op.input(0));
137 // grad(x) = grad(y) * conj(dy/dx)
138 grad_outputs->push_back(
139 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
140 return scope.status();
141}
142REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
143
144Status LogGrad(const Scope& scope, const Operation& op,
145 const std::vector<Output>& grad_inputs,
146 std::vector<Output>* grad_outputs) {
147 // y = log(x)
148 // dy/dx = 1 / x
149 auto dydx = Reciprocal(scope, op.input(0));
150 // grad(x) = grad(y) * conj(dy/dx)
151 grad_outputs->push_back(
152 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
153 return scope.status();
154}
155REGISTER_GRADIENT_OP("Log", LogGrad);
156
157Status Log1pGrad(const Scope& scope, const Operation& op,
158 const std::vector<Output>& grad_inputs,
159 std::vector<Output>* grad_outputs) {
160 // y = log1p(x)
161 // dy/dx = 1 / (1 + x)
162 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
163 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
164 // grad(x) = grad(y) * conj(dy/dx)
165 grad_outputs->push_back(
166 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
167 return scope.status();
168}
169REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
170
171Status SinhGrad(const Scope& scope, const Operation& op,
172 const std::vector<Output>& grad_inputs,
173 std::vector<Output>* grad_outputs) {
174 // y = sinh(x)
175 // dy/dx = cosh(x)
176 auto dydx = Cosh(scope, op.input(0));
177 // grad(x) = grad(y) * conj(dy/dx)
178 grad_outputs->push_back(
179 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
180 return scope.status();
181}
182REGISTER_GRADIENT_OP("Sinh", SinhGrad);
183
184Status CoshGrad(const Scope& scope, const Operation& op,
185 const std::vector<Output>& grad_inputs,
186 std::vector<Output>* grad_outputs) {
187 // y = cosh(x)
188 // dy/dx = sinh(x)
189 auto dydx = Sinh(scope, op.input(0));
190 // grad(x) = grad(y) * conj(dy/dx)
191 grad_outputs->push_back(
192 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
193 return scope.status();
194}
195REGISTER_GRADIENT_OP("Cosh", CoshGrad);
196
197Status TanhGrad(const Scope& scope, const Operation& op,
198 const std::vector<Output>& grad_inputs,
199 std::vector<Output>* grad_outputs) {
200 // Use the built-in operator.
201 // Note that the built-in operator does not return the conjugate of
202 // the gradient.
203 auto grad = grad_inputs[0];
204 // Optimization to avoid calculating conj(y) until the gradient is
205 // evaluated.
206 Scope grad_scope = scope.WithControlDependencies(grad);
207 auto y = ConjugateHelper(grad_scope, op.output(0));
208 grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
209 return grad_scope.status();
210}
211REGISTER_GRADIENT_OP("Tanh", TanhGrad);
212
213Status AsinhGrad(const Scope& scope, const Operation& op,
214 const std::vector<Output>& grad_inputs,
215 std::vector<Output>* grad_outputs) {
216 // y = asinh(x)
217 // dy/dx = 1 / cosh(y)
218 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
219 // grad(x) = grad(y) * conj(dy/dx)
220 grad_outputs->push_back(
221 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
222 return scope.status();
223}
224REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
225
226Status AcoshGrad(const Scope& scope, const Operation& op,
227 const std::vector<Output>& grad_inputs,
228 std::vector<Output>* grad_outputs) {
229 // y = acosh(x)
230 // dy/dx = 1 / sinh(y)
231 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
232 // grad(x) = grad(y) * conj(dy/dx)
233 grad_outputs->push_back(
234 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
235 return scope.status();
236}
237REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
238
239Status AtanhGrad(const Scope& scope, const Operation& op,
240 const std::vector<Output>& grad_inputs,
241 std::vector<Output>* grad_outputs) {
242 // y = atanh(x)
243 // dy/dx = 1 / (1 - x^2)
244 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
245 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
246 // grad(x) = grad(y) * conj(dy/dx)
247 grad_outputs->push_back(
248 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
249 return scope.status();
250}
251REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
252
253Status SigmoidGrad(const Scope& scope, const Operation& op,
254 const std::vector<Output>& grad_inputs,
255 std::vector<Output>* grad_outputs) {
256 // Use the built-in operator.
257 // Note that the built-in operator does not return the conjugate of
258 // the gradient.
259 auto grad = grad_inputs[0];
260 // Optimization to avoid calculating conj(y) until the gradient is
261 // evaluated.
262 Scope grad_scope = scope.WithControlDependencies(grad);
263 auto y = ConjugateHelper(grad_scope, op.output(0));
264 grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
265 return grad_scope.status();
266}
267REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
268
269Status SignGrad(const Scope& scope, const Operation& op,
270 const std::vector<Output>& grad_inputs,
271 std::vector<Output>* grad_outputs) {
272 auto shape = Shape(scope, op.input(0));
273 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
274 auto dx = Fill(scope, shape, zero);
275 grad_outputs->push_back(dx);
276 return scope.status();
277}
278REGISTER_GRADIENT_OP("Sign", SignGrad);
279
280Status SinGrad(const Scope& scope, const Operation& op,
281 const std::vector<Output>& grad_inputs,
282 std::vector<Output>* grad_outputs) {
283 // y = sin(x)
284 // dy/dx = cos(x)
285 auto dydx = Cos(scope, op.input(0));
286 // grad(x) = grad(y) * conj(dy/dx)
287 grad_outputs->push_back(
288 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
289 return scope.status();
290}
291REGISTER_GRADIENT_OP("Sin", SinGrad);
292
293Status CosGrad(const Scope& scope, const Operation& op,
294 const std::vector<Output>& grad_inputs,
295 std::vector<Output>* grad_outputs) {
296 // y = cos(x)
297 // dy/dx = -sin(x)
298 auto dydx = Neg(scope, Sin(scope, op.input(0)));
299 // grad(x) = grad(y) * conj(dy/dx)
300 grad_outputs->push_back(
301 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
302 return scope.status();
303}
304REGISTER_GRADIENT_OP("Cos", CosGrad);
305
306Status AsinGrad(const Scope& scope, const Operation& op,
307 const std::vector<Output>& grad_inputs,
308 std::vector<Output>* grad_outputs) {
309 // y = asin(x)
310 // dy/dx = 1 / sqrt(1 - x^2)
311 auto x2 = Square(scope, op.input(0));
312 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
313 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
314 // grad(x) = grad(y) * conj(dy/dx)
315 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
316 grad_outputs->push_back(dx);
317 return scope.status();
318}
319REGISTER_GRADIENT_OP("Asin", AsinGrad);
320
321Status AcosGrad(const Scope& scope, const Operation& op,
322 const std::vector<Output>& grad_inputs,
323 std::vector<Output>* grad_outputs) {
324 // y = acos(x)
325 // dy/dx = - 1 / (1 - x * x)^1/2
326 // dx = dy * (- 1 / (1 - x * x)^1/2)
327 auto x2 = Square(scope, op.input(0));
328 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
329 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
330 auto dx = Mul(scope, grad_inputs[0], dydx);
331 grad_outputs->push_back(dx);
332 return scope.status();
333}
334REGISTER_GRADIENT_OP("Acos", AcosGrad);
335
336Status TanGrad(const Scope& scope, const Operation& op,
337 const std::vector<Output>& grad_inputs,
338 std::vector<Output>* grad_outputs) {
339 // y = tan(x)
340 // dy/dx = sec(x)^2 = 1 / cos(x)^2
341 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
342 // grad(x) = grad(y) * conj(dy/dx)
343 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
344 grad_outputs->push_back(dx);
345 return scope.status();
346}
347REGISTER_GRADIENT_OP("Tan", TanGrad);
348
349Status AtanGrad(const Scope& scope, const Operation& op,
350 const std::vector<Output>& grad_inputs,
351 std::vector<Output>* grad_outputs) {
352 // y = arctan(x)
353 // dy/dx = 1 / (1 + x^2)
354 // dx = dy * (1 / (1 + x^2)
355 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
356 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
357 auto dx = Mul(scope, grad_inputs[0], dydx);
358 grad_outputs->push_back(dx);
359 return scope.status();
360}
361REGISTER_GRADIENT_OP("Atan", AtanGrad);
362
363Status Atan2Grad(const Scope& scope, const Operation& op,
364 const std::vector<Output>& grad_inputs,
365 std::vector<Output>* grad_outputs) {
366 auto y = op.input(0);
367 auto x = op.input(1);
368 Output grad_inv = Div(scope, grad_inputs[0],
369 Add(scope, Square(scope, x), Square(scope, y)));
370 grad_outputs->push_back(Mul(scope, x, grad_inv));
371 grad_outputs->push_back(Mul(scope, Neg(scope, y), grad_inv));
372 return scope.status();
373}
374REGISTER_GRADIENT_OP("Atan2", Atan2Grad);
375
376// BinaryGradCommon handles the setup for binary ops that broadcast
377// their inputs.
378Status BinaryGradCommon(const Scope& scope, const Operation& op,
379 std::vector<Output>* grad_outputs, const Output& gx_1,
380 const Output& gx_2) {
381 auto sx_1 = Shape(scope, op.input(0));
382 auto sx_2 = Shape(scope, op.input(1));
383 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
384 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
385 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
386 grad_outputs->push_back(dx_1);
387 grad_outputs->push_back(dx_2);
388 return scope.status();
389}
390
391Status AddGrad(const Scope& scope, const Operation& op,
392 const std::vector<Output>& grad_inputs,
393 std::vector<Output>* grad_outputs) {
394 // y = x_1 + x_2
395 // dy/dx_1 = dy/dx_2 = 1
396 auto gx_1 = Identity(scope, grad_inputs[0]);
397 auto gx_2 = Identity(scope, grad_inputs[0]);
398 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
399}
400REGISTER_GRADIENT_OP("Add", AddGrad);
401REGISTER_GRADIENT_OP("AddV2", AddGrad);
402
403Status SubGrad(const Scope& scope, const Operation& op,
404 const std::vector<Output>& grad_inputs,
405 std::vector<Output>* grad_outputs) {
406 // y = x_1 - x_2
407 // dy/dx_1 = 1
408 // dy/dx_2 = -1
409 auto gx_1 = Identity(scope, grad_inputs[0]);
410 auto gx_2 = Neg(scope, grad_inputs[0]);
411 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
412}
413REGISTER_GRADIENT_OP("Sub", SubGrad);
414
415Status MulGrad(const Scope& scope, const Operation& op,
416 const std::vector<Output>& grad_inputs,
417 std::vector<Output>* grad_outputs) {
418 auto x_1 = ConjugateHelper(scope, op.input(0));
419 auto x_2 = ConjugateHelper(scope, op.input(1));
420 // y = x_1 * x_2
421 // dy/dx_1 = x_2
422 // dy/dx_2 = x_1
423 auto gx_1 = Mul(scope, grad_inputs[0], x_2);
424 auto gx_2 = Mul(scope, grad_inputs[0], x_1);
425 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
426}
427REGISTER_GRADIENT_OP("Mul", MulGrad);
428
429Status DivGrad(const Scope& scope, const Operation& op,
430 const std::vector<Output>& grad_inputs,
431 std::vector<Output>* grad_outputs) {
432 auto x_1 = ConjugateHelper(scope, op.input(0));
433 auto x_2 = ConjugateHelper(scope, op.input(1));
434 // y = x_1 / x_2
435 // dy/dx_1 = 1/x_2
436 // dy/dx_2 = -x_1/x_2^2
437 auto gx_1 = Div(scope, grad_inputs[0], x_2);
438 auto gx_2 = Mul(scope, grad_inputs[0],
439 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
440 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
441}
442REGISTER_GRADIENT_OP("Div", DivGrad);
443
444Status RealDivGrad(const Scope& scope, const Operation& op,
445 const std::vector<Output>& grad_inputs,
446 std::vector<Output>* grad_outputs) {
447 auto x_1 = ConjugateHelper(scope, op.input(0));
448 auto x_2 = ConjugateHelper(scope, op.input(1));
449 // y = x_1 / x_2
450 // dy/dx_1 = 1/x_2
451 // dy/dx_2 = -x_1/x_2^2
452 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
453 auto gx_2 = Mul(scope, grad_inputs[0],
454 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
455 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
456}
457REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
458
459Status DivNoNanGrad(const Scope& scope, const Operation& op,
460 const std::vector<Output>& grad_inputs,
461 std::vector<Output>* grad_outputs) {
462 auto x_1 = ConjugateHelper(scope, op.input(0));
463 auto x_2 = ConjugateHelper(scope, op.input(1));
464 // y = x_1 / x_2
465 // dy/dx_1 = 1/x_2
466 // dy/dx_2 = -x_1/x_2^2
467 auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
468 auto gx_2 = Mul(scope, grad_inputs[0],
469 DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
470 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
471}
472REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
473
474Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
475 const std::vector<Output>& grad_inputs,
476 std::vector<Output>* grad_outputs) {
477 auto x_1 = ConjugateHelper(scope, op.input(0));
478 auto x_2 = ConjugateHelper(scope, op.input(1));
479 // y = (x_1 - x_2)^2
480 // dy/dx_1 = 2 * (x_1 - x_2)
481 // dy/dx_2 = -2 * (x_1 - x_2)
482 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
483 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
484 auto gx_2 = Neg(scope, gx_1);
485 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
486}
487REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
488
489Status AddNGrad(const Scope& scope, const Operation& op,
490 const std::vector<Output>& grad_inputs,
491 std::vector<Output>* grad_outputs) {
492 // AddN doesn't support broadcasting, so all the inputs must be the
493 // same shape.
494 // Note:
495 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
496 // hence dx_k = dy for all x_k
497 // So the gradient for AddN just transfers the incoming gradient to
498 // all outgoing gradients.
499 auto incoming = Identity(scope, grad_inputs[0]);
500 for (int32_t i = 0; i < op.num_inputs(); ++i) {
501 grad_outputs->push_back(incoming);
502 }
503 return scope.status();
504}
505REGISTER_GRADIENT_OP("AddN", AddNGrad);
506
507Status PowGrad(const Scope& scope, const Operation& op,
508 const std::vector<Output>& grad_inputs,
509 std::vector<Output>* grad_outputs) {
510 auto x = ConjugateHelper(scope, op.input(0));
511 auto y = ConjugateHelper(scope, op.input(1));
512 auto z = ConjugateHelper(scope, op.output(0));
513 auto grad = grad_inputs[0];
514 // grad * y * pow(x, y - 1)
515 auto one = Cast(scope, Const(scope, 1.0), y.type());
516 auto gx_1 =
517 Mul(scope, Mul(scope, grad, y), Pow(scope, x, Sub(scope, y, one)));
518 // Avoid false singularity at x = 0
519 DataType x_dtype = x.type();
520 auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
521 if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
522 // real(x) < 0 is fine for the complex case
523 auto log_x = Where3(scope, NotEqual(scope, x, zero), Log(scope, x),
524 ZerosLike(scope, x));
525 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
526 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
527 } else {
528 // There's no sensible real value to return if x < 0, so return 0
529 auto log_x = Where3(scope, Greater(scope, x, zero), Log(scope, x),
530 ZerosLike(scope, x));
531 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
532 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
533 }
534}
535REGISTER_GRADIENT_OP("Pow", PowGrad);
536
537// MaximumMinimumGradCommon adds shared ops to calculate gradients for
538// the binary Maximum and Minimum ops.
539Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
540 const std::vector<Output>& grad_inputs,
541 std::vector<Output>* grad_outputs,
542 const Output& comparator) {
543 // comparator is a boolean tensor, with
544 // y = x_1 at points where comparator is true, and x_2 otherwise
545 // Therefore
546 // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
547 // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
548 auto grad = grad_inputs[0];
549 auto zeros = ZerosLike(scope, grad);
550 auto gx_1 = Where3(scope, comparator, grad, zeros);
551 auto gx_2 = Where3(scope, comparator, zeros, grad);
552 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
553}
554
555Status MaximumGrad(const Scope& scope, const Operation& op,
556 const std::vector<Output>& grad_inputs,
557 std::vector<Output>* grad_outputs) {
558 auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
559 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
560 comparator);
561}
562REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
563
564Status MinimumGrad(const Scope& scope, const Operation& op,
565 const std::vector<Output>& grad_inputs,
566 std::vector<Output>* grad_outputs) {
567 auto comparator = LessEqual(scope, op.input(0), op.input(1));
568 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
569 comparator);
570}
571REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
572
573Status RealGrad(const Scope& scope, const Operation& op,
574 const std::vector<Output>& grad_inputs,
575 std::vector<Output>* grad_outputs) {
576 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
577 auto dx = Complex(scope, grad_inputs[0], zero);
578 grad_outputs->push_back(dx);
579 return scope.status();
580}
581REGISTER_GRADIENT_OP("Real", RealGrad);
582
583Status ImagGrad(const Scope& scope, const Operation& op,
584 const std::vector<Output>& grad_inputs,
585 std::vector<Output>* grad_outputs) {
586 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
587 auto dx = Complex(scope, zero, grad_inputs[0]);
588 grad_outputs->push_back(dx);
589 return scope.status();
590}
591REGISTER_GRADIENT_OP("Imag", ImagGrad);
592
593Status ComplexGrad(const Scope& scope, const Operation& op,
594 const std::vector<Output>& grad_inputs,
595 std::vector<Output>* grad_outputs) {
596 auto gx_1 = Real(scope, grad_inputs[0]);
597 auto gx_2 = Imag(scope, grad_inputs[0]);
598 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
599}
600REGISTER_GRADIENT_OP("Complex", ComplexGrad);
601
602Status AngleGrad(const Scope& scope, const Operation& op,
603 const std::vector<Output>& grad_inputs,
604 std::vector<Output>* grad_outputs) {
605 // y = Angle(x)
606 // dx = -dy / (Im(x) + iRe(x)) = -dy * z
607 auto re = Real(scope, op.input(0));
608 auto im = Imag(scope, op.input(0));
609 auto z_inv = Reciprocal(scope, Complex(scope, im, re));
610 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
611 auto grad = Complex(scope, grad_inputs[0], zero);
612 auto dx = Neg(scope, Mul(scope, grad, z_inv));
613 grad_outputs->push_back(dx);
614 return scope.status();
615}
616REGISTER_GRADIENT_OP("Angle", AngleGrad);
617
618Status ConjGrad(const Scope& scope, const Operation& op,
619 const std::vector<Output>& grad_inputs,
620 std::vector<Output>* grad_outputs) {
621 grad_outputs->push_back(Conj(scope, grad_inputs[0]));
622 return scope.status();
623}
624REGISTER_GRADIENT_OP("Conj", ConjGrad);
625
626// Integer division x / y, assuming x and y >=0, but treats x/0 = x
627Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
628 return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
629}
630
631// SumGradHelper returns the gradient for the Sum operator, and is used
632// by SumGrad and MeanGrad.
633Output SumGradHelper(const Scope& scope, const Operation& op,
634 const std::vector<Output>& grad_inputs) {
635 // The partial derivative for any input along a "reduced" dimension
636 // is just 1, so we only need replicate the output gradient on such a
637 // dimension to its "expanded" shape.
638 // Running example:
639 // input is
640 // [[a, b, c],
641 // [d, e, f]]
642 // reduction_indices = [1]
643 // Sum = [a + b + c, d + e + f]
644 // if the gradient is [g1, g2]
645 // We want the propagated gradient to be
646 // [[g1, g1, g1],
647 // [g2, g2, g2]]
648
649 // input_shape = [2, 3]
650 auto input_shape = Shape(scope, op.input(0));
651
652 // output_shape_kept_dims = [2, 1]
653 auto output_shape_kept_dims =
654 ReducedShapeHelper(scope, input_shape, op.input(1));
655
656 // This step "flips" any 1s with values from the input_shape, and
657 // replaces remaining entries with 1. This creates a shape that
658 // shows how much each dimension in the incoming gradient should be
659 // replicated.
660 // tile_scaling = [1, 3]
661 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
662
663 // grad = [[g1], [g2]]
664 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
665
666 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
667 return Tile(scope, grad, tile_scaling);
668}
669
670Status SumGrad(const Scope& scope, const Operation& op,
671 const std::vector<Output>& grad_inputs,
672 std::vector<Output>* grad_outputs) {
673 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
674
675 // Stop propagation along reduction_indices
676 grad_outputs->push_back(NoGradient());
677 return scope.status();
678}
679REGISTER_GRADIENT_OP("Sum", SumGrad);
680
681Status MeanGrad(const Scope& scope, const Operation& op,
682 const std::vector<Output>& grad_inputs,
683 std::vector<Output>* grad_outputs) {
684 // The Mean gradient is just like the Sum gradient, except that
685 // all gradients are also divided by the size of reduced groups.
686 auto sum_grad = SumGradHelper(scope, op, grad_inputs);
687
688 // The product of all entries in a tensor's shape is the total
689 // number of entries in the tensor. This step calculates
690 // n_input_entries/n_output_entries
691 // = group_size
692 auto input_shape = Shape(scope, op.input(0));
693 auto output_shape = Shape(scope, op.output(0));
694 auto zero = Const(scope, 0);
695 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
696 Prod(scope, output_shape, zero));
697
698 // propagate sum_grad/group_size
699 grad_outputs->push_back(
700 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
701
702 // Stop propagation along reduction_indices
703 grad_outputs->push_back(NoGradient());
704 return scope.status();
705}
706REGISTER_GRADIENT_OP("Mean", MeanGrad);
707
708Status ErfGrad(const Scope& scope, const Operation& op,
709 const std::vector<Output>& grad_inputs,
710 std::vector<Output>* grad_outputs) {
711 auto grad = grad_inputs[0];
712 auto two_over_root_pi =
713 Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), grad.type());
714 Scope grad_scope = scope.WithControlDependencies(grad);
715 auto x = ConjugateHelper(grad_scope, op.input(0));
716 // grad * 2/sqrt(pi) * exp(-x**2)
717 auto dx = Mul(grad_scope, Mul(grad_scope, grad, two_over_root_pi),
718 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
719 grad_outputs->push_back(dx);
720 return grad_scope.status();
721}
722REGISTER_GRADIENT_OP("Erf", ErfGrad);
723
724Status ErfinvGrad(const Scope& scope, const Operation& op,
725 const std::vector<Output>& grad_inputs,
726 std::vector<Output>* grad_outputs) {
727 auto grad = grad_inputs[0];
728 auto root_pi_over_two =
729 Cast(scope, Const(scope, std::sqrt(M_PI) / 2), grad.type());
730 Scope grad_scope = scope.WithControlDependencies(grad);
731 auto x = ConjugateHelper(grad_scope, op.input(0));
732 // grad * sqrt(pi) / 2 * exp(erfinv(x) ** 2)
733 auto dx = Mul(grad_scope, Mul(grad_scope, grad, root_pi_over_two),
734 Exp(grad_scope, Square(grad_scope, op.output(0))));
735 grad_outputs->push_back(dx);
736 return grad_scope.status();
737}
738REGISTER_GRADIENT_OP("Erfinv", ErfinvGrad);
739
740Status NdtriGrad(const Scope& scope, const Operation& op,
741 const std::vector<Output>& grad_inputs,
742 std::vector<Output>* grad_outputs) {
743 auto grad = grad_inputs[0];
744 auto root_two_pi =
745 Cast(scope, Const(scope, std::sqrt(2 * M_PI)), grad.type());
746 auto two = Cast(scope, Const(scope, 2), grad.type());
747 Scope grad_scope = scope.WithControlDependencies(grad);
748 auto x = ConjugateHelper(grad_scope, op.input(0));
749 // grad * sqrt(2 * pi) * exp(ndtri(x) ** 2 / 2)
750 auto dx = Mul(
751 grad_scope, Mul(grad_scope, grad, root_two_pi),
752 Exp(grad_scope, Div(grad_scope, Square(grad_scope, op.output(0)), two)));
753 grad_outputs->push_back(dx);
754 return grad_scope.status();
755}
756REGISTER_GRADIENT_OP("Ndtri", NdtriGrad);
757
758Status LgammaGrad(const Scope& scope, const Operation& op,
759 const std::vector<Output>& grad_inputs,
760 std::vector<Output>* grad_outputs) {
761 auto grad = grad_inputs[0];
762 Scope grad_scope = scope.WithControlDependencies(grad);
763 auto x = ConjugateHelper(grad_scope, op.input(0));
764 auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
765 grad_outputs->push_back(dx);
766 return grad_scope.status();
767}
768REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
769
770Status MinOrMaxGrad(const Scope& scope, const Operation& op,
771 const std::vector<Output>& grad_inputs,
772 std::vector<Output>* grad_outputs) {
773 // The partial derivative for any input along a "reduced" dimension
774 // is 1 when it is the min (or max) and 0 everywhere else. So the
775 // gradient calculation is identical for both operators.
776 //
777 // There's a special case for propagating gradients when there are
778 // multiple minima (or maxima) - we choose to divide the gradient
779 // equally among all matching inputs.
780 //
781 // Please note this comment
782 // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
783 // for details.
784
785 // Running example:
786 // input: [[5, 5, 5],
787 // [1, 2, -3]]
788 // reduction_indices: [1]
789 auto input = op.input(0);
790 auto reduction_indices = op.input(1);
791
792 // [2, 3]
793 auto input_shape = Shape(scope, input);
794
795 // [2, 1]
796 auto output_shape_kept_dims =
797 ReducedShapeHelper(scope, input_shape, reduction_indices);
798
799 // for op=min (say)
800 // output = [5, -3]
801 // y = [[5],
802 // [-3]]
803 auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
804
805 // reshape([g1, g2], [2, 1]) = [[g1],
806 // [g2]]
807 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
808
809 // indicators = equal(y, input)
810 // = equal([[5], [[5, 5, 5],
811 // [-3]], [1, 2, -3]])
812 // = [[1, 1, 1],
813 // [0, 0, 1]]
814 auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
815
816 // [[3],
817 // [1]]
818 auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
819 output_shape_kept_dims);
820
821 // [[1/3, 1/3, 1/3],
822 // [0, 0, 1]]
823 auto scale = Div(scope, indicators, num_selected);
824
825 // [[g1/3, g1/3, g1/3],
826 // [0, 0, g2]]
827 grad_outputs->push_back(Mul(scope, scale, grad));
828
829 // Stop propagation along reduction_indices
830 grad_outputs->push_back(NoGradient());
831 return scope.status();
832}
833REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
834REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
835
836Status ProdGrad(const Scope& scope, const Operation& op,
837 const std::vector<Output>& grad_inputs,
838 std::vector<Output>* grad_outputs) {
839 auto zero = Const(scope, 0);
840 auto one = Const(scope, 1);
841
842 // The gradient can be expressed by dividing the product by each entry of
843 // the input tensor. If our input is
844 // [
845 // [3, 4],
846 // [5, 6],
847 // [7, 8]
848 // ]
849 // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
850 // The gradient will have the same shape as the input
851 // [
852 // [105/3, 192/4],
853 // dz * [105/5, 192/6],
854 // [105/7, 192/6]
855 // ]
856 // If the input contains a zero, the division is impossible but
857 // if we take the calculation that gave the first gradient
858 // (3 * 5 * 6)/3 is equal to 5 * 6
859 // the trick will be to cumprod the elements on the axis without
860 // the element at the current position (3 in the example above).
861 // We will take as example:
862 // [
863 // [
864 // [3.0, 4.0],
865 // [5.0, 6.0],
866 // [7.0, 8.0]
867 // ],
868 // [
869 // [3.0, 5.0],
870 // [0.0, 6.0],
871 // [5.0, 6.0]
872 // ]
873 // ]
874
875 // [2, 3, 2]
876 auto input_shape = Shape(scope, op.input(0));
877
878 // The Reshape with -1 flattens the reduction indices.
879 // [1]
880 auto reduction_indices = Reshape(scope, op.input(1), {-1});
881
882 // [2, 1, 2]
883 auto output_shape_kept_dims =
884 ReducedShapeHelper(scope, input_shape, reduction_indices);
885
886 // [1, 3, 1]
887 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
888
889 // [[[105, 192]], [[0, 180]]]
890 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
891
892 // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
893 auto grad_tiled = Tile(scope, grad, tile_scaling);
894
895 Scope cpu_scope = scope.WithDevice("/cpu:0");
896
897 // [3]
898 auto rank = Rank(cpu_scope, op.input(0));
899
900 // Normalize any negative indices in the reduction_axes to positive values.
901 auto reduction_indices_pos =
902 Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
903
904 // [1]
905 auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
906
907 // [0, 1, 2]
908 auto idx = Range(cpu_scope, zero, rank, one);
909
910 // [0, 2]
911 auto other = SetDiff1D(cpu_scope, idx, reduced).out;
912
913 // [1, 0, 2]
914 auto perm =
915 Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
916
917 // 3 => [3]
918 auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
919
920 // 2 * 2 => [2]
921 auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
922
923 // [
924 // [
925 // [ 3., 4.],
926 // [ 3., 5.]
927 // ],
928 // [
929 // [ 5., 6.],
930 // [ 0., 6.]
931 // ],
932 // [
933 // [ 7., 8.],
934 // [ 5., 6.]
935 // ]
936 // ]
937 auto permuted = Transpose(scope, op.input(0), perm);
938
939 // [3, 2, 2]
940 auto permuted_shape = Shape(scope, permuted);
941
942 // [
943 // [ 3., 4., 3., 5.],
944 // [ 5., 6., 0., 6.],
945 // [ 7., 8., 5., 6.]
946 // ]
947 auto reshaped = Reshape(
948 scope, permuted,
949 Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
950
951 // [
952 // [ 1., 1., 1., 1.],
953 // [ 3., 4., 3., 5.],
954 // [ 15., 24., 0., 30.]
955 // ]
956 auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
957
958 // [
959 // [ 35., 48., 0., 36.],
960 // [ 7., 8., 5., 6.],
961 // [ 1., 1., 1., 1.]
962 // ]
963 auto right =
964 Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
965
966 // left * right =
967 // [
968 // [ 35., 48., 0., 36.],
969 // [ 21., 32., 15., 30.],
970 // [ 15., 24., 0., 30.]
971 // ]
972 // y =
973 // [
974 // [
975 // [ 35., 48.],
976 // [ 0., 36.]
977 // ],
978 // [
979 // [ 21., 32.],
980 // [ 15., 30.]
981 // ],
982 // [
983 // [ 15., 24.],
984 // [ 0., 30.]
985 // ]
986 // ]
987 auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
988
989 // out =
990 // [
991 // [
992 // [ 35., 48.],
993 // [ 21., 32.],
994 // [ 15., 24.]
995 // ],
996 // [
997 // [ 0., 36.],
998 // [ 15., 30.],
999 // [ 0., 30.]
1000 // ]
1001 // ]
1002 auto out = Mul(scope, grad_tiled,
1003 Transpose(scope, y, InvertPermutation(scope, perm)));
1004
1005 grad_outputs->push_back(Reshape(scope, out, input_shape));
1006
1007 // stop propagation along reduction_indices
1008 grad_outputs->push_back(NoGradient());
1009 return scope.status();
1010}
1011REGISTER_GRADIENT_OP("Prod", ProdGrad);
1012
1013Status SegmentSumGrad(const Scope& scope, const Operation& op,
1014 const std::vector<Output>& grad_inputs,
1015 std::vector<Output>* grad_outputs) {
1016 // The SegmentSum operation sums segments of the Tensor that have the same
1017 // index in the segment_ids parameter.
1018 // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
1019 // will produce [2 + 3 + 4, 5] = [9, 5]
1020 // The gradient that will flow back to the gather operation will look like
1021 // [x1, x2], it will have the same shape as the output of the SegmentSum
1022 // operation. The differentiation step of the SegmentSum operation just
1023 // broadcast the gradient in order to retrieve the z's shape.
1024 // dy/dz = [x1, x1, x1, x2]
1025 grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
1026
1027 // stop propagation along segment_ids
1028 grad_outputs->push_back(NoGradient());
1029 return scope.status();
1030}
1031REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
1032
1033// MatMulGrad helper function used to compute two MatMul operations
1034// based on input matrix transposition combinations.
1035Status MatMulGradHelper(const Scope& scope, const bool is_batch,
1036 const Output& x0, const bool adj_x0, const Output& x1,
1037 const bool adj_x1, const Output& y0, const bool adj_y0,
1038 const Output& y1, const bool adj_y1,
1039 std::vector<Output>* grad_outputs) {
1040 if (is_batch == false) {
1041 auto dx =
1042 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
1043 grad_outputs->push_back(dx);
1044 auto dy =
1045 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
1046 grad_outputs->push_back(dy);
1047 } else {
1048 auto dx =
1049 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
1050 grad_outputs->push_back(dx);
1051 auto dy =
1052 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
1053 grad_outputs->push_back(dy);
1054 }
1055 return scope.status();
1056}
1057
1058// MatMulGrad common used to read and check node attr state, and determine
1059// proper MatMul products for gradients based on input matrix transposition
1060// combinations.
1061Status MatMulGradCommon(const Scope& scope, const Operation& op,
1062 const bool is_batch,
1063 const std::vector<Output>& grad_inputs,
1064 const string& attr_adj_x, const string& attr_adj_y,
1065 std::vector<Output>* grad_outputs) {
1066 auto a = op.input(0);
1067 auto b = op.input(1);
1068 // Use conjugate of the inputs for MatMul
1069 if (is_batch == false) {
1070 a = ConjugateHelper(scope, a);
1071 b = ConjugateHelper(scope, b);
1072 }
1073 auto product = op.output(0);
1074
1075 bool ta;
1076 bool tb;
1077 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
1078 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
1079
1080 if (!ta && !tb) {
1081 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
1082 true, grad_inputs[0], false, grad_outputs);
1083 } else if (!ta && tb) {
1084 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
1085 grad_inputs[0], true, a, false, grad_outputs);
1086 } else if (ta && !tb) {
1087 return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
1088 false, grad_inputs[0], false, grad_outputs);
1089 }
1090 return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
1091 grad_inputs[0], true, a, true, grad_outputs);
1092}
1093
1094Status MatMulGrad(const Scope& scope, const Operation& op,
1095 const std::vector<Output>& grad_inputs,
1096 std::vector<Output>* grad_outputs) {
1097 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
1098 "transpose_b", grad_outputs);
1099}
1100REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
1101
1102Status BatchMatMulGrad(const Scope& scope, const Operation& op,
1103 const std::vector<Output>& grad_inputs,
1104 std::vector<Output>* grad_outputs) {
1105 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
1106 grad_outputs);
1107}
1108REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
1109
1110Status CumsumGrad(const Scope& scope, const Operation& op,
1111 const std::vector<Output>& grad_inputs,
1112 std::vector<Output>* grad_outputs) {
1113 if (op.num_inputs() != 2) {
1114 return errors::InvalidArgument("Cumsum requires 2 arguments");
1115 }
1116 if (grad_inputs.size() != 1) {
1117 return errors::InvalidArgument("Cumsum grad requires 1 grad input");
1118 }
1119
1120 Cumsum::Attrs attrs;
1121 TF_RETURN_IF_ERROR(
1122 GetNodeAttr(op.node()->attrs(), "exclusive", &attrs.exclusive_));
1123 bool reverse;
1124 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "reverse", &reverse));
1125 attrs.reverse_ = !reverse;
1126
1127 auto axis = op.input(1);
1128 auto sum = Cumsum(scope, grad_inputs[0], axis, attrs);
1129 grad_outputs->push_back(sum.out);
1130 grad_outputs->push_back(NoGradient());
1131 return scope.status();
1132}
1133REGISTER_GRADIENT_OP("Cumsum", CumsumGrad);
1134
1135bool IsFloatingPointDtype(DataType dtype) {
1136 static constexpr DataType valid_dtypes[] = {
1137 DT_FLOAT, DT_HALF, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128};
1138 return std::find(std::begin(valid_dtypes), std::end(valid_dtypes), dtype) !=
1139 std::end(valid_dtypes);
1140}
1141
1142Status CastGrad(const Scope& scope, const Operation& op,
1143 const std::vector<Output>& grad_inputs,
1144 std::vector<Output>* grad_outputs) {
1145 if (op.num_inputs() != 1) {
1146 return errors::InvalidArgument("Cast requires 2 arguments");
1147 }
1148 if (grad_inputs.size() != 1) {
1149 return errors::InvalidArgument("Cast grad requires 1 grad input");
1150 }
1151
1152 auto src_type = op.input_type(0);
1153 auto dst_type = grad_inputs[0].type();
1154 if (IsFloatingPointDtype(src_type) && IsFloatingPointDtype(dst_type)) {
1155 grad_outputs->push_back(Cast(scope, grad_inputs[0], src_type));
1156 } else {
1157 grad_outputs->push_back(NoGradient());
1158 }
1159 return scope.status();
1160}
1161REGISTER_GRADIENT_OP("Cast", CastGrad);
1162
1163Status SelectGrad(const Scope& scope, const Operation& op,
1164 const std::vector<Output>& grad_inputs,
1165 std::vector<Output>* grad_outputs) {
1166 if (op.num_inputs() != 3) {
1167 return errors::InvalidArgument("Select requires 3 arguments");
1168 }
1169 if (grad_inputs.size() != 1) {
1170 return errors::InvalidArgument("Select grad requires 1 grad input");
1171 }
1172
1173 auto c = op.input(0);
1174 auto zeros = ZerosLike(scope, grad_inputs[0]);
1175 grad_outputs->push_back(NoGradient()); // Condition
1176 grad_outputs->push_back(Where3(scope, c, grad_inputs[0], zeros));
1177 grad_outputs->push_back(Where3(scope, c, zeros, grad_inputs[0]));
1178 return scope.status();
1179}
1180REGISTER_GRADIENT_OP("Select", SelectGrad);
1181
1182Status SelectV2Grad(const Scope& scope, const Operation& op,
1183 const std::vector<Output>& grad_inputs,
1184 std::vector<Output>* grad_outputs) {
1185 if (op.num_inputs() != 3) {
1186 return errors::InvalidArgument("Select requires 3 arguments");
1187 }
1188
1189 if (grad_inputs.size() != 1) {
1190 return errors::InvalidArgument("Select grad requires 1 grad input");
1191 }
1192
1193 auto c = op.input(0);
1194 auto x = op.input(1);
1195 auto y = op.input(2);
1196
1197 auto zeros = ZerosLike(scope, grad_inputs[0]);
1198 auto gx = SelectV2(scope, c, grad_inputs[0], zeros);
1199 auto x_shape = Shape(scope, x);
1200 auto output_shape = Shape(scope, op.output(0));
1201
1202 // Reduce away broadcasted leading dims.
1203 auto reduce_x = internal::BroadcastGradientArgs(scope, x_shape, output_shape);
1204 auto gx_sum =
1205 ReduceSum(scope, gx, /*axis=*/reduce_x.r0, ReduceSum::KeepDims(true));
1206 auto gx_sum_reshape = Reshape(scope, gx_sum, x_shape);
1207
1208 auto gy = SelectV2(scope, c, zeros, grad_inputs[0]);
1209 auto y_shape = Shape(scope, y);
1210
1211 // Reduce away broadcasted leading dims.
1212 auto reduce_y = internal::BroadcastGradientArgs(scope, y_shape, output_shape);
1213 auto gy_sum =
1214 ReduceSum(scope, gy, /*axis=*/reduce_y.r0, ReduceSum::KeepDims(true));
1215 auto gy_sum_reshape = Reshape(scope, gy_sum, y_shape);
1216
1217 grad_outputs->push_back(NoGradient()); // Condition
1218 grad_outputs->push_back(gx_sum_reshape);
1219 grad_outputs->push_back(gy_sum_reshape);
1220 return scope.status();
1221}
1222
1223REGISTER_GRADIENT_OP("SelectV2", SelectV2Grad);
1224
1225// Helper function for unsorted segment ops.
1226// Returns 'ids' with negative elements replaced by 0.
1227Output GetZeroClippedIndices(const Scope& scope, const Output& ids) {
1228 return Maximum(scope, ids, ZerosLike(scope, ids));
1229}
1230
1231// Helper function for unsorted segment ops.
1232// Returns a mask of where 'ids' are positive, reshaped so that it will be
1233// broadcastable to the result shape of gathering params by ids.
1234Output GetIsPositive(const Scope& scope, const Output& params,
1235 const Output& ids) {
1236 Output is_positive = GreaterEqual(scope, ids, ZerosLike(scope, ids));
1237 // tf.where(condition, x, y) requires condition to have the same shape as x
1238 // and y.
1239 Output is_positive_shape = Shape(scope, is_positive);
1240 Output ones =
1241 Tile(scope, Const(scope, {1}), Subtract(scope, Rank(scope, params), {1}));
1242 auto broadcastable_shape = Concat(scope, {is_positive_shape, ones},
1243 /*axis=*/0);
1244 is_positive = Reshape(scope, is_positive, broadcastable_shape);
1245 is_positive = LogicalAnd(scope, is_positive, OnesLike(scope, is_positive));
1246 return is_positive;
1247}
1248
1249// Helper function for unsorted segment ops.
1250// Gathers params for positive segment ids and gathers 0 for inputs with
1251// negative segment id.
1252Output GatherDropNegatives(const Scope& scope, const Output& params,
1253 Output& zero_clipped_indices, Output& is_positive) {
1254 auto gathered = Gather(scope, params, zero_clipped_indices);
1255 // Replace gathered params of negative indices with 0.
1256 auto zero_slice = ZerosLike(scope, gathered);
1257 return SelectV2(scope, is_positive, gathered, zero_slice);
1258}
1259
1260Status UnsortedSegmentMinOrMaxGrad(const Scope& scope, const Operation& op,
1261 const std::vector<Output>& grad_inputs,
1262 std::vector<Output>* grad_outputs) {
1263 if (op.num_inputs() != 3) {
1264 return errors::InvalidArgument("UnsortedSegmentMax requires 3 arguments");
1265 }
1266
1267 if (grad_inputs.size() != 1) {
1268 return errors::InvalidArgument(
1269 "UnsortedSegmentMax grad requires 1 grad input");
1270 }
1271
1272 auto grad = grad_inputs[0];
1273 // Get the number of selected (minimum or maximum) elements in each segment.
1274 auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1275 auto is_positive = GetIsPositive(scope, op.output(0), op.input(1));
1276 Output gathered_outputs = GatherDropNegatives(
1277 scope, op.output(0), zero_clipped_indices, is_positive);
1278 Output is_selected = Equal(scope, op.input(0), gathered_outputs);
1279 is_selected = LogicalAnd(scope, is_selected, is_positive);
1280 auto num_selected = UnsortedSegmentSum(
1281 scope, Cast(scope, is_selected, grad.type()), op.input(1), op.input(2));
1282 // Compute the gradient for each segment.The gradient for the ith segment is
1283 // divided evenly among the selected elements in that segment.
1284 auto weighted_grads = Div(scope, grad, num_selected);
1285 auto gathered_grads = GatherDropNegatives(scope, weighted_grads,
1286 zero_clipped_indices, is_positive);
1287 auto zeros = ZerosLike(scope, gathered_grads);
1288 grad_outputs->push_back(SelectV2(scope, is_selected, gathered_grads, zeros));
1289 grad_outputs->push_back(NoGradient());
1290 grad_outputs->push_back(NoGradient());
1291 return scope.status();
1292}
1293
1294REGISTER_GRADIENT_OP("UnsortedSegmentMax", UnsortedSegmentMinOrMaxGrad);
1295REGISTER_GRADIENT_OP("UnsortedSegmentMin", UnsortedSegmentMinOrMaxGrad);
1296
1297Status UnsortedSegmentSumGrad(const Scope& scope, const Operation& op,
1298 const std::vector<Output>& grad_inputs,
1299 std::vector<Output>* grad_outputs) {
1300 if (op.num_inputs() != 3) {
1301 return errors::InvalidArgument("UnsortedSegmentSum requires 3 arguments");
1302 }
1303
1304 if (grad_inputs.size() != 1) {
1305 return errors::InvalidArgument(
1306 "UnsortedSegmentSum grad requires 1 grad input");
1307 }
1308
1309 auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1310 auto is_positive = GetIsPositive(scope, grad_inputs[0], op.input(1));
1311 grad_outputs->push_back(GatherDropNegatives(
1312 scope, grad_inputs[0], zero_clipped_indices, is_positive));
1313 grad_outputs->push_back(NoGradient());
1314 grad_outputs->push_back(NoGradient());
1315 return scope.status();
1316}
1317
1318REGISTER_GRADIENT_OP("UnsortedSegmentSum", UnsortedSegmentSumGrad);
1319
1320} // anonymous namespace
1321} // namespace ops
1322} // namespace tensorflow
1323