1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cmath> |
17 | |
18 | #include "tensorflow/cc/framework/grad_op_registry.h" |
19 | #include "tensorflow/cc/framework/gradients.h" |
20 | #include "tensorflow/cc/gradients/grad_helper.h" |
21 | #include "tensorflow/cc/ops/array_ops.h" |
22 | #include "tensorflow/cc/ops/array_ops_internal.h" |
23 | #include "tensorflow/cc/ops/math_ops.h" |
24 | #include "tensorflow/cc/ops/math_ops_internal.h" |
25 | #include "tensorflow/cc/ops/standard_ops.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace ops { |
29 | namespace { |
30 | |
31 | // Logical operations have no gradients. |
32 | REGISTER_NO_GRADIENT_OP("Less" ); |
33 | REGISTER_NO_GRADIENT_OP("LessEqual" ); |
34 | REGISTER_NO_GRADIENT_OP("Greater" ); |
35 | REGISTER_NO_GRADIENT_OP("GreaterEqual" ); |
36 | REGISTER_NO_GRADIENT_OP("Equal" ); |
37 | REGISTER_NO_GRADIENT_OP("ApproximateEqual" ); |
38 | REGISTER_NO_GRADIENT_OP("NotEqual" ); |
39 | REGISTER_NO_GRADIENT_OP("LogicalAnd" ); |
40 | REGISTER_NO_GRADIENT_OP("LogicalOr" ); |
41 | REGISTER_NO_GRADIENT_OP("LogicalNot" ); |
42 | REGISTER_NO_GRADIENT_OP("Floor" ); |
43 | |
44 | // Conjugate helper function returns the conjugate of an Output if it |
45 | // is complex valued. |
46 | Output ConjugateHelper(const Scope& scope, const Output& out) { |
47 | DataType dtype = out.type(); |
48 | if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { |
49 | return Conj(scope, out); |
50 | } else { |
51 | return out; |
52 | } |
53 | } |
54 | |
55 | // TODO(andydavis) Add control dependencies to gradient functions (as needed). |
56 | |
57 | Status AbsGrad(const Scope& scope, const Operation& op, |
58 | const std::vector<Output>& grad_inputs, |
59 | std::vector<Output>* grad_outputs) { |
60 | // dx = dy * sign(x) |
61 | grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); |
62 | return scope.status(); |
63 | } |
64 | REGISTER_GRADIENT_OP("Abs" , AbsGrad); |
65 | |
66 | Status NegGrad(const Scope& scope, const Operation& op, |
67 | const std::vector<Output>& grad_inputs, |
68 | std::vector<Output>* grad_outputs) { |
69 | // dx = -dy; |
70 | grad_outputs->push_back(Neg(scope, grad_inputs[0])); |
71 | return scope.status(); |
72 | } |
73 | REGISTER_GRADIENT_OP("Neg" , NegGrad); |
74 | |
75 | Status InvGrad(const Scope& scope, const Operation& op, |
76 | const std::vector<Output>& grad_inputs, |
77 | std::vector<Output>* grad_outputs) { |
78 | // Use the built-in operator. |
79 | grad_outputs->push_back( |
80 | internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0])); |
81 | return scope.status(); |
82 | } |
83 | REGISTER_GRADIENT_OP("Inv" , InvGrad); |
84 | REGISTER_GRADIENT_OP("Reciprocal" , InvGrad); |
85 | |
86 | Status SquareGrad(const Scope& scope, const Operation& op, |
87 | const std::vector<Output>& grad_inputs, |
88 | std::vector<Output>* grad_outputs) { |
89 | // dy/dx = (2 * x) |
90 | auto two = Cast(scope, Const(scope, 2), op.input(0).type()); |
91 | auto dydx = Mul(scope, two, op.input(0)); |
92 | // grad(x) = grad(y) * conj(dy/dx) |
93 | grad_outputs->push_back( |
94 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
95 | return scope.status(); |
96 | } |
97 | REGISTER_GRADIENT_OP("Square" , SquareGrad); |
98 | |
99 | Status SqrtGrad(const Scope& scope, const Operation& op, |
100 | const std::vector<Output>& grad_inputs, |
101 | std::vector<Output>* grad_outputs) { |
102 | // Use the built-in operator. |
103 | grad_outputs->push_back( |
104 | internal::SqrtGrad(scope, op.output(0), grad_inputs[0])); |
105 | return scope.status(); |
106 | } |
107 | REGISTER_GRADIENT_OP("Sqrt" , SqrtGrad); |
108 | |
109 | Status RsqrtGrad(const Scope& scope, const Operation& op, |
110 | const std::vector<Output>& grad_inputs, |
111 | std::vector<Output>* grad_outputs) { |
112 | // Use the built-in operator. |
113 | grad_outputs->push_back( |
114 | internal::RsqrtGrad(scope, op.output(0), grad_inputs[0])); |
115 | return scope.status(); |
116 | } |
117 | REGISTER_GRADIENT_OP("Rsqrt" , RsqrtGrad); |
118 | |
119 | Status ExpGrad(const Scope& scope, const Operation& op, |
120 | const std::vector<Output>& grad_inputs, |
121 | std::vector<Output>* grad_outputs) { |
122 | // dy/dx = exp(x) = y |
123 | // grad(x) = grad(y) * conj(dy/dx) |
124 | // = grad(y) * conj(y) |
125 | grad_outputs->push_back( |
126 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); |
127 | return scope.status(); |
128 | } |
129 | REGISTER_GRADIENT_OP("Exp" , ExpGrad); |
130 | |
131 | Status Expm1Grad(const Scope& scope, const Operation& op, |
132 | const std::vector<Output>& grad_inputs, |
133 | std::vector<Output>* grad_outputs) { |
134 | // y = expm1(x) |
135 | // dy/dx = exp(x) |
136 | auto dydx = Exp(scope, op.input(0)); |
137 | // grad(x) = grad(y) * conj(dy/dx) |
138 | grad_outputs->push_back( |
139 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
140 | return scope.status(); |
141 | } |
142 | REGISTER_GRADIENT_OP("Expm1" , Expm1Grad); |
143 | |
144 | Status LogGrad(const Scope& scope, const Operation& op, |
145 | const std::vector<Output>& grad_inputs, |
146 | std::vector<Output>* grad_outputs) { |
147 | // y = log(x) |
148 | // dy/dx = 1 / x |
149 | auto dydx = Reciprocal(scope, op.input(0)); |
150 | // grad(x) = grad(y) * conj(dy/dx) |
151 | grad_outputs->push_back( |
152 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
153 | return scope.status(); |
154 | } |
155 | REGISTER_GRADIENT_OP("Log" , LogGrad); |
156 | |
157 | Status Log1pGrad(const Scope& scope, const Operation& op, |
158 | const std::vector<Output>& grad_inputs, |
159 | std::vector<Output>* grad_outputs) { |
160 | // y = log1p(x) |
161 | // dy/dx = 1 / (1 + x) |
162 | auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); |
163 | auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); |
164 | // grad(x) = grad(y) * conj(dy/dx) |
165 | grad_outputs->push_back( |
166 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
167 | return scope.status(); |
168 | } |
169 | REGISTER_GRADIENT_OP("Log1p" , Log1pGrad); |
170 | |
171 | Status SinhGrad(const Scope& scope, const Operation& op, |
172 | const std::vector<Output>& grad_inputs, |
173 | std::vector<Output>* grad_outputs) { |
174 | // y = sinh(x) |
175 | // dy/dx = cosh(x) |
176 | auto dydx = Cosh(scope, op.input(0)); |
177 | // grad(x) = grad(y) * conj(dy/dx) |
178 | grad_outputs->push_back( |
179 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
180 | return scope.status(); |
181 | } |
182 | REGISTER_GRADIENT_OP("Sinh" , SinhGrad); |
183 | |
184 | Status CoshGrad(const Scope& scope, const Operation& op, |
185 | const std::vector<Output>& grad_inputs, |
186 | std::vector<Output>* grad_outputs) { |
187 | // y = cosh(x) |
188 | // dy/dx = sinh(x) |
189 | auto dydx = Sinh(scope, op.input(0)); |
190 | // grad(x) = grad(y) * conj(dy/dx) |
191 | grad_outputs->push_back( |
192 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
193 | return scope.status(); |
194 | } |
195 | REGISTER_GRADIENT_OP("Cosh" , CoshGrad); |
196 | |
197 | Status TanhGrad(const Scope& scope, const Operation& op, |
198 | const std::vector<Output>& grad_inputs, |
199 | std::vector<Output>* grad_outputs) { |
200 | // Use the built-in operator. |
201 | // Note that the built-in operator does not return the conjugate of |
202 | // the gradient. |
203 | auto grad = grad_inputs[0]; |
204 | // Optimization to avoid calculating conj(y) until the gradient is |
205 | // evaluated. |
206 | Scope grad_scope = scope.WithControlDependencies(grad); |
207 | auto y = ConjugateHelper(grad_scope, op.output(0)); |
208 | grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); |
209 | return grad_scope.status(); |
210 | } |
211 | REGISTER_GRADIENT_OP("Tanh" , TanhGrad); |
212 | |
213 | Status AsinhGrad(const Scope& scope, const Operation& op, |
214 | const std::vector<Output>& grad_inputs, |
215 | std::vector<Output>* grad_outputs) { |
216 | // y = asinh(x) |
217 | // dy/dx = 1 / cosh(y) |
218 | auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); |
219 | // grad(x) = grad(y) * conj(dy/dx) |
220 | grad_outputs->push_back( |
221 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
222 | return scope.status(); |
223 | } |
224 | REGISTER_GRADIENT_OP("Asinh" , AsinhGrad); |
225 | |
226 | Status AcoshGrad(const Scope& scope, const Operation& op, |
227 | const std::vector<Output>& grad_inputs, |
228 | std::vector<Output>* grad_outputs) { |
229 | // y = acosh(x) |
230 | // dy/dx = 1 / sinh(y) |
231 | auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); |
232 | // grad(x) = grad(y) * conj(dy/dx) |
233 | grad_outputs->push_back( |
234 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
235 | return scope.status(); |
236 | } |
237 | REGISTER_GRADIENT_OP("Acosh" , AcoshGrad); |
238 | |
239 | Status AtanhGrad(const Scope& scope, const Operation& op, |
240 | const std::vector<Output>& grad_inputs, |
241 | std::vector<Output>* grad_outputs) { |
242 | // y = atanh(x) |
243 | // dy/dx = 1 / (1 - x^2) |
244 | auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); |
245 | auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); |
246 | // grad(x) = grad(y) * conj(dy/dx) |
247 | grad_outputs->push_back( |
248 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
249 | return scope.status(); |
250 | } |
251 | REGISTER_GRADIENT_OP("Atanh" , AtanhGrad); |
252 | |
253 | Status SigmoidGrad(const Scope& scope, const Operation& op, |
254 | const std::vector<Output>& grad_inputs, |
255 | std::vector<Output>* grad_outputs) { |
256 | // Use the built-in operator. |
257 | // Note that the built-in operator does not return the conjugate of |
258 | // the gradient. |
259 | auto grad = grad_inputs[0]; |
260 | // Optimization to avoid calculating conj(y) until the gradient is |
261 | // evaluated. |
262 | Scope grad_scope = scope.WithControlDependencies(grad); |
263 | auto y = ConjugateHelper(grad_scope, op.output(0)); |
264 | grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); |
265 | return grad_scope.status(); |
266 | } |
267 | REGISTER_GRADIENT_OP("Sigmoid" , SigmoidGrad); |
268 | |
269 | Status SignGrad(const Scope& scope, const Operation& op, |
270 | const std::vector<Output>& grad_inputs, |
271 | std::vector<Output>* grad_outputs) { |
272 | auto shape = Shape(scope, op.input(0)); |
273 | auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); |
274 | auto dx = Fill(scope, shape, zero); |
275 | grad_outputs->push_back(dx); |
276 | return scope.status(); |
277 | } |
278 | REGISTER_GRADIENT_OP("Sign" , SignGrad); |
279 | |
280 | Status SinGrad(const Scope& scope, const Operation& op, |
281 | const std::vector<Output>& grad_inputs, |
282 | std::vector<Output>* grad_outputs) { |
283 | // y = sin(x) |
284 | // dy/dx = cos(x) |
285 | auto dydx = Cos(scope, op.input(0)); |
286 | // grad(x) = grad(y) * conj(dy/dx) |
287 | grad_outputs->push_back( |
288 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
289 | return scope.status(); |
290 | } |
291 | REGISTER_GRADIENT_OP("Sin" , SinGrad); |
292 | |
293 | Status CosGrad(const Scope& scope, const Operation& op, |
294 | const std::vector<Output>& grad_inputs, |
295 | std::vector<Output>* grad_outputs) { |
296 | // y = cos(x) |
297 | // dy/dx = -sin(x) |
298 | auto dydx = Neg(scope, Sin(scope, op.input(0))); |
299 | // grad(x) = grad(y) * conj(dy/dx) |
300 | grad_outputs->push_back( |
301 | Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); |
302 | return scope.status(); |
303 | } |
304 | REGISTER_GRADIENT_OP("Cos" , CosGrad); |
305 | |
306 | Status AsinGrad(const Scope& scope, const Operation& op, |
307 | const std::vector<Output>& grad_inputs, |
308 | std::vector<Output>* grad_outputs) { |
309 | // y = asin(x) |
310 | // dy/dx = 1 / sqrt(1 - x^2) |
311 | auto x2 = Square(scope, op.input(0)); |
312 | auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); |
313 | auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); |
314 | // grad(x) = grad(y) * conj(dy/dx) |
315 | auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); |
316 | grad_outputs->push_back(dx); |
317 | return scope.status(); |
318 | } |
319 | REGISTER_GRADIENT_OP("Asin" , AsinGrad); |
320 | |
321 | Status AcosGrad(const Scope& scope, const Operation& op, |
322 | const std::vector<Output>& grad_inputs, |
323 | std::vector<Output>* grad_outputs) { |
324 | // y = acos(x) |
325 | // dy/dx = - 1 / (1 - x * x)^1/2 |
326 | // dx = dy * (- 1 / (1 - x * x)^1/2) |
327 | auto x2 = Square(scope, op.input(0)); |
328 | auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); |
329 | auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)))); |
330 | auto dx = Mul(scope, grad_inputs[0], dydx); |
331 | grad_outputs->push_back(dx); |
332 | return scope.status(); |
333 | } |
334 | REGISTER_GRADIENT_OP("Acos" , AcosGrad); |
335 | |
336 | Status TanGrad(const Scope& scope, const Operation& op, |
337 | const std::vector<Output>& grad_inputs, |
338 | std::vector<Output>* grad_outputs) { |
339 | // y = tan(x) |
340 | // dy/dx = sec(x)^2 = 1 / cos(x)^2 |
341 | auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); |
342 | // grad(x) = grad(y) * conj(dy/dx) |
343 | auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); |
344 | grad_outputs->push_back(dx); |
345 | return scope.status(); |
346 | } |
347 | REGISTER_GRADIENT_OP("Tan" , TanGrad); |
348 | |
349 | Status AtanGrad(const Scope& scope, const Operation& op, |
350 | const std::vector<Output>& grad_inputs, |
351 | std::vector<Output>* grad_outputs) { |
352 | // y = arctan(x) |
353 | // dy/dx = 1 / (1 + x^2) |
354 | // dx = dy * (1 / (1 + x^2) |
355 | auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); |
356 | auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0)))); |
357 | auto dx = Mul(scope, grad_inputs[0], dydx); |
358 | grad_outputs->push_back(dx); |
359 | return scope.status(); |
360 | } |
361 | REGISTER_GRADIENT_OP("Atan" , AtanGrad); |
362 | |
363 | Status Atan2Grad(const Scope& scope, const Operation& op, |
364 | const std::vector<Output>& grad_inputs, |
365 | std::vector<Output>* grad_outputs) { |
366 | auto y = op.input(0); |
367 | auto x = op.input(1); |
368 | Output grad_inv = Div(scope, grad_inputs[0], |
369 | Add(scope, Square(scope, x), Square(scope, y))); |
370 | grad_outputs->push_back(Mul(scope, x, grad_inv)); |
371 | grad_outputs->push_back(Mul(scope, Neg(scope, y), grad_inv)); |
372 | return scope.status(); |
373 | } |
374 | REGISTER_GRADIENT_OP("Atan2" , Atan2Grad); |
375 | |
376 | // BinaryGradCommon handles the setup for binary ops that broadcast |
377 | // their inputs. |
378 | Status BinaryGradCommon(const Scope& scope, const Operation& op, |
379 | std::vector<Output>* grad_outputs, const Output& gx_1, |
380 | const Output& gx_2) { |
381 | auto sx_1 = Shape(scope, op.input(0)); |
382 | auto sx_2 = Shape(scope, op.input(1)); |
383 | auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2); |
384 | auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1); |
385 | auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2); |
386 | grad_outputs->push_back(dx_1); |
387 | grad_outputs->push_back(dx_2); |
388 | return scope.status(); |
389 | } |
390 | |
391 | Status AddGrad(const Scope& scope, const Operation& op, |
392 | const std::vector<Output>& grad_inputs, |
393 | std::vector<Output>* grad_outputs) { |
394 | // y = x_1 + x_2 |
395 | // dy/dx_1 = dy/dx_2 = 1 |
396 | auto gx_1 = Identity(scope, grad_inputs[0]); |
397 | auto gx_2 = Identity(scope, grad_inputs[0]); |
398 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
399 | } |
400 | REGISTER_GRADIENT_OP("Add" , AddGrad); |
401 | REGISTER_GRADIENT_OP("AddV2" , AddGrad); |
402 | |
403 | Status SubGrad(const Scope& scope, const Operation& op, |
404 | const std::vector<Output>& grad_inputs, |
405 | std::vector<Output>* grad_outputs) { |
406 | // y = x_1 - x_2 |
407 | // dy/dx_1 = 1 |
408 | // dy/dx_2 = -1 |
409 | auto gx_1 = Identity(scope, grad_inputs[0]); |
410 | auto gx_2 = Neg(scope, grad_inputs[0]); |
411 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
412 | } |
413 | REGISTER_GRADIENT_OP("Sub" , SubGrad); |
414 | |
415 | Status MulGrad(const Scope& scope, const Operation& op, |
416 | const std::vector<Output>& grad_inputs, |
417 | std::vector<Output>* grad_outputs) { |
418 | auto x_1 = ConjugateHelper(scope, op.input(0)); |
419 | auto x_2 = ConjugateHelper(scope, op.input(1)); |
420 | // y = x_1 * x_2 |
421 | // dy/dx_1 = x_2 |
422 | // dy/dx_2 = x_1 |
423 | auto gx_1 = Mul(scope, grad_inputs[0], x_2); |
424 | auto gx_2 = Mul(scope, grad_inputs[0], x_1); |
425 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
426 | } |
427 | REGISTER_GRADIENT_OP("Mul" , MulGrad); |
428 | |
429 | Status DivGrad(const Scope& scope, const Operation& op, |
430 | const std::vector<Output>& grad_inputs, |
431 | std::vector<Output>* grad_outputs) { |
432 | auto x_1 = ConjugateHelper(scope, op.input(0)); |
433 | auto x_2 = ConjugateHelper(scope, op.input(1)); |
434 | // y = x_1 / x_2 |
435 | // dy/dx_1 = 1/x_2 |
436 | // dy/dx_2 = -x_1/x_2^2 |
437 | auto gx_1 = Div(scope, grad_inputs[0], x_2); |
438 | auto gx_2 = Mul(scope, grad_inputs[0], |
439 | Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2)); |
440 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
441 | } |
442 | REGISTER_GRADIENT_OP("Div" , DivGrad); |
443 | |
444 | Status RealDivGrad(const Scope& scope, const Operation& op, |
445 | const std::vector<Output>& grad_inputs, |
446 | std::vector<Output>* grad_outputs) { |
447 | auto x_1 = ConjugateHelper(scope, op.input(0)); |
448 | auto x_2 = ConjugateHelper(scope, op.input(1)); |
449 | // y = x_1 / x_2 |
450 | // dy/dx_1 = 1/x_2 |
451 | // dy/dx_2 = -x_1/x_2^2 |
452 | auto gx_1 = RealDiv(scope, grad_inputs[0], x_2); |
453 | auto gx_2 = Mul(scope, grad_inputs[0], |
454 | RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2)); |
455 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
456 | } |
457 | REGISTER_GRADIENT_OP("RealDiv" , RealDivGrad); |
458 | |
459 | Status DivNoNanGrad(const Scope& scope, const Operation& op, |
460 | const std::vector<Output>& grad_inputs, |
461 | std::vector<Output>* grad_outputs) { |
462 | auto x_1 = ConjugateHelper(scope, op.input(0)); |
463 | auto x_2 = ConjugateHelper(scope, op.input(1)); |
464 | // y = x_1 / x_2 |
465 | // dy/dx_1 = 1/x_2 |
466 | // dy/dx_2 = -x_1/x_2^2 |
467 | auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2); |
468 | auto gx_2 = Mul(scope, grad_inputs[0], |
469 | DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2)); |
470 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
471 | } |
472 | REGISTER_GRADIENT_OP("DivNoNan" , DivNoNanGrad); |
473 | |
474 | Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, |
475 | const std::vector<Output>& grad_inputs, |
476 | std::vector<Output>* grad_outputs) { |
477 | auto x_1 = ConjugateHelper(scope, op.input(0)); |
478 | auto x_2 = ConjugateHelper(scope, op.input(1)); |
479 | // y = (x_1 - x_2)^2 |
480 | // dy/dx_1 = 2 * (x_1 - x_2) |
481 | // dy/dx_2 = -2 * (x_1 - x_2) |
482 | auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type()); |
483 | auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2))); |
484 | auto gx_2 = Neg(scope, gx_1); |
485 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
486 | } |
487 | REGISTER_GRADIENT_OP("SquaredDifference" , SquaredDifferenceGrad); |
488 | |
489 | Status AddNGrad(const Scope& scope, const Operation& op, |
490 | const std::vector<Output>& grad_inputs, |
491 | std::vector<Output>* grad_outputs) { |
492 | // AddN doesn't support broadcasting, so all the inputs must be the |
493 | // same shape. |
494 | // Note: |
495 | // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k |
496 | // hence dx_k = dy for all x_k |
497 | // So the gradient for AddN just transfers the incoming gradient to |
498 | // all outgoing gradients. |
499 | auto incoming = Identity(scope, grad_inputs[0]); |
500 | for (int32_t i = 0; i < op.num_inputs(); ++i) { |
501 | grad_outputs->push_back(incoming); |
502 | } |
503 | return scope.status(); |
504 | } |
505 | REGISTER_GRADIENT_OP("AddN" , AddNGrad); |
506 | |
507 | Status PowGrad(const Scope& scope, const Operation& op, |
508 | const std::vector<Output>& grad_inputs, |
509 | std::vector<Output>* grad_outputs) { |
510 | auto x = ConjugateHelper(scope, op.input(0)); |
511 | auto y = ConjugateHelper(scope, op.input(1)); |
512 | auto z = ConjugateHelper(scope, op.output(0)); |
513 | auto grad = grad_inputs[0]; |
514 | // grad * y * pow(x, y - 1) |
515 | auto one = Cast(scope, Const(scope, 1.0), y.type()); |
516 | auto gx_1 = |
517 | Mul(scope, Mul(scope, grad, y), Pow(scope, x, Sub(scope, y, one))); |
518 | // Avoid false singularity at x = 0 |
519 | DataType x_dtype = x.type(); |
520 | auto zero = Cast(scope, Const(scope, 0.0), x_dtype); |
521 | if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { |
522 | // real(x) < 0 is fine for the complex case |
523 | auto log_x = Where3(scope, NotEqual(scope, x, zero), Log(scope, x), |
524 | ZerosLike(scope, x)); |
525 | auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); |
526 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); |
527 | } else { |
528 | // There's no sensible real value to return if x < 0, so return 0 |
529 | auto log_x = Where3(scope, Greater(scope, x, zero), Log(scope, x), |
530 | ZerosLike(scope, x)); |
531 | auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); |
532 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); |
533 | } |
534 | } |
535 | REGISTER_GRADIENT_OP("Pow" , PowGrad); |
536 | |
537 | // MaximumMinimumGradCommon adds shared ops to calculate gradients for |
538 | // the binary Maximum and Minimum ops. |
539 | Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, |
540 | const std::vector<Output>& grad_inputs, |
541 | std::vector<Output>* grad_outputs, |
542 | const Output& comparator) { |
543 | // comparator is a boolean tensor, with |
544 | // y = x_1 at points where comparator is true, and x_2 otherwise |
545 | // Therefore |
546 | // dy/dx_1 = 1 where comparator is true, and 0 otherwise. |
547 | // dy/dx_2 = 0 where comparator is true, and 1 otherwise. |
548 | auto grad = grad_inputs[0]; |
549 | auto zeros = ZerosLike(scope, grad); |
550 | auto gx_1 = Where3(scope, comparator, grad, zeros); |
551 | auto gx_2 = Where3(scope, comparator, zeros, grad); |
552 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
553 | } |
554 | |
555 | Status MaximumGrad(const Scope& scope, const Operation& op, |
556 | const std::vector<Output>& grad_inputs, |
557 | std::vector<Output>* grad_outputs) { |
558 | auto comparator = GreaterEqual(scope, op.input(0), op.input(1)); |
559 | return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, |
560 | comparator); |
561 | } |
562 | REGISTER_GRADIENT_OP("Maximum" , MaximumGrad); |
563 | |
564 | Status MinimumGrad(const Scope& scope, const Operation& op, |
565 | const std::vector<Output>& grad_inputs, |
566 | std::vector<Output>* grad_outputs) { |
567 | auto comparator = LessEqual(scope, op.input(0), op.input(1)); |
568 | return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs, |
569 | comparator); |
570 | } |
571 | REGISTER_GRADIENT_OP("Minimum" , MinimumGrad); |
572 | |
573 | Status RealGrad(const Scope& scope, const Operation& op, |
574 | const std::vector<Output>& grad_inputs, |
575 | std::vector<Output>* grad_outputs) { |
576 | auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); |
577 | auto dx = Complex(scope, grad_inputs[0], zero); |
578 | grad_outputs->push_back(dx); |
579 | return scope.status(); |
580 | } |
581 | REGISTER_GRADIENT_OP("Real" , RealGrad); |
582 | |
583 | Status ImagGrad(const Scope& scope, const Operation& op, |
584 | const std::vector<Output>& grad_inputs, |
585 | std::vector<Output>* grad_outputs) { |
586 | auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); |
587 | auto dx = Complex(scope, zero, grad_inputs[0]); |
588 | grad_outputs->push_back(dx); |
589 | return scope.status(); |
590 | } |
591 | REGISTER_GRADIENT_OP("Imag" , ImagGrad); |
592 | |
593 | Status ComplexGrad(const Scope& scope, const Operation& op, |
594 | const std::vector<Output>& grad_inputs, |
595 | std::vector<Output>* grad_outputs) { |
596 | auto gx_1 = Real(scope, grad_inputs[0]); |
597 | auto gx_2 = Imag(scope, grad_inputs[0]); |
598 | return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); |
599 | } |
600 | REGISTER_GRADIENT_OP("Complex" , ComplexGrad); |
601 | |
602 | Status AngleGrad(const Scope& scope, const Operation& op, |
603 | const std::vector<Output>& grad_inputs, |
604 | std::vector<Output>* grad_outputs) { |
605 | // y = Angle(x) |
606 | // dx = -dy / (Im(x) + iRe(x)) = -dy * z |
607 | auto re = Real(scope, op.input(0)); |
608 | auto im = Imag(scope, op.input(0)); |
609 | auto z_inv = Reciprocal(scope, Complex(scope, im, re)); |
610 | auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type()); |
611 | auto grad = Complex(scope, grad_inputs[0], zero); |
612 | auto dx = Neg(scope, Mul(scope, grad, z_inv)); |
613 | grad_outputs->push_back(dx); |
614 | return scope.status(); |
615 | } |
616 | REGISTER_GRADIENT_OP("Angle" , AngleGrad); |
617 | |
618 | Status ConjGrad(const Scope& scope, const Operation& op, |
619 | const std::vector<Output>& grad_inputs, |
620 | std::vector<Output>* grad_outputs) { |
621 | grad_outputs->push_back(Conj(scope, grad_inputs[0])); |
622 | return scope.status(); |
623 | } |
624 | REGISTER_GRADIENT_OP("Conj" , ConjGrad); |
625 | |
626 | // Integer division x / y, assuming x and y >=0, but treats x/0 = x |
627 | Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { |
628 | return Div(scope, x, Maximum(scope, y, Const(scope, 1))); |
629 | } |
630 | |
631 | // SumGradHelper returns the gradient for the Sum operator, and is used |
632 | // by SumGrad and MeanGrad. |
633 | Output SumGradHelper(const Scope& scope, const Operation& op, |
634 | const std::vector<Output>& grad_inputs) { |
635 | // The partial derivative for any input along a "reduced" dimension |
636 | // is just 1, so we only need replicate the output gradient on such a |
637 | // dimension to its "expanded" shape. |
638 | // Running example: |
639 | // input is |
640 | // [[a, b, c], |
641 | // [d, e, f]] |
642 | // reduction_indices = [1] |
643 | // Sum = [a + b + c, d + e + f] |
644 | // if the gradient is [g1, g2] |
645 | // We want the propagated gradient to be |
646 | // [[g1, g1, g1], |
647 | // [g2, g2, g2]] |
648 | |
649 | // input_shape = [2, 3] |
650 | auto input_shape = Shape(scope, op.input(0)); |
651 | |
652 | // output_shape_kept_dims = [2, 1] |
653 | auto output_shape_kept_dims = |
654 | ReducedShapeHelper(scope, input_shape, op.input(1)); |
655 | |
656 | // This step "flips" any 1s with values from the input_shape, and |
657 | // replaces remaining entries with 1. This creates a shape that |
658 | // shows how much each dimension in the incoming gradient should be |
659 | // replicated. |
660 | // tile_scaling = [1, 3] |
661 | auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); |
662 | |
663 | // grad = [[g1], [g2]] |
664 | auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); |
665 | |
666 | // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] |
667 | return Tile(scope, grad, tile_scaling); |
668 | } |
669 | |
670 | Status SumGrad(const Scope& scope, const Operation& op, |
671 | const std::vector<Output>& grad_inputs, |
672 | std::vector<Output>* grad_outputs) { |
673 | grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); |
674 | |
675 | // Stop propagation along reduction_indices |
676 | grad_outputs->push_back(NoGradient()); |
677 | return scope.status(); |
678 | } |
679 | REGISTER_GRADIENT_OP("Sum" , SumGrad); |
680 | |
681 | Status MeanGrad(const Scope& scope, const Operation& op, |
682 | const std::vector<Output>& grad_inputs, |
683 | std::vector<Output>* grad_outputs) { |
684 | // The Mean gradient is just like the Sum gradient, except that |
685 | // all gradients are also divided by the size of reduced groups. |
686 | auto sum_grad = SumGradHelper(scope, op, grad_inputs); |
687 | |
688 | // The product of all entries in a tensor's shape is the total |
689 | // number of entries in the tensor. This step calculates |
690 | // n_input_entries/n_output_entries |
691 | // = group_size |
692 | auto input_shape = Shape(scope, op.input(0)); |
693 | auto output_shape = Shape(scope, op.output(0)); |
694 | auto zero = Const(scope, 0); |
695 | auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), |
696 | Prod(scope, output_shape, zero)); |
697 | |
698 | // propagate sum_grad/group_size |
699 | grad_outputs->push_back( |
700 | Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); |
701 | |
702 | // Stop propagation along reduction_indices |
703 | grad_outputs->push_back(NoGradient()); |
704 | return scope.status(); |
705 | } |
706 | REGISTER_GRADIENT_OP("Mean" , MeanGrad); |
707 | |
708 | Status ErfGrad(const Scope& scope, const Operation& op, |
709 | const std::vector<Output>& grad_inputs, |
710 | std::vector<Output>* grad_outputs) { |
711 | auto grad = grad_inputs[0]; |
712 | auto two_over_root_pi = |
713 | Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), grad.type()); |
714 | Scope grad_scope = scope.WithControlDependencies(grad); |
715 | auto x = ConjugateHelper(grad_scope, op.input(0)); |
716 | // grad * 2/sqrt(pi) * exp(-x**2) |
717 | auto dx = Mul(grad_scope, Mul(grad_scope, grad, two_over_root_pi), |
718 | Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); |
719 | grad_outputs->push_back(dx); |
720 | return grad_scope.status(); |
721 | } |
722 | REGISTER_GRADIENT_OP("Erf" , ErfGrad); |
723 | |
724 | Status ErfinvGrad(const Scope& scope, const Operation& op, |
725 | const std::vector<Output>& grad_inputs, |
726 | std::vector<Output>* grad_outputs) { |
727 | auto grad = grad_inputs[0]; |
728 | auto root_pi_over_two = |
729 | Cast(scope, Const(scope, std::sqrt(M_PI) / 2), grad.type()); |
730 | Scope grad_scope = scope.WithControlDependencies(grad); |
731 | auto x = ConjugateHelper(grad_scope, op.input(0)); |
732 | // grad * sqrt(pi) / 2 * exp(erfinv(x) ** 2) |
733 | auto dx = Mul(grad_scope, Mul(grad_scope, grad, root_pi_over_two), |
734 | Exp(grad_scope, Square(grad_scope, op.output(0)))); |
735 | grad_outputs->push_back(dx); |
736 | return grad_scope.status(); |
737 | } |
738 | REGISTER_GRADIENT_OP("Erfinv" , ErfinvGrad); |
739 | |
740 | Status NdtriGrad(const Scope& scope, const Operation& op, |
741 | const std::vector<Output>& grad_inputs, |
742 | std::vector<Output>* grad_outputs) { |
743 | auto grad = grad_inputs[0]; |
744 | auto root_two_pi = |
745 | Cast(scope, Const(scope, std::sqrt(2 * M_PI)), grad.type()); |
746 | auto two = Cast(scope, Const(scope, 2), grad.type()); |
747 | Scope grad_scope = scope.WithControlDependencies(grad); |
748 | auto x = ConjugateHelper(grad_scope, op.input(0)); |
749 | // grad * sqrt(2 * pi) * exp(ndtri(x) ** 2 / 2) |
750 | auto dx = Mul( |
751 | grad_scope, Mul(grad_scope, grad, root_two_pi), |
752 | Exp(grad_scope, Div(grad_scope, Square(grad_scope, op.output(0)), two))); |
753 | grad_outputs->push_back(dx); |
754 | return grad_scope.status(); |
755 | } |
756 | REGISTER_GRADIENT_OP("Ndtri" , NdtriGrad); |
757 | |
758 | Status LgammaGrad(const Scope& scope, const Operation& op, |
759 | const std::vector<Output>& grad_inputs, |
760 | std::vector<Output>* grad_outputs) { |
761 | auto grad = grad_inputs[0]; |
762 | Scope grad_scope = scope.WithControlDependencies(grad); |
763 | auto x = ConjugateHelper(grad_scope, op.input(0)); |
764 | auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); |
765 | grad_outputs->push_back(dx); |
766 | return grad_scope.status(); |
767 | } |
768 | REGISTER_GRADIENT_OP("Lgamma" , LgammaGrad); |
769 | |
770 | Status MinOrMaxGrad(const Scope& scope, const Operation& op, |
771 | const std::vector<Output>& grad_inputs, |
772 | std::vector<Output>* grad_outputs) { |
773 | // The partial derivative for any input along a "reduced" dimension |
774 | // is 1 when it is the min (or max) and 0 everywhere else. So the |
775 | // gradient calculation is identical for both operators. |
776 | // |
777 | // There's a special case for propagating gradients when there are |
778 | // multiple minima (or maxima) - we choose to divide the gradient |
779 | // equally among all matching inputs. |
780 | // |
781 | // Please note this comment |
782 | // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063 |
783 | // for details. |
784 | |
785 | // Running example: |
786 | // input: [[5, 5, 5], |
787 | // [1, 2, -3]] |
788 | // reduction_indices: [1] |
789 | auto input = op.input(0); |
790 | auto reduction_indices = op.input(1); |
791 | |
792 | // [2, 3] |
793 | auto input_shape = Shape(scope, input); |
794 | |
795 | // [2, 1] |
796 | auto output_shape_kept_dims = |
797 | ReducedShapeHelper(scope, input_shape, reduction_indices); |
798 | |
799 | // for op=min (say) |
800 | // output = [5, -3] |
801 | // y = [[5], |
802 | // [-3]] |
803 | auto y = Reshape(scope, op.output(0), output_shape_kept_dims); |
804 | |
805 | // reshape([g1, g2], [2, 1]) = [[g1], |
806 | // [g2]] |
807 | auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); |
808 | |
809 | // indicators = equal(y, input) |
810 | // = equal([[5], [[5, 5, 5], |
811 | // [-3]], [1, 2, -3]]) |
812 | // = [[1, 1, 1], |
813 | // [0, 0, 1]] |
814 | auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type()); |
815 | |
816 | // [[3], |
817 | // [1]] |
818 | auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices), |
819 | output_shape_kept_dims); |
820 | |
821 | // [[1/3, 1/3, 1/3], |
822 | // [0, 0, 1]] |
823 | auto scale = Div(scope, indicators, num_selected); |
824 | |
825 | // [[g1/3, g1/3, g1/3], |
826 | // [0, 0, g2]] |
827 | grad_outputs->push_back(Mul(scope, scale, grad)); |
828 | |
829 | // Stop propagation along reduction_indices |
830 | grad_outputs->push_back(NoGradient()); |
831 | return scope.status(); |
832 | } |
833 | REGISTER_GRADIENT_OP("Min" , MinOrMaxGrad); |
834 | REGISTER_GRADIENT_OP("Max" , MinOrMaxGrad); |
835 | |
836 | Status ProdGrad(const Scope& scope, const Operation& op, |
837 | const std::vector<Output>& grad_inputs, |
838 | std::vector<Output>* grad_outputs) { |
839 | auto zero = Const(scope, 0); |
840 | auto one = Const(scope, 1); |
841 | |
842 | // The gradient can be expressed by dividing the product by each entry of |
843 | // the input tensor. If our input is |
844 | // [ |
845 | // [3, 4], |
846 | // [5, 6], |
847 | // [7, 8] |
848 | // ] |
849 | // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. |
850 | // The gradient will have the same shape as the input |
851 | // [ |
852 | // [105/3, 192/4], |
853 | // dz * [105/5, 192/6], |
854 | // [105/7, 192/6] |
855 | // ] |
856 | // If the input contains a zero, the division is impossible but |
857 | // if we take the calculation that gave the first gradient |
858 | // (3 * 5 * 6)/3 is equal to 5 * 6 |
859 | // the trick will be to cumprod the elements on the axis without |
860 | // the element at the current position (3 in the example above). |
861 | // We will take as example: |
862 | // [ |
863 | // [ |
864 | // [3.0, 4.0], |
865 | // [5.0, 6.0], |
866 | // [7.0, 8.0] |
867 | // ], |
868 | // [ |
869 | // [3.0, 5.0], |
870 | // [0.0, 6.0], |
871 | // [5.0, 6.0] |
872 | // ] |
873 | // ] |
874 | |
875 | // [2, 3, 2] |
876 | auto input_shape = Shape(scope, op.input(0)); |
877 | |
878 | // The Reshape with -1 flattens the reduction indices. |
879 | // [1] |
880 | auto reduction_indices = Reshape(scope, op.input(1), {-1}); |
881 | |
882 | // [2, 1, 2] |
883 | auto output_shape_kept_dims = |
884 | ReducedShapeHelper(scope, input_shape, reduction_indices); |
885 | |
886 | // [1, 3, 1] |
887 | auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); |
888 | |
889 | // [[[105, 192]], [[0, 180]]] |
890 | auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); |
891 | |
892 | // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] |
893 | auto grad_tiled = Tile(scope, grad, tile_scaling); |
894 | |
895 | Scope cpu_scope = scope.WithDevice("/cpu:0" ); |
896 | |
897 | // [3] |
898 | auto rank = Rank(cpu_scope, op.input(0)); |
899 | |
900 | // Normalize any negative indices in the reduction_axes to positive values. |
901 | auto reduction_indices_pos = |
902 | Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); |
903 | |
904 | // [1] |
905 | auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); |
906 | |
907 | // [0, 1, 2] |
908 | auto idx = Range(cpu_scope, zero, rank, one); |
909 | |
910 | // [0, 2] |
911 | auto other = SetDiff1D(cpu_scope, idx, reduced).out; |
912 | |
913 | // [1, 0, 2] |
914 | auto perm = |
915 | Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0); |
916 | |
917 | // 3 => [3] |
918 | auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); |
919 | |
920 | // 2 * 2 => [2] |
921 | auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); |
922 | |
923 | // [ |
924 | // [ |
925 | // [ 3., 4.], |
926 | // [ 3., 5.] |
927 | // ], |
928 | // [ |
929 | // [ 5., 6.], |
930 | // [ 0., 6.] |
931 | // ], |
932 | // [ |
933 | // [ 7., 8.], |
934 | // [ 5., 6.] |
935 | // ] |
936 | // ] |
937 | auto permuted = Transpose(scope, op.input(0), perm); |
938 | |
939 | // [3, 2, 2] |
940 | auto permuted_shape = Shape(scope, permuted); |
941 | |
942 | // [ |
943 | // [ 3., 4., 3., 5.], |
944 | // [ 5., 6., 0., 6.], |
945 | // [ 7., 8., 5., 6.] |
946 | // ] |
947 | auto reshaped = Reshape( |
948 | scope, permuted, |
949 | Stack(scope, std::initializer_list<Input>{reduced_num, other_num})); |
950 | |
951 | // [ |
952 | // [ 1., 1., 1., 1.], |
953 | // [ 3., 4., 3., 5.], |
954 | // [ 15., 24., 0., 30.] |
955 | // ] |
956 | auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); |
957 | |
958 | // [ |
959 | // [ 35., 48., 0., 36.], |
960 | // [ 7., 8., 5., 6.], |
961 | // [ 1., 1., 1., 1.] |
962 | // ] |
963 | auto right = |
964 | Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); |
965 | |
966 | // left * right = |
967 | // [ |
968 | // [ 35., 48., 0., 36.], |
969 | // [ 21., 32., 15., 30.], |
970 | // [ 15., 24., 0., 30.] |
971 | // ] |
972 | // y = |
973 | // [ |
974 | // [ |
975 | // [ 35., 48.], |
976 | // [ 0., 36.] |
977 | // ], |
978 | // [ |
979 | // [ 21., 32.], |
980 | // [ 15., 30.] |
981 | // ], |
982 | // [ |
983 | // [ 15., 24.], |
984 | // [ 0., 30.] |
985 | // ] |
986 | // ] |
987 | auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); |
988 | |
989 | // out = |
990 | // [ |
991 | // [ |
992 | // [ 35., 48.], |
993 | // [ 21., 32.], |
994 | // [ 15., 24.] |
995 | // ], |
996 | // [ |
997 | // [ 0., 36.], |
998 | // [ 15., 30.], |
999 | // [ 0., 30.] |
1000 | // ] |
1001 | // ] |
1002 | auto out = Mul(scope, grad_tiled, |
1003 | Transpose(scope, y, InvertPermutation(scope, perm))); |
1004 | |
1005 | grad_outputs->push_back(Reshape(scope, out, input_shape)); |
1006 | |
1007 | // stop propagation along reduction_indices |
1008 | grad_outputs->push_back(NoGradient()); |
1009 | return scope.status(); |
1010 | } |
1011 | REGISTER_GRADIENT_OP("Prod" , ProdGrad); |
1012 | |
1013 | Status SegmentSumGrad(const Scope& scope, const Operation& op, |
1014 | const std::vector<Output>& grad_inputs, |
1015 | std::vector<Output>* grad_outputs) { |
1016 | // The SegmentSum operation sums segments of the Tensor that have the same |
1017 | // index in the segment_ids parameter. |
1018 | // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] |
1019 | // will produce [2 + 3 + 4, 5] = [9, 5] |
1020 | // The gradient that will flow back to the gather operation will look like |
1021 | // [x1, x2], it will have the same shape as the output of the SegmentSum |
1022 | // operation. The differentiation step of the SegmentSum operation just |
1023 | // broadcast the gradient in order to retrieve the z's shape. |
1024 | // dy/dz = [x1, x1, x1, x2] |
1025 | grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); |
1026 | |
1027 | // stop propagation along segment_ids |
1028 | grad_outputs->push_back(NoGradient()); |
1029 | return scope.status(); |
1030 | } |
1031 | REGISTER_GRADIENT_OP("SegmentSum" , SegmentSumGrad); |
1032 | |
1033 | // MatMulGrad helper function used to compute two MatMul operations |
1034 | // based on input matrix transposition combinations. |
1035 | Status MatMulGradHelper(const Scope& scope, const bool is_batch, |
1036 | const Output& x0, const bool adj_x0, const Output& x1, |
1037 | const bool adj_x1, const Output& y0, const bool adj_y0, |
1038 | const Output& y1, const bool adj_y1, |
1039 | std::vector<Output>* grad_outputs) { |
1040 | if (is_batch == false) { |
1041 | auto dx = |
1042 | MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); |
1043 | grad_outputs->push_back(dx); |
1044 | auto dy = |
1045 | MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); |
1046 | grad_outputs->push_back(dy); |
1047 | } else { |
1048 | auto dx = |
1049 | BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); |
1050 | grad_outputs->push_back(dx); |
1051 | auto dy = |
1052 | BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); |
1053 | grad_outputs->push_back(dy); |
1054 | } |
1055 | return scope.status(); |
1056 | } |
1057 | |
1058 | // MatMulGrad common used to read and check node attr state, and determine |
1059 | // proper MatMul products for gradients based on input matrix transposition |
1060 | // combinations. |
1061 | Status MatMulGradCommon(const Scope& scope, const Operation& op, |
1062 | const bool is_batch, |
1063 | const std::vector<Output>& grad_inputs, |
1064 | const string& attr_adj_x, const string& attr_adj_y, |
1065 | std::vector<Output>* grad_outputs) { |
1066 | auto a = op.input(0); |
1067 | auto b = op.input(1); |
1068 | // Use conjugate of the inputs for MatMul |
1069 | if (is_batch == false) { |
1070 | a = ConjugateHelper(scope, a); |
1071 | b = ConjugateHelper(scope, b); |
1072 | } |
1073 | auto product = op.output(0); |
1074 | |
1075 | bool ta; |
1076 | bool tb; |
1077 | TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta)); |
1078 | TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); |
1079 | |
1080 | if (!ta && !tb) { |
1081 | return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, |
1082 | true, grad_inputs[0], false, grad_outputs); |
1083 | } else if (!ta && tb) { |
1084 | return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, |
1085 | grad_inputs[0], true, a, false, grad_outputs); |
1086 | } else if (ta && !tb) { |
1087 | return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, |
1088 | false, grad_inputs[0], false, grad_outputs); |
1089 | } |
1090 | return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, |
1091 | grad_inputs[0], true, a, true, grad_outputs); |
1092 | } |
1093 | |
1094 | Status MatMulGrad(const Scope& scope, const Operation& op, |
1095 | const std::vector<Output>& grad_inputs, |
1096 | std::vector<Output>* grad_outputs) { |
1097 | return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a" , |
1098 | "transpose_b" , grad_outputs); |
1099 | } |
1100 | REGISTER_GRADIENT_OP("MatMul" , MatMulGrad); |
1101 | |
1102 | Status BatchMatMulGrad(const Scope& scope, const Operation& op, |
1103 | const std::vector<Output>& grad_inputs, |
1104 | std::vector<Output>* grad_outputs) { |
1105 | return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x" , "adj_y" , |
1106 | grad_outputs); |
1107 | } |
1108 | REGISTER_GRADIENT_OP("BatchMatMul" , BatchMatMulGrad); |
1109 | |
1110 | Status CumsumGrad(const Scope& scope, const Operation& op, |
1111 | const std::vector<Output>& grad_inputs, |
1112 | std::vector<Output>* grad_outputs) { |
1113 | if (op.num_inputs() != 2) { |
1114 | return errors::InvalidArgument("Cumsum requires 2 arguments" ); |
1115 | } |
1116 | if (grad_inputs.size() != 1) { |
1117 | return errors::InvalidArgument("Cumsum grad requires 1 grad input" ); |
1118 | } |
1119 | |
1120 | Cumsum::Attrs attrs; |
1121 | TF_RETURN_IF_ERROR( |
1122 | GetNodeAttr(op.node()->attrs(), "exclusive" , &attrs.exclusive_)); |
1123 | bool reverse; |
1124 | TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "reverse" , &reverse)); |
1125 | attrs.reverse_ = !reverse; |
1126 | |
1127 | auto axis = op.input(1); |
1128 | auto sum = Cumsum(scope, grad_inputs[0], axis, attrs); |
1129 | grad_outputs->push_back(sum.out); |
1130 | grad_outputs->push_back(NoGradient()); |
1131 | return scope.status(); |
1132 | } |
1133 | REGISTER_GRADIENT_OP("Cumsum" , CumsumGrad); |
1134 | |
1135 | bool IsFloatingPointDtype(DataType dtype) { |
1136 | static constexpr DataType valid_dtypes[] = { |
1137 | DT_FLOAT, DT_HALF, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}; |
1138 | return std::find(std::begin(valid_dtypes), std::end(valid_dtypes), dtype) != |
1139 | std::end(valid_dtypes); |
1140 | } |
1141 | |
1142 | Status CastGrad(const Scope& scope, const Operation& op, |
1143 | const std::vector<Output>& grad_inputs, |
1144 | std::vector<Output>* grad_outputs) { |
1145 | if (op.num_inputs() != 1) { |
1146 | return errors::InvalidArgument("Cast requires 2 arguments" ); |
1147 | } |
1148 | if (grad_inputs.size() != 1) { |
1149 | return errors::InvalidArgument("Cast grad requires 1 grad input" ); |
1150 | } |
1151 | |
1152 | auto src_type = op.input_type(0); |
1153 | auto dst_type = grad_inputs[0].type(); |
1154 | if (IsFloatingPointDtype(src_type) && IsFloatingPointDtype(dst_type)) { |
1155 | grad_outputs->push_back(Cast(scope, grad_inputs[0], src_type)); |
1156 | } else { |
1157 | grad_outputs->push_back(NoGradient()); |
1158 | } |
1159 | return scope.status(); |
1160 | } |
1161 | REGISTER_GRADIENT_OP("Cast" , CastGrad); |
1162 | |
1163 | Status SelectGrad(const Scope& scope, const Operation& op, |
1164 | const std::vector<Output>& grad_inputs, |
1165 | std::vector<Output>* grad_outputs) { |
1166 | if (op.num_inputs() != 3) { |
1167 | return errors::InvalidArgument("Select requires 3 arguments" ); |
1168 | } |
1169 | if (grad_inputs.size() != 1) { |
1170 | return errors::InvalidArgument("Select grad requires 1 grad input" ); |
1171 | } |
1172 | |
1173 | auto c = op.input(0); |
1174 | auto zeros = ZerosLike(scope, grad_inputs[0]); |
1175 | grad_outputs->push_back(NoGradient()); // Condition |
1176 | grad_outputs->push_back(Where3(scope, c, grad_inputs[0], zeros)); |
1177 | grad_outputs->push_back(Where3(scope, c, zeros, grad_inputs[0])); |
1178 | return scope.status(); |
1179 | } |
1180 | REGISTER_GRADIENT_OP("Select" , SelectGrad); |
1181 | |
1182 | Status SelectV2Grad(const Scope& scope, const Operation& op, |
1183 | const std::vector<Output>& grad_inputs, |
1184 | std::vector<Output>* grad_outputs) { |
1185 | if (op.num_inputs() != 3) { |
1186 | return errors::InvalidArgument("Select requires 3 arguments" ); |
1187 | } |
1188 | |
1189 | if (grad_inputs.size() != 1) { |
1190 | return errors::InvalidArgument("Select grad requires 1 grad input" ); |
1191 | } |
1192 | |
1193 | auto c = op.input(0); |
1194 | auto x = op.input(1); |
1195 | auto y = op.input(2); |
1196 | |
1197 | auto zeros = ZerosLike(scope, grad_inputs[0]); |
1198 | auto gx = SelectV2(scope, c, grad_inputs[0], zeros); |
1199 | auto x_shape = Shape(scope, x); |
1200 | auto output_shape = Shape(scope, op.output(0)); |
1201 | |
1202 | // Reduce away broadcasted leading dims. |
1203 | auto reduce_x = internal::BroadcastGradientArgs(scope, x_shape, output_shape); |
1204 | auto gx_sum = |
1205 | ReduceSum(scope, gx, /*axis=*/reduce_x.r0, ReduceSum::KeepDims(true)); |
1206 | auto gx_sum_reshape = Reshape(scope, gx_sum, x_shape); |
1207 | |
1208 | auto gy = SelectV2(scope, c, zeros, grad_inputs[0]); |
1209 | auto y_shape = Shape(scope, y); |
1210 | |
1211 | // Reduce away broadcasted leading dims. |
1212 | auto reduce_y = internal::BroadcastGradientArgs(scope, y_shape, output_shape); |
1213 | auto gy_sum = |
1214 | ReduceSum(scope, gy, /*axis=*/reduce_y.r0, ReduceSum::KeepDims(true)); |
1215 | auto gy_sum_reshape = Reshape(scope, gy_sum, y_shape); |
1216 | |
1217 | grad_outputs->push_back(NoGradient()); // Condition |
1218 | grad_outputs->push_back(gx_sum_reshape); |
1219 | grad_outputs->push_back(gy_sum_reshape); |
1220 | return scope.status(); |
1221 | } |
1222 | |
1223 | REGISTER_GRADIENT_OP("SelectV2" , SelectV2Grad); |
1224 | |
1225 | // Helper function for unsorted segment ops. |
1226 | // Returns 'ids' with negative elements replaced by 0. |
1227 | Output GetZeroClippedIndices(const Scope& scope, const Output& ids) { |
1228 | return Maximum(scope, ids, ZerosLike(scope, ids)); |
1229 | } |
1230 | |
1231 | // Helper function for unsorted segment ops. |
1232 | // Returns a mask of where 'ids' are positive, reshaped so that it will be |
1233 | // broadcastable to the result shape of gathering params by ids. |
1234 | Output GetIsPositive(const Scope& scope, const Output& params, |
1235 | const Output& ids) { |
1236 | Output is_positive = GreaterEqual(scope, ids, ZerosLike(scope, ids)); |
1237 | // tf.where(condition, x, y) requires condition to have the same shape as x |
1238 | // and y. |
1239 | Output is_positive_shape = Shape(scope, is_positive); |
1240 | Output ones = |
1241 | Tile(scope, Const(scope, {1}), Subtract(scope, Rank(scope, params), {1})); |
1242 | auto broadcastable_shape = Concat(scope, {is_positive_shape, ones}, |
1243 | /*axis=*/0); |
1244 | is_positive = Reshape(scope, is_positive, broadcastable_shape); |
1245 | is_positive = LogicalAnd(scope, is_positive, OnesLike(scope, is_positive)); |
1246 | return is_positive; |
1247 | } |
1248 | |
1249 | // Helper function for unsorted segment ops. |
1250 | // Gathers params for positive segment ids and gathers 0 for inputs with |
1251 | // negative segment id. |
1252 | Output GatherDropNegatives(const Scope& scope, const Output& params, |
1253 | Output& zero_clipped_indices, Output& is_positive) { |
1254 | auto gathered = Gather(scope, params, zero_clipped_indices); |
1255 | // Replace gathered params of negative indices with 0. |
1256 | auto zero_slice = ZerosLike(scope, gathered); |
1257 | return SelectV2(scope, is_positive, gathered, zero_slice); |
1258 | } |
1259 | |
1260 | Status UnsortedSegmentMinOrMaxGrad(const Scope& scope, const Operation& op, |
1261 | const std::vector<Output>& grad_inputs, |
1262 | std::vector<Output>* grad_outputs) { |
1263 | if (op.num_inputs() != 3) { |
1264 | return errors::InvalidArgument("UnsortedSegmentMax requires 3 arguments" ); |
1265 | } |
1266 | |
1267 | if (grad_inputs.size() != 1) { |
1268 | return errors::InvalidArgument( |
1269 | "UnsortedSegmentMax grad requires 1 grad input" ); |
1270 | } |
1271 | |
1272 | auto grad = grad_inputs[0]; |
1273 | // Get the number of selected (minimum or maximum) elements in each segment. |
1274 | auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1)); |
1275 | auto is_positive = GetIsPositive(scope, op.output(0), op.input(1)); |
1276 | Output gathered_outputs = GatherDropNegatives( |
1277 | scope, op.output(0), zero_clipped_indices, is_positive); |
1278 | Output is_selected = Equal(scope, op.input(0), gathered_outputs); |
1279 | is_selected = LogicalAnd(scope, is_selected, is_positive); |
1280 | auto num_selected = UnsortedSegmentSum( |
1281 | scope, Cast(scope, is_selected, grad.type()), op.input(1), op.input(2)); |
1282 | // Compute the gradient for each segment.The gradient for the ith segment is |
1283 | // divided evenly among the selected elements in that segment. |
1284 | auto weighted_grads = Div(scope, grad, num_selected); |
1285 | auto gathered_grads = GatherDropNegatives(scope, weighted_grads, |
1286 | zero_clipped_indices, is_positive); |
1287 | auto zeros = ZerosLike(scope, gathered_grads); |
1288 | grad_outputs->push_back(SelectV2(scope, is_selected, gathered_grads, zeros)); |
1289 | grad_outputs->push_back(NoGradient()); |
1290 | grad_outputs->push_back(NoGradient()); |
1291 | return scope.status(); |
1292 | } |
1293 | |
1294 | REGISTER_GRADIENT_OP("UnsortedSegmentMax" , UnsortedSegmentMinOrMaxGrad); |
1295 | REGISTER_GRADIENT_OP("UnsortedSegmentMin" , UnsortedSegmentMinOrMaxGrad); |
1296 | |
1297 | Status UnsortedSegmentSumGrad(const Scope& scope, const Operation& op, |
1298 | const std::vector<Output>& grad_inputs, |
1299 | std::vector<Output>* grad_outputs) { |
1300 | if (op.num_inputs() != 3) { |
1301 | return errors::InvalidArgument("UnsortedSegmentSum requires 3 arguments" ); |
1302 | } |
1303 | |
1304 | if (grad_inputs.size() != 1) { |
1305 | return errors::InvalidArgument( |
1306 | "UnsortedSegmentSum grad requires 1 grad input" ); |
1307 | } |
1308 | |
1309 | auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1)); |
1310 | auto is_positive = GetIsPositive(scope, grad_inputs[0], op.input(1)); |
1311 | grad_outputs->push_back(GatherDropNegatives( |
1312 | scope, grad_inputs[0], zero_clipped_indices, is_positive)); |
1313 | grad_outputs->push_back(NoGradient()); |
1314 | grad_outputs->push_back(NoGradient()); |
1315 | return scope.status(); |
1316 | } |
1317 | |
1318 | REGISTER_GRADIENT_OP("UnsortedSegmentSum" , UnsortedSegmentSumGrad); |
1319 | |
1320 | } // anonymous namespace |
1321 | } // namespace ops |
1322 | } // namespace tensorflow |
1323 | |