1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/core/framework/function.h"
19#include "tensorflow/core/framework/types.pb.h"
20#include "tensorflow/core/lib/core/errors.h"
21#include "tensorflow/core/lib/gtl/array_slice.h"
22
23namespace tensorflow {
24
25typedef FunctionDefHelper FDH;
26
27// Cwise binary ops
28Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
29 for (auto& n : nodes) {
30 if (n.attr.empty()) {
31 n.attr = {{"T", "$T"}};
32 }
33 }
34 *g = FDH::Define(
35 // Arg defs
36 {"x: T", "dy: T"},
37 // Ret val defs
38 {"dx: T"},
39 // Attr defs
40 {{"T: {half, float, double}"}},
41 // Nodes
42 nodes);
43 return OkStatus();
44}
45
46Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
47 // clang-format off
48 return GradForUnaryCwise(g, {
49 {{"sign"}, "Sign", {"x"}, {}, {"dy"}},
50 {{"dx"}, "Mul", {"dy", "sign"}},
51 });
52 // clang-format on
53}
54REGISTER_OP_GRADIENT("Abs", AbsGrad);
55
56Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
57 // clang-format off
58 return GradForUnaryCwise(g, {
59 {{"dx"}, "Neg", {"dy"}},
60 });
61 // clang-format on
62}
63REGISTER_OP_GRADIENT("Neg", NegGrad);
64
65Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
66 // clang-format off
67 return GradForUnaryCwise(g, {
68 {{"y"}, "Reciprocal", {"x"}},
69 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
70 {{"y2_neg"}, "Neg", {"y2"}},
71 {{"dx"}, "Mul", {"dy", "y2_neg"}}
72 });
73 // clang-format on
74}
75REGISTER_OP_GRADIENT("Inv", InvGrad);
76REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
77
78Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
79 // clang-format off
80 return GradForUnaryCwise(g, {
81 FDH::Const("c", int64_t{2}),
82 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
83 {{"x2"}, "Mul", {"x", "two"}, {}, {"dy"}}, // x * 2
84 {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2)
85 });
86 // clang-format on
87}
88REGISTER_OP_GRADIENT("Square", SquareGrad);
89
90Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
91 // clang-format off
92 return GradForUnaryCwise(g, {
93 {{"y"}, "Sqrt", {"x"}},
94 {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
95 FDH::Const("const", 0.5f),
96 {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
97 {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y
98 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y)
99 });
100 // clang-format on
101}
102REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
103
104Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
105 // clang-format off
106 return GradForUnaryCwise(g, {
107 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
108 {{"y"}, "Rsqrt", {"x"}},
109 FDH::Const("const", -.5f),
110 {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
111 {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x
112 {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y
113 {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5)
114 });
115 // clang-format on
116}
117REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
118
119Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
120 // clang-format off
121 return GradForUnaryCwise(g, {
122 {{"y"}, "Exp", {"x"}},
123 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
124 });
125 // clang-format on
126}
127REGISTER_OP_GRADIENT("Exp", ExpGrad);
128
129Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) {
130 // clang-format off
131 return GradForUnaryCwise(g, {
132 {{"y"}, "Exp", {"x"}},
133 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
134 });
135 // clang-format on
136}
137REGISTER_OP_GRADIENT("Expm1", Expm1Grad);
138
139Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
140 // clang-format off
141 return GradForUnaryCwise(g, {
142 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
143 {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x
144 });
145 // clang-format on
146}
147REGISTER_OP_GRADIENT("Log", LogGrad);
148
149Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
150 // clang-format off
151 return GradForUnaryCwise(g, {
152 FDH::Const("const", 1.0f),
153 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
154 {{"a"}, "Add", {"one", "x"}},
155 {{"dx"}, "Div", {"dy", "a"}}, // dy / (1 + x)
156 });
157 // clang-format on
158}
159REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
160
161Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
162 // clang-format off
163 return GradForUnaryCwise(g, {
164 {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
165 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(x)
166 });
167 // clang-format on
168}
169REGISTER_OP_GRADIENT("Sinh", SinhGrad);
170
171Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
172 // clang-format off
173 return GradForUnaryCwise(g, {
174 {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
175 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(x)
176 });
177 // clang-format on
178}
179REGISTER_OP_GRADIENT("Cosh", CoshGrad);
180
181Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
182 // clang-format off
183 return GradForUnaryCwise(g, {
184 {{"y"}, "Tanh", {"x"}},
185 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
186 FDH::Const("const", 1.0f),
187 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
188 {{"a"}, "Sub", {"one", "y2"}},
189 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y)
190 });
191 // clang-format on
192}
193REGISTER_OP_GRADIENT("Tanh", TanhGrad);
194
195Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) {
196 // clang-format off
197 return GradForUnaryCwise(g, {
198 {{"y"}, "Asinh", {"x"}},
199 {{"cosh"}, "Cosh", {"y"}},
200 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(y)
201 });
202 // clang-format on
203}
204REGISTER_OP_GRADIENT("Asinh", AsinhGrad);
205
206Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) {
207 // clang-format off
208 return GradForUnaryCwise(g, {
209 {{"y"}, "Acosh", {"x"}},
210 {{"sinh"}, "Sinh", {"y"}},
211 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(y)
212 });
213 // clang-format on
214}
215REGISTER_OP_GRADIENT("Acosh", AcoshGrad);
216
217Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) {
218 // clang-format off
219 return GradForUnaryCwise(g, {
220 {{"x2"}, "Square", {"x"}},
221 FDH::Const("const", 1.0f),
222 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
223 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
224 {{"inv"}, "Reciprocal", {"a"}},
225 {{"dx"}, "Mul", {"dy", "inv"}}
226 });
227 // clang-format on
228}
229REGISTER_OP_GRADIENT("Atanh", AtanhGrad);
230
231Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
232 // clang-format off
233 return GradForUnaryCwise(g, {
234 {{"y"}, "Sigmoid", {"x"}},
235 FDH::Const("const", 1.0f),
236 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
237 {{"a"}, "Sub", {"one", "y"}, {}, {"dy"}},
238 {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y)
239 {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y)
240 });
241 // clang-format on
242}
243REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
244
245Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
246 // clang-format off
247 return GradForUnaryCwise(g, {
248 {{"s"}, "Shape", {"x"}},
249 FDH::Const("zero", 0.f),
250 {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
251 {{"dx"}, "Fill", {"s", "val"}},
252 });
253 // clang-format on
254}
255REGISTER_OP_GRADIENT("Sign", SignGrad);
256
257Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
258 // clang-format off
259 return GradForUnaryCwise(g, {
260 {{"cos"}, "Cos", {"x"}, {}, {"dy"}},
261 {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x)
262 });
263 // clang-format on
264}
265REGISTER_OP_GRADIENT("Sin", SinGrad);
266
267Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
268 // clang-format off
269 return GradForUnaryCwise(g, {
270 {{"sin"}, "Sin", {"x"}, {}, {"dy"}},
271 {{"neg"}, "Neg", {"sin"}},
272 {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x))
273 });
274 // clang-format on
275}
276REGISTER_OP_GRADIENT("Cos", CosGrad);
277
278Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
279 // clang-format off
280 return GradForUnaryCwise(g, {
281 {{"x2"}, "Square", {"x"}},
282 FDH::Const("const", 1.0f),
283 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
284 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
285 {{"b"}, "Sqrt", {"a"}},
286 {{"inv"}, "Reciprocal", {"b"}},
287 {{"neg"}, "Neg", {"inv"}},
288 {{"dx"}, "Mul", {"dy", "neg"}}
289 });
290 // clang-format on
291}
292REGISTER_OP_GRADIENT("Acos", AcosGrad);
293
294Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
295 // clang-format off
296 return GradForUnaryCwise(g, {
297 {{"x2"}, "Square", {"x"}},
298 FDH::Const("const", 1.0f),
299 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
300 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
301 {{"b"}, "Sqrt", {"a"}},
302 {{"inv"}, "Reciprocal", {"b"}},
303 {{"dx"}, "Mul", {"dy", "inv"}}
304 });
305 // clang-format on
306}
307REGISTER_OP_GRADIENT("Asin", AsinGrad);
308
309Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
310 // clang-format off
311 return GradForUnaryCwise(g, {
312 {{"x2"}, "Square", {"x"}},
313 FDH::Const("const", 1.0f),
314 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
315 {{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
316 {{"inv"}, "Reciprocal", {"a"}},
317 {{"dx"}, "Mul", {"dy", "inv"}}
318 });
319 // clang-format on
320}
321REGISTER_OP_GRADIENT("Atan", AtanGrad);
322
323Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
324 // clang-format off
325 return GradForUnaryCwise(g, {
326 {{"cosx"}, "Cos", {"x"}},
327 {{"secx"}, "Reciprocal", {"cosx"}},
328 {{"secx2"}, "Square", {"secx"}},
329 {{"dx"}, "Mul", {"dy", "secx2"}}
330 });
331 // clang-format on
332}
333REGISTER_OP_GRADIENT("Tan", TanGrad);
334
335Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
336 // clang-format off
337 return GradForUnaryCwise(g, {
338 FDH::Const("zero", 0.f),
339 {{"dx"}, "Complex", {"dy", "zero"}},
340 });
341 // clang-format on
342}
343REGISTER_OP_GRADIENT("Real", RealGrad);
344
345Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
346 // clang-format off
347 return GradForUnaryCwise(g, {
348 FDH::Const("zero", 0.f),
349 {{"dx"}, "Complex", {"zero", "dy"}},
350 });
351 // clang-format on
352}
353REGISTER_OP_GRADIENT("Imag", ImagGrad);
354
355Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
356 // clang-format off
357 return GradForUnaryCwise(g, {
358 {{"re"}, "Real", {"x"}},
359 {{"im"}, "Imag", {"x"}},
360 {{"z"}, "Complex", {"im", "re"}},
361 {{"z_inv"}, "Reciprocal", {"z"}},
362 {{"neg"}, "Neg", {"z_inv"}},
363 {{"dx"}, "Mul", {"neg", "dy"}},
364 });
365 // clang-format on
366}
367REGISTER_OP_GRADIENT("Angle", AngleGrad);
368
369Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
370 // clang-format off
371 return GradForUnaryCwise(g, {
372 {{"dx"}, "Conj", {"dy"}},
373 });
374 // clang-format on
375}
376REGISTER_OP_GRADIENT("Conj", ConjGrad);
377
378Status CastGrad(const AttrSlice& attrs, FunctionDef* g) {
379 // clang-format off
380 *g = FDH::Define(
381 // Arg defs
382 {"x: SrcT", "dy: DstT"},
383 // Ret val defs
384 {"dx: SrcT"},
385 // Attr defs
386 {{"SrcT: type"}, {"DstT: type"}},
387 // Nodes
388 {{{"dx"}, "Cast", {"dy"}, {{"SrcT", "$DstT"}, {"DstT", "$SrcT"}}}});
389 return OkStatus();
390 // clang-format on
391}
392REGISTER_OP_GRADIENT("Cast", CastGrad);
393
394// Cwise binary ops
395//
396// TODO(zhifengc): This can be arrange as a function in the standard
397// library.
398Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
399 // clang-format off
400 std::vector<FDH::Node> nodes = {
401 {{"sx"}, "Shape", {"x"}},
402 {{"sy"}, "Shape", {"y"}},
403 };
404 nodes.insert(nodes.end(), body.begin(), body.end());
405 std::vector<FDH::Node> reshapes = {
406 {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
407 {{"sum_gx"}, "Sum", {"gx", "rx"}},
408 {{"dx"}, "Reshape", {"sum_gx", "sx"}},
409 {{"sum_gy"}, "Sum", {"gy", "ry"}},
410 {{"dy"}, "Reshape", {"sum_gy", "sy"}},
411 };
412 nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
413
414 // clang-format on
415 for (auto& n : nodes) {
416 // "BroadcastGradientArgs" doesn't need any attrs.
417 if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
418 n.attr = {{"T", "$T"}};
419 }
420 }
421 *g = FDH::Define(
422 // Arg defs
423 {"x: T", "y: T", "dz: T"},
424 // Ret val defs
425 {"dx: T", "dy: T"},
426 // Attr defs
427 {{"T: {half, float, double}"}},
428 // Nodes
429 nodes);
430 return OkStatus();
431}
432
433Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
434 // clang-format off
435 return GradForBinaryCwise(g, {
436 {{"gx"}, "Identity", {"dz"}},
437 {{"gy"}, "Identity", {"dz"}},
438 });
439 // clang-format on
440}
441REGISTER_OP_GRADIENT("Add", AddGrad);
442REGISTER_OP_GRADIENT("AddV2", AddGrad);
443
444Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
445 // clang-format off
446 return GradForBinaryCwise(g, {
447 {{"gx"}, "Identity", {"dz"}},
448 {{"gy"}, "Neg", {"dz"}}, // -dz
449 });
450 // clang-format on
451}
452REGISTER_OP_GRADIENT("Sub", SubGrad);
453
454Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
455 DataType T;
456 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
457 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
458 return GradForBinaryCwise(
459 g, {
460 {{"cy"}, "Conj", {"y"}, {}, {"dz"}},
461 {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y)
462 {{"cx"}, "Conj", {"x"}, {}, {"dz"}},
463 {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz
464 });
465 } else {
466 // clang-format off
467 return GradForBinaryCwise(g, {
468 {{"gx"}, "Mul", {"dz", "y"}}, // dz * y
469 {{"gy"}, "Mul", {"x", "dz"}}, // x * dz
470 });
471 // clang-format on
472 }
473}
474REGISTER_OP_GRADIENT("Mul", MulGrad);
475
476Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
477 // clang-format off
478 return GradForBinaryCwise(g, {
479 {{"gx"}, "MulNoNan", {"y", "dz"}}, // y * dz
480 {{"gy"}, "MulNoNan", {"x", "dz"}}, // x * dz
481 });
482 // clang-format on
483}
484REGISTER_OP_GRADIENT("MulNoNan", MulGrad);
485
486Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
487 // clang-format off
488 return GradForBinaryCwise(g, {
489 {{"gx"}, "Div", {"dz", "y"}},
490 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
491 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
492 {{"nx_y2"}, "Div", {"nx", "y2"}},
493 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
494 });
495 // clang-format on
496}
497REGISTER_OP_GRADIENT("Div", DivGrad);
498
499Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
500 // clang-format off
501 return GradForBinaryCwise(g, {
502 {{"gx"}, "RealDiv", {"dz", "y"}},
503 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
504 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
505 {{"nx_y2"}, "RealDiv", {"nx", "y2"}},
506 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
507 });
508 // clang-format on
509}
510REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
511
512Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
513 // clang-format off
514 return GradForBinaryCwise(g, {
515 {{"gx"}, "DivNoNan", {"dz", "y"}},
516 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
517 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
518 {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
519 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
520 });
521 // clang-format on
522}
523REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
524
525Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
526 // clang-format off
527 std::vector<FDH::Node> nodes = {
528 {{"z"}, "Pow", {"x", "y"}},
529 // dz * y * Pow(x, y - 1)
530 FDH::Const("const_zero", 0.0f),
531 FDH::Const("const_one", 1.0f),
532 {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
533 {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
534 {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
535 {{"t1"}, "Pow", {"x", "t0"}},
536 {{"t2"}, "Mul", {"dz", "y"}},
537 {{"gx"}, "Mul", {"t1", "t2"}},
538 {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
539 {{"zeros"}, "ZerosLike", {"x"}}};
540 // clang-format on
541 std::vector<FDH::Node> log_x_handling;
542 DataType T;
543 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
544 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
545 // dz * z * (x != 0 ? Log(x) : 0)
546 // clang-format off
547 log_x_handling = {
548 {{"nz_x"}, "NotEqual", {"x", "zero"}},
549 {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
550 // clang-format on
551 } else {
552 // dz * z * (x > 0 ? Log(x) : 0)
553 // clang-format off
554 log_x_handling = {
555 {{"pos_x"}, "Greater", {"x", "zero"}},
556 {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
557 // clang-format on
558 }
559 nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
560 nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
561 nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
562 return GradForBinaryCwise(g, nodes);
563}
564REGISTER_OP_GRADIENT("Pow", PowGrad);
565
566Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
567 // clang-format off
568 return GradForBinaryCwise(g, {
569 {{"zeros"}, "ZerosLike", {"x"}},
570 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
571 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
572 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
573 {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
574 {{"xlogygrad"}, "Xdivy", {"x", "y"}},
575 {{"gx"}, "Mul", {"safe_logy", "dz"}},
576 {{"gy"}, "Mul", {"xlogygrad", "dz"}},
577 });
578 // clang-format on
579}
580REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
581
582Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) {
583 // clang-format off
584 return GradForBinaryCwise(g, {
585 FDH::Const("const", 1.0f),
586 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
587 {{"zeros"}, "ZerosLike", {"x"}},
588 {{"yp1"}, "Add", {"y", "one"}},
589 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
590 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
591 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
592 {{"safe_log1py"}, "Xlog1py", {"is_zero_cast", "y"}},
593 {{"xlog1pygrad"}, "Xdivy", {"x", "yp1"}},
594 {{"gx"}, "Mul", {"safe_log1py", "dz"}},
595 {{"gy"}, "Mul", {"xlog1pygrad", "dz"}},
596 });
597 // clang-format on
598}
599REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad);
600
601Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
602 // clang-format off
603 return GradForBinaryCwise(g, {
604 {{"zeros"}, "ZerosLike", {"x"}},
605 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
606 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
607 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
608 {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
609 {{"y2"}, "Square", {"y"}},
610 {{"negy2"}, "Neg", {"y2"}},
611 {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
612 {{"gx"}, "Mul", {"safe_divy", "dz"}},
613 {{"gy"}, "Mul", {"xdivygrad", "dz"}},
614 });
615 // clang-format on
616}
617REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
618
619Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) {
620 // clang-format off
621 return GradForBinaryCwise(g, {
622 FDH::Const("c", int64_t{2}),
623 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
624 {{"x_sub_y"}, "Sub", {"x", "y"}},
625 {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}}, // 2 * (x - y)
626 {{"gx"}, "Mul", {"two_x_sub_y", "dz"}},
627 {{"gy"}, "Neg", {"gx"}}
628 });
629 // clang-format on
630}
631REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad);
632
633Status MaximumMinimumGradHelper(const string& comparator,
634 const AttrSlice& attrs, FunctionDef* g) {
635 // clang-format off
636 return GradForBinaryCwise(g, {
637 {{"c"}, comparator, {"x", "y"}, {}, {"dz"}},
638 {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
639 {{"gx"}, "Mul", {"dz", "mask"}},
640 {{"gy"}, "Sub", {"dz", "gx"}},
641 });
642 // clang-format on
643}
644
645Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
646 return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
647}
648REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
649
650Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
651 return MaximumMinimumGradHelper("LessEqual", attrs, g);
652}
653REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
654
655Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
656 // clang-format off
657 return GradForBinaryCwise(g, {
658 {{"gx"}, "Real", {"dz"}},
659 {{"gy"}, "Imag", {"dz"}},
660 });
661 // clang-format on
662}
663REGISTER_OP_GRADIENT("Complex", ComplexGrad);
664
665// Cwise ternary ops.
666Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
667 // clang-format off
668 *g = FDH::Define(
669 {"c:bool", "x:T", "y:T", "dz:T"},
670 {"dc:bool", "dx:T", "dy:T"},
671 {{"T: {half, float, double}"}},
672 {
673 {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}, {"dz"}},
674 {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}, {"dz"}},
675 {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
676 {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
677 });
678 // clang-format on
679 return OkStatus();
680}
681REGISTER_OP_GRADIENT("Select", SelectGrad);
682
683// N-ry ops
684// REGISTER_OP_GRADIENT("AddN", AddNGrad);
685
686// Reduction ops
687//
688// TODO(zhifengc): This helper is pretty ugly. Do something better.
689// TODO(zhifengc): This can be arrange as a function in the standard library.
690Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
691 // Shape manipulation nodes.
692
693 // clang-format off
694 std::vector<FDH::Node> nodes = {
695 {{"x_shape"}, "Shape", {"x"}},
696 {{"x_rank"}, "Rank", {"x"}},
697 {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
698 FDH::Const("zero", 0),
699 FDH::Const("one", 1),
700 // stitch_idx0 = Range(0, x_rank, 1)
701 {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
702 {{"T", DT_INT32}}},
703 {{"y_shape"}, "DynamicStitch",
704 {"stitch_idx0:output:0", "i",
705 "x_shape:output:0", "stitch_val1:output:0"},
706 {{"N", 2}, {"T", DT_INT32}}},
707 {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
708 {{"T", DT_INT32}}},
709 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
710 };
711 // clang-format on
712 nodes.insert(nodes.end(), body.begin(), body.end());
713 for (auto& n : nodes) {
714 if (n.attr.empty()) {
715 n.attr = {{"T", "$T"}};
716 }
717 }
718 // "Range" doesn't need any attr.
719 nodes.push_back({{"stitch_idx0"},
720 "Range",
721 {"zero:output:0", "x_rank:output:0", "one:output:0"},
722 {}});
723 *g = FDH::Create("_",
724 // Input defs
725 {"x:T", "i:int32", "dy:T"},
726 // Ret val defs
727 {"dx:T", "di:int32"},
728 // Attr defs
729 {{"T: {half, float, double}"}},
730 // Nodes
731 nodes,
732 // Return values
733 {{"dx", "dx:output:0"}, {"di", "di:y:0"}});
734 return OkStatus();
735}
736
737Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
738 // clang-format off
739 return GradForReductionOp(g, {
740 {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
741 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
742 });
743 // clang-format on
744}
745REGISTER_OP_GRADIENT("Sum", SumGrad);
746
747Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
748 // clang-format off
749 return GradForReductionOp(g, {
750 {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
751 {{"T", DT_INT32}}},
752 {{"factor_T"}, "Cast", {"factor:output:0"},
753 {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
754 {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
755 {{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
756 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
757 });
758 // clang-format on
759}
760REGISTER_OP_GRADIENT("Mean", MeanGrad);
761
762// REGISTER_OP_GRADIENT("Prod", ProdGrad);
763// REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
764// REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
765// REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
766// REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
767// REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
768// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
769// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
770// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
771// REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
772
773Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
774 FunctionDef* g) {
775 // clang-format off
776 *g = FDH::Define(
777 // Arg defs
778 {"x:T", "i:int32", "dy:T"},
779 // Ret val defs
780 {"dx:T", "di:int32"},
781 // Attr defs
782 {{"T: {half, float, double}"}},
783 {
784 // keep_dims because we need to do x == y, which requires x
785 // and y are broadcastable.
786 {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
787 {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
788 {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
789 {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
790 {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
791 {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
792 {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
793 {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
794 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
795 });
796 // clang-format on
797 return OkStatus();
798}
799
800Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
801 return MinMaxGradHelper("Max", attrs, g);
802}
803REGISTER_OP_GRADIENT("Max", MaxGrad);
804
805Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
806 return MinMaxGradHelper("Min", attrs, g);
807}
808REGISTER_OP_GRADIENT("Min", MinGrad);
809
810static Status MatMulGradHelper(FunctionDef* g, const string& opname,
811 const string& attr_adj_x,
812 const string& attr_adj_y, const string& x0,
813 bool ax0, const string& x1, bool ax1,
814 const string& y0, bool ay0, const string& y1,
815 bool ay1, bool enable_broadcasting) {
816 // The final outputs are "dx" and "dy". If we're broadcasting compute
817 // intermediate nodes for now.
818 std::vector<FDH::Node> nodes = {
819 {{(enable_broadcasting ? "gx" : "dx")},
820 opname,
821 {x0, x1},
822 {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
823 {{(enable_broadcasting ? "gy" : "dy")},
824 opname,
825 {y0, y1},
826 {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
827 };
828 // TODO(anudhyan): Figure out a way to inspect the static shapes of "x" and
829 // "y". If they have the same batch dimensions, then we can omit adding the
830 // broadcasting-specific ops.
831 if (enable_broadcasting) {
832 std::vector<FDH::Node> unbroadcast_gradients = {
833 FDH::Const<int32>("zero", gtl::ArraySlice<int32>{0}),
834 FDH::Const<int32>("one", gtl::ArraySlice<int32>{1}),
835 FDH::Const<int32>("minustwo", gtl::ArraySlice<int32>{-2}),
836 // Compute the batch shapes of the inputs (all but last two dims).
837 {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
838 {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
839 {{"batch_sx"},
840 "StridedSlice",
841 {"sx", "zero", "minustwo", "one"},
842 {{"T", DT_INT32}, {"Index", DT_INT32}}},
843 {{"batch_sy"},
844 "StridedSlice",
845 {"sy", "zero", "minustwo", "one"},
846 {{"T", DT_INT32}, {"Index", DT_INT32}}},
847 // Sum along dimensions that the inputs were broadcasted across.
848 {{"rx", "ry"}, "BroadcastGradientArgs", {"batch_sx", "batch_sy"}},
849 {{"sum_gx"}, "Sum", {"gx", "rx"}, {{"T", "$T"}}},
850 {{"sum_gy"}, "Sum", {"gy", "ry"}, {{"T", "$T"}}},
851 {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
852 {{"dy"}, "Reshape", {"sum_gy", "sy"}, {{"T", "$T"}}}};
853 nodes.insert(nodes.end(), unbroadcast_gradients.begin(),
854 unbroadcast_gradients.end());
855 }
856 *g = FDH::Define(
857 // Arg defs
858 {"x: T", "y: T", "dz: T"},
859 // Ret val defs
860 {"dx: T", "dy: T"},
861 // Attr defs
862 {{"T: {half, float, double}"}},
863 // Nodes
864 nodes);
865 return OkStatus();
866}
867
868Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
869 const string& attr_adj_y, const AttrSlice& attrs,
870 FunctionDef* g, bool enable_broadcasting) {
871 DataType T;
872 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
873 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
874 return errors::Unimplemented(
875 "MatMul gradient for complex is not supported yet.");
876 }
877 bool ta;
878 bool tb;
879 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
880 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
881 if (!ta && !tb) {
882 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
883 true, "x", true, "dz", false, enable_broadcasting);
884 }
885 if (!ta && tb) {
886 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
887 false, "dz", true, "x", false, enable_broadcasting);
888 }
889 if (ta && !tb) {
890 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
891 true, "x", false, "dz", false, enable_broadcasting);
892 }
893 CHECK(ta && tb);
894 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
895 true, "dz", true, "x", true, enable_broadcasting);
896}
897
898Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
899 return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g,
900 false /* enable_broadcasting */);
901}
902REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
903
904Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
905 return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g,
906 false /* enable_broadcasting */);
907}
908REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
909
910Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) {
911 return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", attrs, g,
912 true /* enable_broadcasting */);
913}
914REGISTER_OP_GRADIENT("BatchMatMulV2", BatchMatMulV2Grad);
915
916// REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
917
918// Comparison ops.
919REGISTER_OP_NO_GRADIENT("Less");
920REGISTER_OP_NO_GRADIENT("LessEqual");
921REGISTER_OP_NO_GRADIENT("Greater");
922REGISTER_OP_NO_GRADIENT("GreaterEqual");
923REGISTER_OP_NO_GRADIENT("Equal");
924REGISTER_OP_NO_GRADIENT("NotEqual");
925
926// Logical ops.
927REGISTER_OP_NO_GRADIENT("LogicalAnd");
928REGISTER_OP_NO_GRADIENT("LogicalOr");
929REGISTER_OP_NO_GRADIENT("LogicalNot");
930
931// Sequence generation ops.
932REGISTER_OP_NO_GRADIENT("Range");
933REGISTER_OP_NO_GRADIENT("LinSpace");
934
935REGISTER_OP_NO_GRADIENT("Floor");
936REGISTER_OP_NO_GRADIENT("FloorDiv");
937REGISTER_OP_NO_GRADIENT("TruncateDiv");
938
939} // end namespace tensorflow
940