1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17#include "tensorflow/core/framework/function.h"
18#include "tensorflow/core/lib/core/errors.h"
19
20namespace tensorflow {
21
22typedef FunctionDefHelper FDH;
23
24REGISTER_OP_NO_GRADIENT("Shape");
25REGISTER_OP_NO_GRADIENT("Rank");
26REGISTER_OP_NO_GRADIENT("Size");
27REGISTER_OP_NO_GRADIENT("ZerosLike");
28REGISTER_OP_NO_GRADIENT("OnesLike");
29REGISTER_OP_NO_GRADIENT("Const");
30REGISTER_OP_NO_GRADIENT("EditDistance");
31REGISTER_OP_NO_GRADIENT("StopGradient");
32
33Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
34 // clang-format off
35 *g = FDH::Define(
36 // Arg defs
37 {"x: T", "shape: int32", "dy: T"},
38 // Ret val defs
39 {"dx: T", "dshape: int32"},
40 // Attr defs
41 {"T: type"},
42 // Nodes
43 {
44 {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
45 {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
46 {{"dshape"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
47 });
48 // clang-format on
49 return OkStatus();
50}
51REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
52REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad);
53
54Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) {
55 // clang-format off
56 *g = FDH::Define(
57 // Arg defs
58 {"x: T", "dy: T"},
59 // Ret val defs
60 {"dx: T"},
61 // Attr defs
62 {"T: type"},
63 // Nodes
64 {
65 {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
66 {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
67 });
68 // clang-format on
69 return OkStatus();
70}
71REGISTER_OP_GRADIENT("Squeeze", SqueezeGrad);
72
73Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) {
74 // clang-format off
75 *g = FDH::Define(
76 // Arg defs
77 {"x: T", "dy: T"},
78 // Ret val defs
79 {"dx: T"},
80 // Attr defs
81 {"T: type"},
82 // Nodes
83 {
84 {{"dx"}, "Identity", {"dy"}, {{"T", "$T"}}},
85 });
86 // clang-format on
87 VLOG(1) << "IdentityGrad " << DebugString(*g);
88 return OkStatus();
89}
90REGISTER_OP_GRADIENT("Identity", IdentityGrad);
91
92Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
93 // clang-format off
94 *g = FDH::Create(
95 "_",
96 // Arg defs
97 {"x: N*T", "dy: T"},
98 // Ret val defs
99 {"dx: N*T"},
100 // Attr defs
101 {"T: type", "N: int", "axis: int"},
102 // Nodes
103 {
104 {
105 {"dx"},
106 "Unpack",
107 {"dy"},
108 {{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
109 },
110 },
111 {{"dx", "dx:output"}});
112 // clang-format on
113 VLOG(1) << "PackGrad " << DebugString(*g);
114 return OkStatus();
115}
116REGISTER_OP_GRADIENT("Pack", PackGrad);
117
118Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) {
119 // clang-format off
120 *g = FDH::Define(
121 // Arg defs
122 {"x: T", "dy: num*T"},
123 // Ret val defs
124 {"dx: T"},
125 // Attr defs
126 {"T: type", "num: int", "axis: int"},
127 // Nodes
128 {
129 {
130 {"dx"},
131 "Pack",
132 {"dy"},
133 {{"T", "$T"}, {"N", "$num"}, {"axis", "$axis"}}
134 },
135 });
136 // clang-format on
137 VLOG(1) << "UnpackGrad " << DebugString(*g);
138 return OkStatus();
139}
140REGISTER_OP_GRADIENT("Unpack", UnpackGrad);
141
142Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
143 bool dim_is_last_arg) {
144 int N;
145 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
146 DataType T;
147 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
148
149 std::vector<string> shape_i;
150 std::vector<string> offset_i;
151 std::vector<string> dx_i;
152 for (int i = 0; i < N; ++i) {
153 shape_i.push_back(strings::StrCat("shapes:output:", i));
154 offset_i.push_back(strings::StrCat("offset:offset:", i));
155 dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
156 }
157 DataTypeVector dtype_list(N, T);
158
159 // ConcatGrad(dim, x, dy):
160 // for i in range(N):
161 // dx[i] = Slice(dy, offset[i], shape[x[i]]),
162 // where offset[i] is the offset of x[i] in the output y,
163 // which is the same as dx[i]'s offset within dy.
164 std::vector<FDH::Node> nodes{
165 {{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
166 {{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
167 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
168 {{"dx"},
169 "_ListToArray",
170 dx_i,
171 {{"T", "$T"}, {"N", "$N"}, {"Tin", DataTypeVector(N, T)}}}};
172
173 // For each dx[i], we take a slice of dy. The offset and size of the
174 // slice is given by offset[i] and shape[i].
175 for (int i = 0; i < N; ++i) {
176 nodes.push_back({{strings::StrCat("dx_", i)},
177 "Slice",
178 {"dy", offset_i[i], shape_i[i]},
179 {{"T", "$T"}, {"Index", DT_INT32}}});
180 }
181 if (dim_is_last_arg) {
182 // clang-format off
183 *g = FDH::Create(
184 "_",
185 // Arg defs
186 {"x: N*T", "dim: int32", "dy: T"},
187 // Return signature
188 {"dx: N*T", "d_dim: int32"},
189 // Attr defs
190 {"T: type", "N: int"},
191 // Nodes
192 nodes,
193 // Return values
194 {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
195 // clang-format on
196 } else {
197 // clang-format off
198 *g = FDH::Create(
199 "_",
200 // Arg defs
201 {"dim: int32", "x: N*T", "dy: T"},
202 // Return signature
203 {"d_dim: int32", "dx: N*T"},
204 // Attr defs
205 {"T: type", "N: int"},
206 // Nodes
207 nodes,
208 // Return values
209 {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
210 // clang-format on
211 }
212 VLOG(1) << "ConcatGrad " << DebugString(*g);
213 return OkStatus();
214}
215
216Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) {
217 return ConcatGradHelper(attrs, g, false);
218}
219
220Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) {
221 return ConcatGradHelper(attrs, g, true);
222}
223
224REGISTER_OP_GRADIENT("Concat", ConcatGrad);
225REGISTER_OP_GRADIENT("ConcatV2", ConcatGradV2);
226
227Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) {
228 // clang-format off
229 *g = FDH::Define(
230 // Arg defs
231 {"dim: int32", "x: T", "dy: num_split*T"},
232 // Ret val defs
233 {"d_dim: int32", "dx: T"},
234 // Attr defs
235 {"T: type", "num_split: int"},
236 // Nodes
237 {
238 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
239 {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}}
240 });
241 // clang-format on
242 VLOG(1) << "SplitGrad " << DebugString(*g);
243 return OkStatus();
244}
245REGISTER_OP_GRADIENT("Split", SplitGrad);
246
247Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) {
248 // clang-format off
249 *g = FDH::Define(
250 // Arg defs
251 {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"},
252 // Ret val defs
253 {"dx: T", "d_size_splits: Tlen", "d_dim: int32"},
254 // Attr defs
255 {"T: type", "Tlen: type", "num_split: int"},
256 // Nodes
257 {
258 {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}},
259 {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}},
260 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
261 });
262 // clang-format on
263 VLOG(1) << "SplitVGrad " << DebugString(*g);
264 return OkStatus();
265}
266REGISTER_OP_GRADIENT("SplitV", SplitVGrad);
267
268Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) {
269 int N;
270 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
271 std::vector<string> dys;
272 dys.reserve(N);
273 for (int i = 0; i < N; ++i) {
274 dys.push_back(strings::StrCat("dy:", i));
275 }
276 // clang-format off
277 *g = FDH::Define(
278 // Arg defs
279 {"x: N*T", "dy: out_types"},
280 // Ret val defs
281 {"dx: N*T"},
282 // Attr defs
283 {"T: type", "N: int", "out_types: list(type)"},
284 // Nodes
285 {
286 {{"dx"}, "_ListToArray", dys,
287 {{"T", "$T"}, {"N", "$N"}, {"Tin", "$out_types"}}}
288 });
289 // clang-format on
290 VLOG(1) << "ArrayToListGrad " << DebugString(*g);
291 return OkStatus();
292}
293REGISTER_OP_GRADIENT("_ArrayToList", ArrayToListGrad);
294
295Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) {
296 // clang-format off
297 *g = FDH::Define(
298 // Arg defs
299 {"x: Tin", "dy: N*T"},
300 // Ret val defs
301 {"dx: Tin"},
302 // Attr defs
303 {"T: type", "N: int", "Tin: list(type)"},
304 // Nodes
305 {
306 {{"dx"}, "_ArrayToList", {"dy"},
307 {{"T", "$T"}, {"N", "$N"}, {"out_types", "$Tin"}}}
308 });
309 // clang-format on
310 VLOG(1) << "ListToArrayGrad " << DebugString(*g);
311 return OkStatus();
312}
313REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad);
314
315Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
316 *g = FDH::Define(
317 // Arg defs
318 {"dims: int32", "x: T", "dy: T"},
319 // Ret val defs
320 {"d_dims: int32", "dx: T"},
321 // Attr defs
322 {"T: type"},
323 // Nodes
324 {
325 {{"d_dims"}, "ZerosLike", {"dims"}, {{"T", DT_INT32}}},
326 FDH::Const("zero", 0),
327 {{"rank"}, "Rank", {"dy"}, {{"T", "$T"}}},
328 FDH::Const("one", 1),
329 {{"r"}, "Range", {"zero", "rank", "one"}, {}},
330 // dx = sum(dy)
331 {{"dx"}, "Sum", {"dy", "r"}, {{"T", "$T"}}},
332 });
333 VLOG(1) << "FillGrad " << DebugString(*g);
334 return OkStatus();
335}
336REGISTER_OP_GRADIENT("Fill", FillGrad);
337
338Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
339 *g = FDH::Define(
340 // Arg defs
341 {"x: T", "p: int32", "dy: T"},
342 // Ret val defs
343 {"dx: T", "dp: int32"},
344 // Attr defs
345 {"T: type"},
346 // Nodes
347 {
348 {{"q"}, "InvertPermutation", {"p"}, {}},
349 {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}},
350 {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
351 });
352 VLOG(1) << "TransposeGrad " << DebugString(*g);
353 return OkStatus();
354}
355REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
356
357Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
358 // clang-format off
359 *g = FDH::Define(
360 // Arg defs
361 {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
362 // Ret val defs
363 {"dparams: Tparams", "dindices: Tindices"},
364 // Attr defs
365 {"Tparams: type", "Tindices: type"},
366 // Nodes
367 {
368 {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
369 {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
370 {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
371 {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
372 });
373 // clang-format on
374 return OkStatus();
375}
376REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
377
378Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
379 *g = FDH::Define(
380 // Arg defs
381 {"x: T", "p: int32", "dy: T"},
382 // Ret val defs
383 {"dx: T", "dp: int32"},
384 // Attr defs
385 {"T: type"},
386 // Nodes
387 {
388 {{"q"}, "InvertPermutation", {"p"}, {}},
389 {{"dx"}, "ConjugateTranspose", {"dy", "q"}, {{"T", "$T"}}},
390 {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
391 });
392 VLOG(1) << "ConjugateTransposeGrad " << DebugString(*g);
393 return OkStatus();
394}
395REGISTER_OP_GRADIENT("ConjugateTranspose", ConjugateTransposeGrad);
396
397Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) {
398 *g = FDH::Define(
399 // Arg defs
400 {"x: T", "d: bool", "dy: T"},
401 // Ret val defs
402 {"dx: T", "dd: bool"},
403 // Attr defs
404 {"T: type"},
405 // Nodes
406 {
407 {{"dx"}, "Reverse", {"dy", "d"}, {{"T", "$T"}}},
408 {{"dd"}, "ZerosLike", {"d"}, {{"T", DT_BOOL}}},
409 });
410 VLOG(1) << "ReverseGrad " << DebugString(*g);
411 return OkStatus();
412}
413REGISTER_OP_GRADIENT("Reverse", ReverseGrad);
414
415Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) {
416 DataType itype;
417 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
418 if (itype != DT_INT32) {
419 return errors::Unimplemented(
420 "ReverseV2Grad for int64 index are not supported.");
421 }
422 *g = FDH::Define(
423 // Arg defs
424 {"x: T", "d: int32", "dy: T"},
425 // Ret val defs
426 {"dx: T", "dd: int32"},
427 // Attr defs
428 {"T: type", "Tidx: {int32, int64}"},
429 // Nodes
430 {
431 {{"dx"}, "ReverseV2", {"dy", "d"}, {{"T", "$T"}}},
432 {{"dd"}, "ZerosLike", {"d"}, {{"T", "$Tidx"}}},
433 });
434 VLOG(1) << "ReverseGrad " << DebugString(*g);
435 return OkStatus();
436}
437REGISTER_OP_GRADIENT("ReverseV2", ReverseV2Grad);
438
439Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
440 DataType itype;
441 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
442 if (itype != DT_INT32) {
443 return errors::Unimplemented(
444 "SliceGrad for int64 index are not supported.");
445 }
446 *g = FDH::Define(
447 // Arg defs
448 {"x: T", "begin: int32", "size: int32", "dy: T"},
449 // Ret val defs
450 {"dx: T", "begin_grad: int32", "size_grad: int32"},
451 // Attr defs
452 {"T: type"},
453 // Nodes
454 {// paddings = concat(1, [begin, shape(x) - begin - size])
455 FDH::Const("one", 1),
456 {{"b1"}, "ExpandDims", {"begin", "one"}, {{"T", DT_INT32}}},
457 {{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
458 {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
459 {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
460 {{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
461 {{"paddings"},
462 "Concat",
463 {"one", "b1", "a1"},
464 {{"N", 2}, {"T", DT_INT32}}},
465 // dx = Pad(dy, paddings)
466 {{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
467 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
468 {{"size_grad"}, "ZerosLike", {"size"}, {{"T", DT_INT32}}}});
469 VLOG(1) << "SliceGrad " << DebugString(*g);
470 return OkStatus();
471}
472REGISTER_OP_GRADIENT("Slice", SliceGrad);
473
474Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) {
475 DataType itype;
476 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
477 if (itype != DT_INT32) {
478 return errors::Unimplemented(
479 "SliceGrad for int64 index are not supported.");
480 }
481
482 *g = FDH::Define(
483 // Arg defs
484 {"x: T", "begin: int32", "end: int32", "stride: int32", "dy: T"},
485 // Ret val defs
486 {"dx: T", "begin_grad: int32", "end_grad: int32", "stride_grad: int32"},
487 // Attr defs
488 {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
489 "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
490 {// Nodes
491 {{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
492 {{"dx"},
493 "StridedSliceGrad",
494 {"xs", "begin", "end", "stride", "dy"},
495 {{"T", "$T"},
496 {"Index", "$Index"},
497 {"begin_mask", "$begin_mask"},
498 {"end_mask", "$end_mask"},
499 {"ellipsis_mask", "$ellipsis_mask"},
500 {"new_axis_mask", "$new_axis_mask"},
501 {"shrink_axis_mask", "$shrink_axis_mask"}}},
502 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
503 {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
504 {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}}}});
505
506 VLOG(1) << "StridedSliceGrad " << DebugString(*g);
507 return OkStatus();
508}
509REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad);
510
511Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
512 DataType itype;
513 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
514 if (itype != DT_INT32) {
515 return errors::Unimplemented(
516 "SliceGrad for int64 index are not supported.");
517 }
518
519 // TODO(aselle): Shouldn't the int32 tensors return zeros of shape like
520 // dy_grad?
521 // I'm following slice's behavior for now.
522 *g = FDH::Define(
523 // Arg defs
524 {"shape: int32", "begin: int32", "end: int32", "stride: int32", "dy: T",
525 "grad: T"},
526 // Ret val defs
527 {"shape_grad: int32", "begin_grad: int32", "end_grad: int32",
528 "stride_grad: int32", "dy_grad: T"},
529 // Attr defs
530 {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
531 "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
532 {// Nodes
533 {{{"shape_grad"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
534 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
535 {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
536 {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}},
537 {{"dy_grad"},
538 "StridedSlice",
539 {"grad", "begin", "end", "stride"},
540 {{"T", "$T"},
541 {"Index", "$Index"},
542 {"begin_mask", "$begin_mask"},
543 {"end_mask", "$end_mask"},
544 {"ellipsis_mask", "$ellipsis_mask"},
545 {"new_axis_mask", "$new_axis_mask"},
546 {"shrink_axis_mask", "$shrink_axis_mask"}}}}});
547
548 VLOG(1) << "StridedSliceGrad " << DebugString(*g);
549 return OkStatus();
550}
551REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
552
553Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
554 DataType itype;
555 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
556 if (itype != DT_INT32) {
557 return errors::Unimplemented(
558 "BroadcastToGrad for int64 index are not supported.");
559 }
560 std::vector<FDH::Node> nodes = {
561 {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
562 {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
563 {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
564 {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
565 {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
566 *g = FDH::Define(
567 // Arg defs
568 {"x: T", "shape: int32", "dy: T"},
569 // Ret val defs
570 {"dx: T", "dshape: Tidx"},
571 // Attr defs
572 {{"T: type"}, {"Tidx: {int32, int64}"}},
573 // Nodes
574 nodes);
575 return OkStatus();
576}
577REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
578
579} // end namespace tensorflow
580