1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | |
18 | #include "tensorflow/core/framework/function.h" |
19 | #include "tensorflow/core/framework/types.pb.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/gtl/array_slice.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | typedef FunctionDefHelper FDH; |
26 | |
27 | // Cwise binary ops |
28 | Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) { |
29 | for (auto& n : nodes) { |
30 | if (n.attr.empty()) { |
31 | n.attr = {{"T" , "$T" }}; |
32 | } |
33 | } |
34 | *g = FDH::Define( |
35 | // Arg defs |
36 | {"x: T" , "dy: T" }, |
37 | // Ret val defs |
38 | {"dx: T" }, |
39 | // Attr defs |
40 | {{"T: {half, float, double}" }}, |
41 | // Nodes |
42 | nodes); |
43 | return OkStatus(); |
44 | } |
45 | |
46 | Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { |
47 | // clang-format off |
48 | return GradForUnaryCwise(g, { |
49 | {{"sign" }, "Sign" , {"x" }, {}, {"dy" }}, |
50 | {{"dx" }, "Mul" , {"dy" , "sign" }}, |
51 | }); |
52 | // clang-format on |
53 | } |
54 | REGISTER_OP_GRADIENT("Abs" , AbsGrad); |
55 | |
56 | Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { |
57 | // clang-format off |
58 | return GradForUnaryCwise(g, { |
59 | {{"dx" }, "Neg" , {"dy" }}, |
60 | }); |
61 | // clang-format on |
62 | } |
63 | REGISTER_OP_GRADIENT("Neg" , NegGrad); |
64 | |
65 | Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { |
66 | // clang-format off |
67 | return GradForUnaryCwise(g, { |
68 | {{"y" }, "Reciprocal" , {"x" }}, |
69 | {{"y2" }, "Square" , {"y" }, {}, {"dy" }}, |
70 | {{"y2_neg" }, "Neg" , {"y2" }}, |
71 | {{"dx" }, "Mul" , {"dy" , "y2_neg" }} |
72 | }); |
73 | // clang-format on |
74 | } |
75 | REGISTER_OP_GRADIENT("Inv" , InvGrad); |
76 | REGISTER_OP_GRADIENT("Reciprocal" , InvGrad); |
77 | |
78 | Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { |
79 | // clang-format off |
80 | return GradForUnaryCwise(g, { |
81 | FDH::Const("c" , int64_t{2}), |
82 | {{"two" }, "Cast" , {"c" }, {{"SrcT" , DT_INT64}, {"DstT" , "$T" }}}, |
83 | {{"x2" }, "Mul" , {"x" , "two" }, {}, {"dy" }}, // x * 2 |
84 | {{"dx" }, "Mul" , {"dy" , "x2" }}, // dy * (x * 2) |
85 | }); |
86 | // clang-format on |
87 | } |
88 | REGISTER_OP_GRADIENT("Square" , SquareGrad); |
89 | |
90 | Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { |
91 | // clang-format off |
92 | return GradForUnaryCwise(g, { |
93 | {{"y" }, "Sqrt" , {"x" }}, |
94 | {{"y_inv" }, "Reciprocal" , {"y" }, {}, {"dy" }}, |
95 | FDH::Const("const" , 0.5f), |
96 | {{"half" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
97 | {{"a" }, "Mul" , {"half" , "y_inv" }}, // .5 * 1/y |
98 | {{"dx" }, "Mul" , {"dy" , "a" }}, // dy * (.5 * 1/y) |
99 | }); |
100 | // clang-format on |
101 | } |
102 | REGISTER_OP_GRADIENT("Sqrt" , SqrtGrad); |
103 | |
104 | Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { |
105 | // clang-format off |
106 | return GradForUnaryCwise(g, { |
107 | {{"x_inv" }, "Reciprocal" , {"x" }, {}, {"dy" }}, |
108 | {{"y" }, "Rsqrt" , {"x" }}, |
109 | FDH::Const("const" , -.5f), |
110 | {{"neghalf" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
111 | {{"a" }, "Mul" , {"neghalf" , "x_inv" }}, // -0.5 * 1/x |
112 | {{"b" }, "Mul" , {"a" , "y" }}, // -0.5 * 1/x * y |
113 | {{"dx" }, "Mul" , {"dy" , "b" }}, // dy * (1/y * .5) |
114 | }); |
115 | // clang-format on |
116 | } |
117 | REGISTER_OP_GRADIENT("Rsqrt" , RsqrtGrad); |
118 | |
119 | Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { |
120 | // clang-format off |
121 | return GradForUnaryCwise(g, { |
122 | {{"y" }, "Exp" , {"x" }}, |
123 | {{"dx" }, "Mul" , {"dy" , "y" }}, // dy * y |
124 | }); |
125 | // clang-format on |
126 | } |
127 | REGISTER_OP_GRADIENT("Exp" , ExpGrad); |
128 | |
129 | Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) { |
130 | // clang-format off |
131 | return GradForUnaryCwise(g, { |
132 | {{"y" }, "Exp" , {"x" }}, |
133 | {{"dx" }, "Mul" , {"dy" , "y" }}, // dy * y |
134 | }); |
135 | // clang-format on |
136 | } |
137 | REGISTER_OP_GRADIENT("Expm1" , Expm1Grad); |
138 | |
139 | Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { |
140 | // clang-format off |
141 | return GradForUnaryCwise(g, { |
142 | {{"x_inv" }, "Reciprocal" , {"x" }, {}, {"dy" }}, |
143 | {{"dx" }, "Mul" , {"dy" , "x_inv" }}, // dy * 1/x |
144 | }); |
145 | // clang-format on |
146 | } |
147 | REGISTER_OP_GRADIENT("Log" , LogGrad); |
148 | |
149 | Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) { |
150 | // clang-format off |
151 | return GradForUnaryCwise(g, { |
152 | FDH::Const("const" , 1.0f), |
153 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
154 | {{"a" }, "Add" , {"one" , "x" }}, |
155 | {{"dx" }, "Div" , {"dy" , "a" }}, // dy / (1 + x) |
156 | }); |
157 | // clang-format on |
158 | } |
159 | REGISTER_OP_GRADIENT("Log1p" , Log1pGrad); |
160 | |
161 | Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) { |
162 | // clang-format off |
163 | return GradForUnaryCwise(g, { |
164 | {{"cosh" }, "Cosh" , {"x" }, {}, {"dy" }}, |
165 | {{"dx" }, "Mul" , {"dy" , "cosh" }}, // dy * cosh(x) |
166 | }); |
167 | // clang-format on |
168 | } |
169 | REGISTER_OP_GRADIENT("Sinh" , SinhGrad); |
170 | |
171 | Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) { |
172 | // clang-format off |
173 | return GradForUnaryCwise(g, { |
174 | {{"sinh" }, "Sinh" , {"x" }, {}, {"dy" }}, |
175 | {{"dx" }, "Mul" , {"dy" , "sinh" }}, // dy * sinh(x) |
176 | }); |
177 | // clang-format on |
178 | } |
179 | REGISTER_OP_GRADIENT("Cosh" , CoshGrad); |
180 | |
181 | Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { |
182 | // clang-format off |
183 | return GradForUnaryCwise(g, { |
184 | {{"y" }, "Tanh" , {"x" }}, |
185 | {{"y2" }, "Square" , {"y" }, {}, {"dy" }}, |
186 | FDH::Const("const" , 1.0f), |
187 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
188 | {{"a" }, "Sub" , {"one" , "y2" }}, |
189 | {{"dx" }, "Mul" , {"dy" , "a" }}, // dy * (1 - y*y) |
190 | }); |
191 | // clang-format on |
192 | } |
193 | REGISTER_OP_GRADIENT("Tanh" , TanhGrad); |
194 | |
195 | Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) { |
196 | // clang-format off |
197 | return GradForUnaryCwise(g, { |
198 | {{"y" }, "Asinh" , {"x" }}, |
199 | {{"cosh" }, "Cosh" , {"y" }}, |
200 | {{"dx" }, "Mul" , {"dy" , "cosh" }}, // dy * cosh(y) |
201 | }); |
202 | // clang-format on |
203 | } |
204 | REGISTER_OP_GRADIENT("Asinh" , AsinhGrad); |
205 | |
206 | Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) { |
207 | // clang-format off |
208 | return GradForUnaryCwise(g, { |
209 | {{"y" }, "Acosh" , {"x" }}, |
210 | {{"sinh" }, "Sinh" , {"y" }}, |
211 | {{"dx" }, "Mul" , {"dy" , "sinh" }}, // dy * sinh(y) |
212 | }); |
213 | // clang-format on |
214 | } |
215 | REGISTER_OP_GRADIENT("Acosh" , AcoshGrad); |
216 | |
217 | Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) { |
218 | // clang-format off |
219 | return GradForUnaryCwise(g, { |
220 | {{"x2" }, "Square" , {"x" }}, |
221 | FDH::Const("const" , 1.0f), |
222 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
223 | {{"a" }, "Sub" , {"one" , "x2" }}, // 1 - x^2 |
224 | {{"inv" }, "Reciprocal" , {"a" }}, |
225 | {{"dx" }, "Mul" , {"dy" , "inv" }} |
226 | }); |
227 | // clang-format on |
228 | } |
229 | REGISTER_OP_GRADIENT("Atanh" , AtanhGrad); |
230 | |
231 | Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { |
232 | // clang-format off |
233 | return GradForUnaryCwise(g, { |
234 | {{"y" }, "Sigmoid" , {"x" }}, |
235 | FDH::Const("const" , 1.0f), |
236 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
237 | {{"a" }, "Sub" , {"one" , "y" }, {}, {"dy" }}, |
238 | {{"b" }, "Mul" , {"y" , "a" }}, // y * (1 - y) |
239 | {{"dx" }, "Mul" , {"dy" , "b" }}, // dy * y * (1 - y) |
240 | }); |
241 | // clang-format on |
242 | } |
243 | REGISTER_OP_GRADIENT("Sigmoid" , SigmoidGrad); |
244 | |
245 | Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { |
246 | // clang-format off |
247 | return GradForUnaryCwise(g, { |
248 | {{"s" }, "Shape" , {"x" }}, |
249 | FDH::Const("zero" , 0.f), |
250 | {{"val" }, "Cast" , {"zero" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
251 | {{"dx" }, "Fill" , {"s" , "val" }}, |
252 | }); |
253 | // clang-format on |
254 | } |
255 | REGISTER_OP_GRADIENT("Sign" , SignGrad); |
256 | |
257 | Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { |
258 | // clang-format off |
259 | return GradForUnaryCwise(g, { |
260 | {{"cos" }, "Cos" , {"x" }, {}, {"dy" }}, |
261 | {{"dx" }, "Mul" , {"dy" , "cos" }}, // dy * cos(x) |
262 | }); |
263 | // clang-format on |
264 | } |
265 | REGISTER_OP_GRADIENT("Sin" , SinGrad); |
266 | |
267 | Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { |
268 | // clang-format off |
269 | return GradForUnaryCwise(g, { |
270 | {{"sin" }, "Sin" , {"x" }, {}, {"dy" }}, |
271 | {{"neg" }, "Neg" , {"sin" }}, |
272 | {{"dx" }, "Mul" , {"dy" , "neg" }}, // dy * (-sin(x)) |
273 | }); |
274 | // clang-format on |
275 | } |
276 | REGISTER_OP_GRADIENT("Cos" , CosGrad); |
277 | |
278 | Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) { |
279 | // clang-format off |
280 | return GradForUnaryCwise(g, { |
281 | {{"x2" }, "Square" , {"x" }}, |
282 | FDH::Const("const" , 1.0f), |
283 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
284 | {{"a" }, "Sub" , {"one" , "x2" }}, // 1 - x^2 |
285 | {{"b" }, "Sqrt" , {"a" }}, |
286 | {{"inv" }, "Reciprocal" , {"b" }}, |
287 | {{"neg" }, "Neg" , {"inv" }}, |
288 | {{"dx" }, "Mul" , {"dy" , "neg" }} |
289 | }); |
290 | // clang-format on |
291 | } |
292 | REGISTER_OP_GRADIENT("Acos" , AcosGrad); |
293 | |
294 | Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) { |
295 | // clang-format off |
296 | return GradForUnaryCwise(g, { |
297 | {{"x2" }, "Square" , {"x" }}, |
298 | FDH::Const("const" , 1.0f), |
299 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
300 | {{"a" }, "Sub" , {"one" , "x2" }}, // 1 - x^2 |
301 | {{"b" }, "Sqrt" , {"a" }}, |
302 | {{"inv" }, "Reciprocal" , {"b" }}, |
303 | {{"dx" }, "Mul" , {"dy" , "inv" }} |
304 | }); |
305 | // clang-format on |
306 | } |
307 | REGISTER_OP_GRADIENT("Asin" , AsinGrad); |
308 | |
309 | Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) { |
310 | // clang-format off |
311 | return GradForUnaryCwise(g, { |
312 | {{"x2" }, "Square" , {"x" }}, |
313 | FDH::Const("const" , 1.0f), |
314 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
315 | {{"a" }, "Add" , {"one" , "x2" }}, // 1 + x^2 |
316 | {{"inv" }, "Reciprocal" , {"a" }}, |
317 | {{"dx" }, "Mul" , {"dy" , "inv" }} |
318 | }); |
319 | // clang-format on |
320 | } |
321 | REGISTER_OP_GRADIENT("Atan" , AtanGrad); |
322 | |
323 | Status TanGrad(const AttrSlice& attrs, FunctionDef* g) { |
324 | // clang-format off |
325 | return GradForUnaryCwise(g, { |
326 | {{"cosx" }, "Cos" , {"x" }}, |
327 | {{"secx" }, "Reciprocal" , {"cosx" }}, |
328 | {{"secx2" }, "Square" , {"secx" }}, |
329 | {{"dx" }, "Mul" , {"dy" , "secx2" }} |
330 | }); |
331 | // clang-format on |
332 | } |
333 | REGISTER_OP_GRADIENT("Tan" , TanGrad); |
334 | |
335 | Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { |
336 | // clang-format off |
337 | return GradForUnaryCwise(g, { |
338 | FDH::Const("zero" , 0.f), |
339 | {{"dx" }, "Complex" , {"dy" , "zero" }}, |
340 | }); |
341 | // clang-format on |
342 | } |
343 | REGISTER_OP_GRADIENT("Real" , RealGrad); |
344 | |
345 | Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { |
346 | // clang-format off |
347 | return GradForUnaryCwise(g, { |
348 | FDH::Const("zero" , 0.f), |
349 | {{"dx" }, "Complex" , {"zero" , "dy" }}, |
350 | }); |
351 | // clang-format on |
352 | } |
353 | REGISTER_OP_GRADIENT("Imag" , ImagGrad); |
354 | |
355 | Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) { |
356 | // clang-format off |
357 | return GradForUnaryCwise(g, { |
358 | {{"re" }, "Real" , {"x" }}, |
359 | {{"im" }, "Imag" , {"x" }}, |
360 | {{"z" }, "Complex" , {"im" , "re" }}, |
361 | {{"z_inv" }, "Reciprocal" , {"z" }}, |
362 | {{"neg" }, "Neg" , {"z_inv" }}, |
363 | {{"dx" }, "Mul" , {"neg" , "dy" }}, |
364 | }); |
365 | // clang-format on |
366 | } |
367 | REGISTER_OP_GRADIENT("Angle" , AngleGrad); |
368 | |
369 | Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { |
370 | // clang-format off |
371 | return GradForUnaryCwise(g, { |
372 | {{"dx" }, "Conj" , {"dy" }}, |
373 | }); |
374 | // clang-format on |
375 | } |
376 | REGISTER_OP_GRADIENT("Conj" , ConjGrad); |
377 | |
378 | Status CastGrad(const AttrSlice& attrs, FunctionDef* g) { |
379 | // clang-format off |
380 | *g = FDH::Define( |
381 | // Arg defs |
382 | {"x: SrcT" , "dy: DstT" }, |
383 | // Ret val defs |
384 | {"dx: SrcT" }, |
385 | // Attr defs |
386 | {{"SrcT: type" }, {"DstT: type" }}, |
387 | // Nodes |
388 | {{{"dx" }, "Cast" , {"dy" }, {{"SrcT" , "$DstT" }, {"DstT" , "$SrcT" }}}}); |
389 | return OkStatus(); |
390 | // clang-format on |
391 | } |
392 | REGISTER_OP_GRADIENT("Cast" , CastGrad); |
393 | |
394 | // Cwise binary ops |
395 | // |
396 | // TODO(zhifengc): This can be arrange as a function in the standard |
397 | // library. |
398 | Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) { |
399 | // clang-format off |
400 | std::vector<FDH::Node> nodes = { |
401 | {{"sx" }, "Shape" , {"x" }}, |
402 | {{"sy" }, "Shape" , {"y" }}, |
403 | }; |
404 | nodes.insert(nodes.end(), body.begin(), body.end()); |
405 | std::vector<FDH::Node> reshapes = { |
406 | {{"rx" , "ry" }, "BroadcastGradientArgs" , {"sx" , "sy" }}, |
407 | {{"sum_gx" }, "Sum" , {"gx" , "rx" }}, |
408 | {{"dx" }, "Reshape" , {"sum_gx" , "sx" }}, |
409 | {{"sum_gy" }, "Sum" , {"gy" , "ry" }}, |
410 | {{"dy" }, "Reshape" , {"sum_gy" , "sy" }}, |
411 | }; |
412 | nodes.insert(nodes.end(), reshapes.begin(), reshapes.end()); |
413 | |
414 | // clang-format on |
415 | for (auto& n : nodes) { |
416 | // "BroadcastGradientArgs" doesn't need any attrs. |
417 | if (n.attr.empty() && n.op != "BroadcastGradientArgs" ) { |
418 | n.attr = {{"T" , "$T" }}; |
419 | } |
420 | } |
421 | *g = FDH::Define( |
422 | // Arg defs |
423 | {"x: T" , "y: T" , "dz: T" }, |
424 | // Ret val defs |
425 | {"dx: T" , "dy: T" }, |
426 | // Attr defs |
427 | {{"T: {half, float, double}" }}, |
428 | // Nodes |
429 | nodes); |
430 | return OkStatus(); |
431 | } |
432 | |
433 | Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { |
434 | // clang-format off |
435 | return GradForBinaryCwise(g, { |
436 | {{"gx" }, "Identity" , {"dz" }}, |
437 | {{"gy" }, "Identity" , {"dz" }}, |
438 | }); |
439 | // clang-format on |
440 | } |
441 | REGISTER_OP_GRADIENT("Add" , AddGrad); |
442 | REGISTER_OP_GRADIENT("AddV2" , AddGrad); |
443 | |
444 | Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { |
445 | // clang-format off |
446 | return GradForBinaryCwise(g, { |
447 | {{"gx" }, "Identity" , {"dz" }}, |
448 | {{"gy" }, "Neg" , {"dz" }}, // -dz |
449 | }); |
450 | // clang-format on |
451 | } |
452 | REGISTER_OP_GRADIENT("Sub" , SubGrad); |
453 | |
454 | Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { |
455 | DataType T; |
456 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T" , &T)); |
457 | if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { |
458 | return GradForBinaryCwise( |
459 | g, { |
460 | {{"cy" }, "Conj" , {"y" }, {}, {"dz" }}, |
461 | {{"gx" }, "Mul" , {"dz" , "cy" }}, // dz * Conj(y) |
462 | {{"cx" }, "Conj" , {"x" }, {}, {"dz" }}, |
463 | {{"gy" }, "Mul" , {"cx" , "dz" }}, // Conj(x) * dz |
464 | }); |
465 | } else { |
466 | // clang-format off |
467 | return GradForBinaryCwise(g, { |
468 | {{"gx" }, "Mul" , {"dz" , "y" }}, // dz * y |
469 | {{"gy" }, "Mul" , {"x" , "dz" }}, // x * dz |
470 | }); |
471 | // clang-format on |
472 | } |
473 | } |
474 | REGISTER_OP_GRADIENT("Mul" , MulGrad); |
475 | |
476 | Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { |
477 | // clang-format off |
478 | return GradForBinaryCwise(g, { |
479 | {{"gx" }, "MulNoNan" , {"y" , "dz" }}, // y * dz |
480 | {{"gy" }, "MulNoNan" , {"x" , "dz" }}, // x * dz |
481 | }); |
482 | // clang-format on |
483 | } |
484 | REGISTER_OP_GRADIENT("MulNoNan" , MulGrad); |
485 | |
486 | Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { |
487 | // clang-format off |
488 | return GradForBinaryCwise(g, { |
489 | {{"gx" }, "Div" , {"dz" , "y" }}, |
490 | {{"nx" }, "Neg" , {"x" }, {}, {"dz" }}, |
491 | {{"y2" }, "Square" , {"y" }, {}, {"dz" }}, |
492 | {{"nx_y2" }, "Div" , {"nx" , "y2" }}, |
493 | {{"gy" }, "Mul" , {"dz" , "nx_y2" }}, // dz * (- x / y^2) |
494 | }); |
495 | // clang-format on |
496 | } |
497 | REGISTER_OP_GRADIENT("Div" , DivGrad); |
498 | |
499 | Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) { |
500 | // clang-format off |
501 | return GradForBinaryCwise(g, { |
502 | {{"gx" }, "RealDiv" , {"dz" , "y" }}, |
503 | {{"nx" }, "Neg" , {"x" }, {}, {"dz" }}, |
504 | {{"y2" }, "Square" , {"y" }, {}, {"dz" }}, |
505 | {{"nx_y2" }, "RealDiv" , {"nx" , "y2" }}, |
506 | {{"gy" }, "Mul" , {"dz" , "nx_y2" }}, // dz * (- x / y^2) |
507 | }); |
508 | // clang-format on |
509 | } |
510 | REGISTER_OP_GRADIENT("RealDiv" , RealDivGrad); |
511 | |
512 | Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { |
513 | // clang-format off |
514 | return GradForBinaryCwise(g, { |
515 | {{"gx" }, "DivNoNan" , {"dz" , "y" }}, |
516 | {{"nx" }, "Neg" , {"x" }, {}, {"dz" }}, |
517 | {{"y2" }, "Square" , {"y" }, {}, {"dz" }}, |
518 | {{"nx_y2" }, "DivNoNan" , {"nx" , "y2" }}, |
519 | {{"gy" }, "Mul" , {"dz" , "nx_y2" }}, // dz * (- x / y^2) |
520 | }); |
521 | // clang-format on |
522 | } |
523 | REGISTER_OP_GRADIENT("DivNoNan" , DivNoNanGrad); |
524 | |
525 | Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { |
526 | // clang-format off |
527 | std::vector<FDH::Node> nodes = { |
528 | {{"z" }, "Pow" , {"x" , "y" }}, |
529 | // dz * y * Pow(x, y - 1) |
530 | FDH::Const("const_zero" , 0.0f), |
531 | FDH::Const("const_one" , 1.0f), |
532 | {{"zero" }, "Cast" , {"const_zero" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
533 | {{"one" }, "Cast" , {"const_one" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
534 | {{"t0" }, "Sub" , {"y" , "one" }, {}, {"dz" }}, |
535 | {{"t1" }, "Pow" , {"x" , "t0" }}, |
536 | {{"t2" }, "Mul" , {"dz" , "y" }}, |
537 | {{"gx" }, "Mul" , {"t1" , "t2" }}, |
538 | {{"unsafe_log" }, "Log" , {"x" }, {}, {"dz" }}, |
539 | {{"zeros" }, "ZerosLike" , {"x" }}}; |
540 | // clang-format on |
541 | std::vector<FDH::Node> log_x_handling; |
542 | DataType T; |
543 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T" , &T)); |
544 | if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { |
545 | // dz * z * (x != 0 ? Log(x) : 0) |
546 | // clang-format off |
547 | log_x_handling = { |
548 | {{"nz_x" }, "NotEqual" , {"x" , "zero" }}, |
549 | {{"safe_log" }, "Select" , {"nz_x" , "unsafe_log" , "zeros" }}}; |
550 | // clang-format on |
551 | } else { |
552 | // dz * z * (x > 0 ? Log(x) : 0) |
553 | // clang-format off |
554 | log_x_handling = { |
555 | {{"pos_x" }, "Greater" , {"x" , "zero" }}, |
556 | {{"safe_log" }, "Select" , {"pos_x" , "unsafe_log" , "zeros" }}}; |
557 | // clang-format on |
558 | } |
559 | nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end()); |
560 | nodes.push_back({{"t4" }, "Mul" , {"dz" , "z" }}); |
561 | nodes.push_back({{"gy" }, "Mul" , {"safe_log" , "t4" }}); |
562 | return GradForBinaryCwise(g, nodes); |
563 | } |
564 | REGISTER_OP_GRADIENT("Pow" , PowGrad); |
565 | |
566 | Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { |
567 | // clang-format off |
568 | return GradForBinaryCwise(g, { |
569 | {{"zeros" }, "ZerosLike" , {"x" }}, |
570 | {{"is_x_zero" }, "NotEqual" , {"x" , "zeros" }}, |
571 | {{"is_zero_cast" }, "Cast" , {"is_x_zero" }, |
572 | {{"SrcT" , DT_BOOL}, {"DstT" , "$T" }}}, |
573 | {{"safe_logy" }, "Xlogy" , {"is_zero_cast" , "y" }}, |
574 | {{"xlogygrad" }, "Xdivy" , {"x" , "y" }}, |
575 | {{"gx" }, "Mul" , {"safe_logy" , "dz" }}, |
576 | {{"gy" }, "Mul" , {"xlogygrad" , "dz" }}, |
577 | }); |
578 | // clang-format on |
579 | } |
580 | REGISTER_OP_GRADIENT("Xlogy" , XlogyGrad); |
581 | |
582 | Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) { |
583 | // clang-format off |
584 | return GradForBinaryCwise(g, { |
585 | FDH::Const("const" , 1.0f), |
586 | {{"one" }, "Cast" , {"const" }, {{"SrcT" , DT_FLOAT}, {"DstT" , "$T" }}}, |
587 | {{"zeros" }, "ZerosLike" , {"x" }}, |
588 | {{"yp1" }, "Add" , {"y" , "one" }}, |
589 | {{"is_x_zero" }, "NotEqual" , {"x" , "zeros" }}, |
590 | {{"is_zero_cast" }, "Cast" , {"is_x_zero" }, |
591 | {{"SrcT" , DT_BOOL}, {"DstT" , "$T" }}}, |
592 | {{"safe_log1py" }, "Xlog1py" , {"is_zero_cast" , "y" }}, |
593 | {{"xlog1pygrad" }, "Xdivy" , {"x" , "yp1" }}, |
594 | {{"gx" }, "Mul" , {"safe_log1py" , "dz" }}, |
595 | {{"gy" }, "Mul" , {"xlog1pygrad" , "dz" }}, |
596 | }); |
597 | // clang-format on |
598 | } |
599 | REGISTER_OP_GRADIENT("Xlog1py" , Xlog1pyGrad); |
600 | |
601 | Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { |
602 | // clang-format off |
603 | return GradForBinaryCwise(g, { |
604 | {{"zeros" }, "ZerosLike" , {"x" }}, |
605 | {{"is_x_zero" }, "NotEqual" , {"x" , "zeros" }}, |
606 | {{"is_zero_cast" }, "Cast" , {"is_x_zero" }, |
607 | {{"SrcT" , DT_BOOL}, {"DstT" , "$T" }}}, |
608 | {{"safe_divy" }, "Xdivy" , {"is_zero_cast" , "y" }}, |
609 | {{"y2" }, "Square" , {"y" }}, |
610 | {{"negy2" }, "Neg" , {"y2" }}, |
611 | {{"xdivygrad" }, "Xdivy" , {"x" , "negy2" }}, |
612 | {{"gx" }, "Mul" , {"safe_divy" , "dz" }}, |
613 | {{"gy" }, "Mul" , {"xdivygrad" , "dz" }}, |
614 | }); |
615 | // clang-format on |
616 | } |
617 | REGISTER_OP_GRADIENT("Xdivy" , XdivyGrad); |
618 | |
619 | Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) { |
620 | // clang-format off |
621 | return GradForBinaryCwise(g, { |
622 | FDH::Const("c" , int64_t{2}), |
623 | {{"two" }, "Cast" , {"c" }, {{"SrcT" , DT_INT64}, {"DstT" , "$T" }}}, |
624 | {{"x_sub_y" }, "Sub" , {"x" , "y" }}, |
625 | {{"two_x_sub_y" }, "Mul" , {"two" , "x_sub_y" }}, // 2 * (x - y) |
626 | {{"gx" }, "Mul" , {"two_x_sub_y" , "dz" }}, |
627 | {{"gy" }, "Neg" , {"gx" }} |
628 | }); |
629 | // clang-format on |
630 | } |
631 | REGISTER_OP_GRADIENT("SquaredDifference" , SquaredDifferenceGrad); |
632 | |
633 | Status MaximumMinimumGradHelper(const string& comparator, |
634 | const AttrSlice& attrs, FunctionDef* g) { |
635 | // clang-format off |
636 | return GradForBinaryCwise(g, { |
637 | {{"c" }, comparator, {"x" , "y" }, {}, {"dz" }}, |
638 | {{"mask" }, "Cast" , {"c" }, {{"SrcT" , DT_BOOL}, {"DstT" , "$T" }}}, |
639 | {{"gx" }, "Mul" , {"dz" , "mask" }}, |
640 | {{"gy" }, "Sub" , {"dz" , "gx" }}, |
641 | }); |
642 | // clang-format on |
643 | } |
644 | |
645 | Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) { |
646 | return MaximumMinimumGradHelper("GreaterEqual" , attrs, g); |
647 | } |
648 | REGISTER_OP_GRADIENT("Maximum" , MaximumGrad); |
649 | |
650 | Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) { |
651 | return MaximumMinimumGradHelper("LessEqual" , attrs, g); |
652 | } |
653 | REGISTER_OP_GRADIENT("Minimum" , MinimumGrad); |
654 | |
655 | Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { |
656 | // clang-format off |
657 | return GradForBinaryCwise(g, { |
658 | {{"gx" }, "Real" , {"dz" }}, |
659 | {{"gy" }, "Imag" , {"dz" }}, |
660 | }); |
661 | // clang-format on |
662 | } |
663 | REGISTER_OP_GRADIENT("Complex" , ComplexGrad); |
664 | |
665 | // Cwise ternary ops. |
666 | Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) { |
667 | // clang-format off |
668 | *g = FDH::Define( |
669 | {"c:bool" , "x:T" , "y:T" , "dz:T" }, |
670 | {"dc:bool" , "dx:T" , "dy:T" }, |
671 | {{"T: {half, float, double}" }}, |
672 | { |
673 | {{"dc" }, "ZerosLike" , {"c" }, {{"T" , DT_BOOL}}, {"dz" }}, |
674 | {{"zeros" }, "ZerosLike" , {"x" }, {{"T" , "$T" }}, {"dz" }}, |
675 | {{"dx" }, "Select" , {"c" , "dz" , "zeros" }, {{"T" , "$T" }}}, |
676 | {{"dy" }, "Select" , {"c" , "zeros" , "dz" }, {{"T" , "$T" }}}, |
677 | }); |
678 | // clang-format on |
679 | return OkStatus(); |
680 | } |
681 | REGISTER_OP_GRADIENT("Select" , SelectGrad); |
682 | |
683 | // N-ry ops |
684 | // REGISTER_OP_GRADIENT("AddN", AddNGrad); |
685 | |
686 | // Reduction ops |
687 | // |
688 | // TODO(zhifengc): This helper is pretty ugly. Do something better. |
689 | // TODO(zhifengc): This can be arrange as a function in the standard library. |
690 | Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) { |
691 | // Shape manipulation nodes. |
692 | |
693 | // clang-format off |
694 | std::vector<FDH::Node> nodes = { |
695 | {{"x_shape" }, "Shape" , {"x" }}, |
696 | {{"x_rank" }, "Rank" , {"x" }}, |
697 | {{"i_shape" }, "Shape" , {"i" }, {{"T" , DT_INT32}}}, |
698 | FDH::Const("zero" , 0), |
699 | FDH::Const("one" , 1), |
700 | // stitch_idx0 = Range(0, x_rank, 1) |
701 | {{"stitch_val1" }, "Fill" , {"i_shape:output:0" , "one:output:0" }, |
702 | {{"T" , DT_INT32}}}, |
703 | {{"y_shape" }, "DynamicStitch" , |
704 | {"stitch_idx0:output:0" , "i" , |
705 | "x_shape:output:0" , "stitch_val1:output:0" }, |
706 | {{"N" , 2}, {"T" , DT_INT32}}}, |
707 | {{"tile_scaling" }, "Div" , {"x_shape:output:0" , "y_shape:merged:0" }, |
708 | {{"T" , DT_INT32}}}, |
709 | {{"di" }, "ZerosLike" , {"i" }, {{"T" , DT_INT32}}} |
710 | }; |
711 | // clang-format on |
712 | nodes.insert(nodes.end(), body.begin(), body.end()); |
713 | for (auto& n : nodes) { |
714 | if (n.attr.empty()) { |
715 | n.attr = {{"T" , "$T" }}; |
716 | } |
717 | } |
718 | // "Range" doesn't need any attr. |
719 | nodes.push_back({{"stitch_idx0" }, |
720 | "Range" , |
721 | {"zero:output:0" , "x_rank:output:0" , "one:output:0" }, |
722 | {}}); |
723 | *g = FDH::Create("_" , |
724 | // Input defs |
725 | {"x:T" , "i:int32" , "dy:T" }, |
726 | // Ret val defs |
727 | {"dx:T" , "di:int32" }, |
728 | // Attr defs |
729 | {{"T: {half, float, double}" }}, |
730 | // Nodes |
731 | nodes, |
732 | // Return values |
733 | {{"dx" , "dx:output:0" }, {"di" , "di:y:0" }}); |
734 | return OkStatus(); |
735 | } |
736 | |
737 | Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { |
738 | // clang-format off |
739 | return GradForReductionOp(g, { |
740 | {{"dy_reshaped" }, "Reshape" , {"dy" , "y_shape:merged:0" }}, |
741 | {{"dx" }, "Tile" , {"dy_reshaped:output:0" , "tile_scaling:z:0" }}, |
742 | }); |
743 | // clang-format on |
744 | } |
745 | REGISTER_OP_GRADIENT("Sum" , SumGrad); |
746 | |
747 | Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { |
748 | // clang-format off |
749 | return GradForReductionOp(g, { |
750 | {{"factor" }, "Prod" , {"tile_scaling:z:0" , "zero:output:0" }, |
751 | {{"T" , DT_INT32}}}, |
752 | {{"factor_T" }, "Cast" , {"factor:output:0" }, |
753 | {{"SrcT" , DT_INT32}, {"DstT" , "$T" }}}, |
754 | {{"dy_scaled" }, "Div" , {"dy" , "factor_T:y:0" }}, |
755 | {{"dy_reshaped" }, "Reshape" , {"dy_scaled:z:0" , "y_shape:merged:0" }}, |
756 | {{"dx" }, "Tile" , {"dy_reshaped:output:0" , "tile_scaling:z:0" }}, |
757 | }); |
758 | // clang-format on |
759 | } |
760 | REGISTER_OP_GRADIENT("Mean" , MeanGrad); |
761 | |
762 | // REGISTER_OP_GRADIENT("Prod", ProdGrad); |
763 | // REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad); |
764 | // REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad); |
765 | // REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad); |
766 | // REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad); |
767 | // REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad); |
768 | // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad); |
769 | // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad); |
770 | // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad); |
771 | // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad); |
772 | |
773 | Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, |
774 | FunctionDef* g) { |
775 | // clang-format off |
776 | *g = FDH::Define( |
777 | // Arg defs |
778 | {"x:T" , "i:int32" , "dy:T" }, |
779 | // Ret val defs |
780 | {"dx:T" , "di:int32" }, |
781 | // Attr defs |
782 | {{"T: {half, float, double}" }}, |
783 | { |
784 | // keep_dims because we need to do x == y, which requires x |
785 | // and y are broadcastable. |
786 | {{"y" }, op, {"x" , "i" }, {{"T" , "$T" }, {"keep_dims" , true}}}, |
787 | {{"mask" }, "Equal" , {"x" , "y" }, {{"T" , "$T" }}}, |
788 | {{"mask_cast" }, "Cast" , {"mask" }, {{"SrcT" , DT_BOOL}, {"DstT" , "$T" }}}, |
789 | {{"mask_sum" }, "Sum" , {"mask_cast" , "i" }, {{"T" , "$T" }}}, |
790 | {{"norm_dy" }, "Div" , {"dy" , "mask_sum" }, {{"T" , "$T" }}}, |
791 | {{"sy" }, "Shape" , {"y" }, {{"T" , "$T" }}}, |
792 | {{"norm_dy_reshaped" }, "Reshape" , {"norm_dy" , "sy" }, {{"T" , "$T" }}}, |
793 | {{"dx" }, "Mul" , {"mask_cast" , "norm_dy_reshaped" }, {{"T" , "$T" }}}, |
794 | {{"di" }, "ZerosLike" , {"i" }, {{"T" , DT_INT32}}} |
795 | }); |
796 | // clang-format on |
797 | return OkStatus(); |
798 | } |
799 | |
800 | Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) { |
801 | return MinMaxGradHelper("Max" , attrs, g); |
802 | } |
803 | REGISTER_OP_GRADIENT("Max" , MaxGrad); |
804 | |
805 | Status MinGrad(const AttrSlice& attrs, FunctionDef* g) { |
806 | return MinMaxGradHelper("Min" , attrs, g); |
807 | } |
808 | REGISTER_OP_GRADIENT("Min" , MinGrad); |
809 | |
810 | static Status MatMulGradHelper(FunctionDef* g, const string& opname, |
811 | const string& attr_adj_x, |
812 | const string& attr_adj_y, const string& x0, |
813 | bool ax0, const string& x1, bool ax1, |
814 | const string& y0, bool ay0, const string& y1, |
815 | bool ay1, bool enable_broadcasting) { |
816 | // The final outputs are "dx" and "dy". If we're broadcasting compute |
817 | // intermediate nodes for now. |
818 | std::vector<FDH::Node> nodes = { |
819 | {{(enable_broadcasting ? "gx" : "dx" )}, |
820 | opname, |
821 | {x0, x1}, |
822 | {{"T" , "$T" }, {attr_adj_x, ax0}, {attr_adj_y, ax1}}}, |
823 | {{(enable_broadcasting ? "gy" : "dy" )}, |
824 | opname, |
825 | {y0, y1}, |
826 | {{"T" , "$T" }, {attr_adj_x, ay0}, {attr_adj_y, ay1}}}, |
827 | }; |
828 | // TODO(anudhyan): Figure out a way to inspect the static shapes of "x" and |
829 | // "y". If they have the same batch dimensions, then we can omit adding the |
830 | // broadcasting-specific ops. |
831 | if (enable_broadcasting) { |
832 | std::vector<FDH::Node> unbroadcast_gradients = { |
833 | FDH::Const<int32>("zero" , gtl::ArraySlice<int32>{0}), |
834 | FDH::Const<int32>("one" , gtl::ArraySlice<int32>{1}), |
835 | FDH::Const<int32>("minustwo" , gtl::ArraySlice<int32>{-2}), |
836 | // Compute the batch shapes of the inputs (all but last two dims). |
837 | {{"sx" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
838 | {{"sy" }, "Shape" , {"y" }, {{"T" , "$T" }}}, |
839 | {{"batch_sx" }, |
840 | "StridedSlice" , |
841 | {"sx" , "zero" , "minustwo" , "one" }, |
842 | {{"T" , DT_INT32}, {"Index" , DT_INT32}}}, |
843 | {{"batch_sy" }, |
844 | "StridedSlice" , |
845 | {"sy" , "zero" , "minustwo" , "one" }, |
846 | {{"T" , DT_INT32}, {"Index" , DT_INT32}}}, |
847 | // Sum along dimensions that the inputs were broadcasted across. |
848 | {{"rx" , "ry" }, "BroadcastGradientArgs" , {"batch_sx" , "batch_sy" }}, |
849 | {{"sum_gx" }, "Sum" , {"gx" , "rx" }, {{"T" , "$T" }}}, |
850 | {{"sum_gy" }, "Sum" , {"gy" , "ry" }, {{"T" , "$T" }}}, |
851 | {{"dx" }, "Reshape" , {"sum_gx" , "sx" }, {{"T" , "$T" }}}, |
852 | {{"dy" }, "Reshape" , {"sum_gy" , "sy" }, {{"T" , "$T" }}}}; |
853 | nodes.insert(nodes.end(), unbroadcast_gradients.begin(), |
854 | unbroadcast_gradients.end()); |
855 | } |
856 | *g = FDH::Define( |
857 | // Arg defs |
858 | {"x: T" , "y: T" , "dz: T" }, |
859 | // Ret val defs |
860 | {"dx: T" , "dy: T" }, |
861 | // Attr defs |
862 | {{"T: {half, float, double}" }}, |
863 | // Nodes |
864 | nodes); |
865 | return OkStatus(); |
866 | } |
867 | |
868 | Status MatMulGradCommon(const string& opname, const string& attr_adj_x, |
869 | const string& attr_adj_y, const AttrSlice& attrs, |
870 | FunctionDef* g, bool enable_broadcasting) { |
871 | DataType T; |
872 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T" , &T)); |
873 | if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { |
874 | return errors::Unimplemented( |
875 | "MatMul gradient for complex is not supported yet." ); |
876 | } |
877 | bool ta; |
878 | bool tb; |
879 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta)); |
880 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb)); |
881 | if (!ta && !tb) { |
882 | return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz" , false, "y" , |
883 | true, "x" , true, "dz" , false, enable_broadcasting); |
884 | } |
885 | if (!ta && tb) { |
886 | return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz" , false, "y" , |
887 | false, "dz" , true, "x" , false, enable_broadcasting); |
888 | } |
889 | if (ta && !tb) { |
890 | return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y" , false, "dz" , |
891 | true, "x" , false, "dz" , false, enable_broadcasting); |
892 | } |
893 | CHECK(ta && tb); |
894 | return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y" , true, "dz" , |
895 | true, "dz" , true, "x" , true, enable_broadcasting); |
896 | } |
897 | |
898 | Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { |
899 | return MatMulGradCommon("MatMul" , "transpose_a" , "transpose_b" , attrs, g, |
900 | false /* enable_broadcasting */); |
901 | } |
902 | REGISTER_OP_GRADIENT("MatMul" , MatMulGrad); |
903 | |
904 | Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) { |
905 | return MatMulGradCommon("BatchMatMul" , "adj_x" , "adj_y" , attrs, g, |
906 | false /* enable_broadcasting */); |
907 | } |
908 | REGISTER_OP_GRADIENT("BatchMatMul" , BatchMatMulGrad); |
909 | |
910 | Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) { |
911 | return MatMulGradCommon("BatchMatMulV2" , "adj_x" , "adj_y" , attrs, g, |
912 | true /* enable_broadcasting */); |
913 | } |
914 | REGISTER_OP_GRADIENT("BatchMatMulV2" , BatchMatMulV2Grad); |
915 | |
916 | // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad); |
917 | |
918 | // Comparison ops. |
919 | REGISTER_OP_NO_GRADIENT("Less" ); |
920 | REGISTER_OP_NO_GRADIENT("LessEqual" ); |
921 | REGISTER_OP_NO_GRADIENT("Greater" ); |
922 | REGISTER_OP_NO_GRADIENT("GreaterEqual" ); |
923 | REGISTER_OP_NO_GRADIENT("Equal" ); |
924 | REGISTER_OP_NO_GRADIENT("NotEqual" ); |
925 | |
926 | // Logical ops. |
927 | REGISTER_OP_NO_GRADIENT("LogicalAnd" ); |
928 | REGISTER_OP_NO_GRADIENT("LogicalOr" ); |
929 | REGISTER_OP_NO_GRADIENT("LogicalNot" ); |
930 | |
931 | // Sequence generation ops. |
932 | REGISTER_OP_NO_GRADIENT("Range" ); |
933 | REGISTER_OP_NO_GRADIENT("LinSpace" ); |
934 | |
935 | REGISTER_OP_NO_GRADIENT("Floor" ); |
936 | REGISTER_OP_NO_GRADIENT("FloorDiv" ); |
937 | REGISTER_OP_NO_GRADIENT("TruncateDiv" ); |
938 | |
939 | } // end namespace tensorflow |
940 | |