1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | #include "tensorflow/core/framework/function.h" |
18 | #include "tensorflow/core/lib/core/errors.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | typedef FunctionDefHelper FDH; |
23 | |
24 | REGISTER_OP_NO_GRADIENT("Shape" ); |
25 | REGISTER_OP_NO_GRADIENT("Rank" ); |
26 | REGISTER_OP_NO_GRADIENT("Size" ); |
27 | REGISTER_OP_NO_GRADIENT("ZerosLike" ); |
28 | REGISTER_OP_NO_GRADIENT("OnesLike" ); |
29 | REGISTER_OP_NO_GRADIENT("Const" ); |
30 | REGISTER_OP_NO_GRADIENT("EditDistance" ); |
31 | REGISTER_OP_NO_GRADIENT("StopGradient" ); |
32 | |
33 | Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { |
34 | // clang-format off |
35 | *g = FDH::Define( |
36 | // Arg defs |
37 | {"x: T" , "shape: int32" , "dy: T" }, |
38 | // Ret val defs |
39 | {"dx: T" , "dshape: int32" }, |
40 | // Attr defs |
41 | {"T: type" }, |
42 | // Nodes |
43 | { |
44 | {{"x_shape" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
45 | {{"dx" }, "Reshape" , {"dy" , "x_shape" }, {{"T" , "$T" }}}, |
46 | {{"dshape" }, "ZerosLike" , {"shape" }, {{"T" , DT_INT32}}}, |
47 | }); |
48 | // clang-format on |
49 | return OkStatus(); |
50 | } |
51 | REGISTER_OP_GRADIENT("Reshape" , ReshapeGrad); |
52 | REGISTER_OP_GRADIENT("ExpandDims" , ReshapeGrad); |
53 | |
54 | Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) { |
55 | // clang-format off |
56 | *g = FDH::Define( |
57 | // Arg defs |
58 | {"x: T" , "dy: T" }, |
59 | // Ret val defs |
60 | {"dx: T" }, |
61 | // Attr defs |
62 | {"T: type" }, |
63 | // Nodes |
64 | { |
65 | {{"x_shape" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
66 | {{"dx" }, "Reshape" , {"dy" , "x_shape" }, {{"T" , "$T" }}}, |
67 | }); |
68 | // clang-format on |
69 | return OkStatus(); |
70 | } |
71 | REGISTER_OP_GRADIENT("Squeeze" , SqueezeGrad); |
72 | |
73 | Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) { |
74 | // clang-format off |
75 | *g = FDH::Define( |
76 | // Arg defs |
77 | {"x: T" , "dy: T" }, |
78 | // Ret val defs |
79 | {"dx: T" }, |
80 | // Attr defs |
81 | {"T: type" }, |
82 | // Nodes |
83 | { |
84 | {{"dx" }, "Identity" , {"dy" }, {{"T" , "$T" }}}, |
85 | }); |
86 | // clang-format on |
87 | VLOG(1) << "IdentityGrad " << DebugString(*g); |
88 | return OkStatus(); |
89 | } |
90 | REGISTER_OP_GRADIENT("Identity" , IdentityGrad); |
91 | |
92 | Status PackGrad(const AttrSlice& attrs, FunctionDef* g) { |
93 | // clang-format off |
94 | *g = FDH::Create( |
95 | "_" , |
96 | // Arg defs |
97 | {"x: N*T" , "dy: T" }, |
98 | // Ret val defs |
99 | {"dx: N*T" }, |
100 | // Attr defs |
101 | {"T: type" , "N: int" , "axis: int" }, |
102 | // Nodes |
103 | { |
104 | { |
105 | {"dx" }, |
106 | "Unpack" , |
107 | {"dy" }, |
108 | {{"T" , "$T" }, {"num" , "$N" }, {"axis" , "$axis" }} |
109 | }, |
110 | }, |
111 | {{"dx" , "dx:output" }}); |
112 | // clang-format on |
113 | VLOG(1) << "PackGrad " << DebugString(*g); |
114 | return OkStatus(); |
115 | } |
116 | REGISTER_OP_GRADIENT("Pack" , PackGrad); |
117 | |
118 | Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) { |
119 | // clang-format off |
120 | *g = FDH::Define( |
121 | // Arg defs |
122 | {"x: T" , "dy: num*T" }, |
123 | // Ret val defs |
124 | {"dx: T" }, |
125 | // Attr defs |
126 | {"T: type" , "num: int" , "axis: int" }, |
127 | // Nodes |
128 | { |
129 | { |
130 | {"dx" }, |
131 | "Pack" , |
132 | {"dy" }, |
133 | {{"T" , "$T" }, {"N" , "$num" }, {"axis" , "$axis" }} |
134 | }, |
135 | }); |
136 | // clang-format on |
137 | VLOG(1) << "UnpackGrad " << DebugString(*g); |
138 | return OkStatus(); |
139 | } |
140 | REGISTER_OP_GRADIENT("Unpack" , UnpackGrad); |
141 | |
142 | Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, |
143 | bool dim_is_last_arg) { |
144 | int N; |
145 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N" , &N)); |
146 | DataType T; |
147 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T" , &T)); |
148 | |
149 | std::vector<string> shape_i; |
150 | std::vector<string> offset_i; |
151 | std::vector<string> dx_i; |
152 | for (int i = 0; i < N; ++i) { |
153 | shape_i.push_back(strings::StrCat("shapes:output:" , i)); |
154 | offset_i.push_back(strings::StrCat("offset:offset:" , i)); |
155 | dx_i.push_back(strings::StrCat("dx_" , i, ":output:0" )); |
156 | } |
157 | DataTypeVector dtype_list(N, T); |
158 | |
159 | // ConcatGrad(dim, x, dy): |
160 | // for i in range(N): |
161 | // dx[i] = Slice(dy, offset[i], shape[x[i]]), |
162 | // where offset[i] is the offset of x[i] in the output y, |
163 | // which is the same as dx[i]'s offset within dy. |
164 | std::vector<FDH::Node> nodes{ |
165 | {{"shapes" }, "ShapeN" , {"x" }, {{"T" , "$T" }, {"N" , "$N" }}}, |
166 | {{"offset" }, "ConcatOffset" , {"dim" , "shapes:output" }, {{"N" , "$N" }}}, |
167 | {{"d_dim" }, "ZerosLike" , {"dim" }, {{"T" , DT_INT32}}}, |
168 | {{"dx" }, |
169 | "_ListToArray" , |
170 | dx_i, |
171 | {{"T" , "$T" }, {"N" , "$N" }, {"Tin" , DataTypeVector(N, T)}}}}; |
172 | |
173 | // For each dx[i], we take a slice of dy. The offset and size of the |
174 | // slice is given by offset[i] and shape[i]. |
175 | for (int i = 0; i < N; ++i) { |
176 | nodes.push_back({{strings::StrCat("dx_" , i)}, |
177 | "Slice" , |
178 | {"dy" , offset_i[i], shape_i[i]}, |
179 | {{"T" , "$T" }, {"Index" , DT_INT32}}}); |
180 | } |
181 | if (dim_is_last_arg) { |
182 | // clang-format off |
183 | *g = FDH::Create( |
184 | "_" , |
185 | // Arg defs |
186 | {"x: N*T" , "dim: int32" , "dy: T" }, |
187 | // Return signature |
188 | {"dx: N*T" , "d_dim: int32" }, |
189 | // Attr defs |
190 | {"T: type" , "N: int" }, |
191 | // Nodes |
192 | nodes, |
193 | // Return values |
194 | {{"dx" , "dx:output" }, {"d_dim" , "d_dim:y:0" }}); |
195 | // clang-format on |
196 | } else { |
197 | // clang-format off |
198 | *g = FDH::Create( |
199 | "_" , |
200 | // Arg defs |
201 | {"dim: int32" , "x: N*T" , "dy: T" }, |
202 | // Return signature |
203 | {"d_dim: int32" , "dx: N*T" }, |
204 | // Attr defs |
205 | {"T: type" , "N: int" }, |
206 | // Nodes |
207 | nodes, |
208 | // Return values |
209 | {{"dx" , "dx:output" }, {"d_dim" , "d_dim:y:0" }}); |
210 | // clang-format on |
211 | } |
212 | VLOG(1) << "ConcatGrad " << DebugString(*g); |
213 | return OkStatus(); |
214 | } |
215 | |
216 | Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) { |
217 | return ConcatGradHelper(attrs, g, false); |
218 | } |
219 | |
220 | Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) { |
221 | return ConcatGradHelper(attrs, g, true); |
222 | } |
223 | |
224 | REGISTER_OP_GRADIENT("Concat" , ConcatGrad); |
225 | REGISTER_OP_GRADIENT("ConcatV2" , ConcatGradV2); |
226 | |
227 | Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { |
228 | // clang-format off |
229 | *g = FDH::Define( |
230 | // Arg defs |
231 | {"dim: int32" , "x: T" , "dy: num_split*T" }, |
232 | // Ret val defs |
233 | {"d_dim: int32" , "dx: T" }, |
234 | // Attr defs |
235 | {"T: type" , "num_split: int" }, |
236 | // Nodes |
237 | { |
238 | {{"d_dim" }, "ZerosLike" , {"dim" }, {{"T" , DT_INT32}}}, |
239 | {{"dx" }, "Concat" , {"dim" , "dy" }, {{"T" , "$T" }, {"N" , "$num_split" }}} |
240 | }); |
241 | // clang-format on |
242 | VLOG(1) << "SplitGrad " << DebugString(*g); |
243 | return OkStatus(); |
244 | } |
245 | REGISTER_OP_GRADIENT("Split" , SplitGrad); |
246 | |
247 | Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { |
248 | // clang-format off |
249 | *g = FDH::Define( |
250 | // Arg defs |
251 | {"x: T" , "size_splits: Tlen" , "dim: int32" , "dy: num_split*T" }, |
252 | // Ret val defs |
253 | {"dx: T" , "d_size_splits: Tlen" , "d_dim: int32" }, |
254 | // Attr defs |
255 | {"T: type" , "Tlen: type" , "num_split: int" }, |
256 | // Nodes |
257 | { |
258 | {{"dx" }, "Concat" , {"dim" , "dy" }, {{"T" , "$T" }, {"N" , "$num_split" }}}, |
259 | {{"d_size_splits" }, "ZerosLike" , {"size_splits" }, {{"T" , "$Tlen" }}}, |
260 | {{"d_dim" }, "ZerosLike" , {"dim" }, {{"T" , DT_INT32}}}, |
261 | }); |
262 | // clang-format on |
263 | VLOG(1) << "SplitVGrad " << DebugString(*g); |
264 | return OkStatus(); |
265 | } |
266 | REGISTER_OP_GRADIENT("SplitV" , SplitVGrad); |
267 | |
268 | Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { |
269 | int N; |
270 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N" , &N)); |
271 | std::vector<string> dys; |
272 | dys.reserve(N); |
273 | for (int i = 0; i < N; ++i) { |
274 | dys.push_back(strings::StrCat("dy:" , i)); |
275 | } |
276 | // clang-format off |
277 | *g = FDH::Define( |
278 | // Arg defs |
279 | {"x: N*T" , "dy: out_types" }, |
280 | // Ret val defs |
281 | {"dx: N*T" }, |
282 | // Attr defs |
283 | {"T: type" , "N: int" , "out_types: list(type)" }, |
284 | // Nodes |
285 | { |
286 | {{"dx" }, "_ListToArray" , dys, |
287 | {{"T" , "$T" }, {"N" , "$N" }, {"Tin" , "$out_types" }}} |
288 | }); |
289 | // clang-format on |
290 | VLOG(1) << "ArrayToListGrad " << DebugString(*g); |
291 | return OkStatus(); |
292 | } |
293 | REGISTER_OP_GRADIENT("_ArrayToList" , ArrayToListGrad); |
294 | |
295 | Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) { |
296 | // clang-format off |
297 | *g = FDH::Define( |
298 | // Arg defs |
299 | {"x: Tin" , "dy: N*T" }, |
300 | // Ret val defs |
301 | {"dx: Tin" }, |
302 | // Attr defs |
303 | {"T: type" , "N: int" , "Tin: list(type)" }, |
304 | // Nodes |
305 | { |
306 | {{"dx" }, "_ArrayToList" , {"dy" }, |
307 | {{"T" , "$T" }, {"N" , "$N" }, {"out_types" , "$Tin" }}} |
308 | }); |
309 | // clang-format on |
310 | VLOG(1) << "ListToArrayGrad " << DebugString(*g); |
311 | return OkStatus(); |
312 | } |
313 | REGISTER_OP_GRADIENT("_ListToArray" , ListToArrayGrad); |
314 | |
315 | Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { |
316 | *g = FDH::Define( |
317 | // Arg defs |
318 | {"dims: int32" , "x: T" , "dy: T" }, |
319 | // Ret val defs |
320 | {"d_dims: int32" , "dx: T" }, |
321 | // Attr defs |
322 | {"T: type" }, |
323 | // Nodes |
324 | { |
325 | {{"d_dims" }, "ZerosLike" , {"dims" }, {{"T" , DT_INT32}}}, |
326 | FDH::Const("zero" , 0), |
327 | {{"rank" }, "Rank" , {"dy" }, {{"T" , "$T" }}}, |
328 | FDH::Const("one" , 1), |
329 | {{"r" }, "Range" , {"zero" , "rank" , "one" }, {}}, |
330 | // dx = sum(dy) |
331 | {{"dx" }, "Sum" , {"dy" , "r" }, {{"T" , "$T" }}}, |
332 | }); |
333 | VLOG(1) << "FillGrad " << DebugString(*g); |
334 | return OkStatus(); |
335 | } |
336 | REGISTER_OP_GRADIENT("Fill" , FillGrad); |
337 | |
338 | Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { |
339 | *g = FDH::Define( |
340 | // Arg defs |
341 | {"x: T" , "p: int32" , "dy: T" }, |
342 | // Ret val defs |
343 | {"dx: T" , "dp: int32" }, |
344 | // Attr defs |
345 | {"T: type" }, |
346 | // Nodes |
347 | { |
348 | {{"q" }, "InvertPermutation" , {"p" }, {}}, |
349 | {{"dx" }, "Transpose" , {"dy" , "q" }, {{"T" , "$T" }}}, |
350 | {{"dp" }, "ZerosLike" , {"p" }, {{"T" , DT_INT32}}}, |
351 | }); |
352 | VLOG(1) << "TransposeGrad " << DebugString(*g); |
353 | return OkStatus(); |
354 | } |
355 | REGISTER_OP_GRADIENT("Transpose" , TransposeGrad); |
356 | |
357 | Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { |
358 | // clang-format off |
359 | *g = FDH::Define( |
360 | // Arg defs |
361 | {"params: Tparams" , "indices: Tindices" , "doutput: Tparams" }, |
362 | // Ret val defs |
363 | {"dparams: Tparams" , "dindices: Tindices" }, |
364 | // Attr defs |
365 | {"Tparams: type" , "Tindices: type" }, |
366 | // Nodes |
367 | { |
368 | {{"x_shape" }, "Shape" , {"params" }, {{"T" , "$Tparams" }}}, |
369 | {{"dparams" }, "ScatterNd" , {"indices" , "doutput" , "x_shape" }, |
370 | {{"T" , "$Tparams" }, {"Tindices" , "$Tindices" }}}, |
371 | {{"dindices" }, "ZerosLike" , {"indices" }, {{"T" , "$Tindices" }}}, |
372 | }); |
373 | // clang-format on |
374 | return OkStatus(); |
375 | } |
376 | REGISTER_OP_GRADIENT("GatherNd" , GatherNdGrad); |
377 | |
378 | Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { |
379 | *g = FDH::Define( |
380 | // Arg defs |
381 | {"x: T" , "p: int32" , "dy: T" }, |
382 | // Ret val defs |
383 | {"dx: T" , "dp: int32" }, |
384 | // Attr defs |
385 | {"T: type" }, |
386 | // Nodes |
387 | { |
388 | {{"q" }, "InvertPermutation" , {"p" }, {}}, |
389 | {{"dx" }, "ConjugateTranspose" , {"dy" , "q" }, {{"T" , "$T" }}}, |
390 | {{"dp" }, "ZerosLike" , {"p" }, {{"T" , DT_INT32}}}, |
391 | }); |
392 | VLOG(1) << "ConjugateTransposeGrad " << DebugString(*g); |
393 | return OkStatus(); |
394 | } |
395 | REGISTER_OP_GRADIENT("ConjugateTranspose" , ConjugateTransposeGrad); |
396 | |
397 | Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) { |
398 | *g = FDH::Define( |
399 | // Arg defs |
400 | {"x: T" , "d: bool" , "dy: T" }, |
401 | // Ret val defs |
402 | {"dx: T" , "dd: bool" }, |
403 | // Attr defs |
404 | {"T: type" }, |
405 | // Nodes |
406 | { |
407 | {{"dx" }, "Reverse" , {"dy" , "d" }, {{"T" , "$T" }}}, |
408 | {{"dd" }, "ZerosLike" , {"d" }, {{"T" , DT_BOOL}}}, |
409 | }); |
410 | VLOG(1) << "ReverseGrad " << DebugString(*g); |
411 | return OkStatus(); |
412 | } |
413 | REGISTER_OP_GRADIENT("Reverse" , ReverseGrad); |
414 | |
415 | Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) { |
416 | DataType itype; |
417 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx" , &itype)); |
418 | if (itype != DT_INT32) { |
419 | return errors::Unimplemented( |
420 | "ReverseV2Grad for int64 index are not supported." ); |
421 | } |
422 | *g = FDH::Define( |
423 | // Arg defs |
424 | {"x: T" , "d: int32" , "dy: T" }, |
425 | // Ret val defs |
426 | {"dx: T" , "dd: int32" }, |
427 | // Attr defs |
428 | {"T: type" , "Tidx: {int32, int64}" }, |
429 | // Nodes |
430 | { |
431 | {{"dx" }, "ReverseV2" , {"dy" , "d" }, {{"T" , "$T" }}}, |
432 | {{"dd" }, "ZerosLike" , {"d" }, {{"T" , "$Tidx" }}}, |
433 | }); |
434 | VLOG(1) << "ReverseGrad " << DebugString(*g); |
435 | return OkStatus(); |
436 | } |
437 | REGISTER_OP_GRADIENT("ReverseV2" , ReverseV2Grad); |
438 | |
439 | Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { |
440 | DataType itype; |
441 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index" , &itype)); |
442 | if (itype != DT_INT32) { |
443 | return errors::Unimplemented( |
444 | "SliceGrad for int64 index are not supported." ); |
445 | } |
446 | *g = FDH::Define( |
447 | // Arg defs |
448 | {"x: T" , "begin: int32" , "size: int32" , "dy: T" }, |
449 | // Ret val defs |
450 | {"dx: T" , "begin_grad: int32" , "size_grad: int32" }, |
451 | // Attr defs |
452 | {"T: type" }, |
453 | // Nodes |
454 | {// paddings = concat(1, [begin, shape(x) - begin - size]) |
455 | FDH::Const("one" , 1), |
456 | {{"b1" }, "ExpandDims" , {"begin" , "one" }, {{"T" , DT_INT32}}}, |
457 | {{"xs" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
458 | {{"xs_b" }, "Sub" , {"xs" , "begin" }, {{"T" , DT_INT32}}}, |
459 | {{"xs_b_s" }, "Sub" , {"xs_b" , "size" }, {{"T" , DT_INT32}}}, |
460 | {{"a1" }, "ExpandDims" , {"xs_b_s" , "one" }, {{"T" , DT_INT32}}}, |
461 | {{"paddings" }, |
462 | "Concat" , |
463 | {"one" , "b1" , "a1" }, |
464 | {{"N" , 2}, {"T" , DT_INT32}}}, |
465 | // dx = Pad(dy, paddings) |
466 | {{"dx" }, "Pad" , {"dy" , "paddings" }, {{"T" , "$T" }}}, |
467 | {{"begin_grad" }, "ZerosLike" , {"begin" }, {{"T" , DT_INT32}}}, |
468 | {{"size_grad" }, "ZerosLike" , {"size" }, {{"T" , DT_INT32}}}}); |
469 | VLOG(1) << "SliceGrad " << DebugString(*g); |
470 | return OkStatus(); |
471 | } |
472 | REGISTER_OP_GRADIENT("Slice" , SliceGrad); |
473 | |
474 | Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) { |
475 | DataType itype; |
476 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index" , &itype)); |
477 | if (itype != DT_INT32) { |
478 | return errors::Unimplemented( |
479 | "SliceGrad for int64 index are not supported." ); |
480 | } |
481 | |
482 | *g = FDH::Define( |
483 | // Arg defs |
484 | {"x: T" , "begin: int32" , "end: int32" , "stride: int32" , "dy: T" }, |
485 | // Ret val defs |
486 | {"dx: T" , "begin_grad: int32" , "end_grad: int32" , "stride_grad: int32" }, |
487 | // Attr defs |
488 | {"T: type" , "Index: {int32, int64}" , "begin_mask: int" , "end_mask: int" , |
489 | "ellipsis_mask: int" , "new_axis_mask: int" , "shrink_axis_mask: int" }, |
490 | {// Nodes |
491 | {{{"xs" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
492 | {{"dx" }, |
493 | "StridedSliceGrad" , |
494 | {"xs" , "begin" , "end" , "stride" , "dy" }, |
495 | {{"T" , "$T" }, |
496 | {"Index" , "$Index" }, |
497 | {"begin_mask" , "$begin_mask" }, |
498 | {"end_mask" , "$end_mask" }, |
499 | {"ellipsis_mask" , "$ellipsis_mask" }, |
500 | {"new_axis_mask" , "$new_axis_mask" }, |
501 | {"shrink_axis_mask" , "$shrink_axis_mask" }}}, |
502 | {{"begin_grad" }, "ZerosLike" , {"begin" }, {{"T" , DT_INT32}}}, |
503 | {{"end_grad" }, "ZerosLike" , {"end" }, {{"T" , DT_INT32}}}, |
504 | {{"stride_grad" }, "ZerosLike" , {"stride" }, {{"T" , DT_INT32}}}}}); |
505 | |
506 | VLOG(1) << "StridedSliceGrad " << DebugString(*g); |
507 | return OkStatus(); |
508 | } |
509 | REGISTER_OP_GRADIENT("StridedSlice" , StridedSliceGrad); |
510 | |
511 | Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { |
512 | DataType itype; |
513 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index" , &itype)); |
514 | if (itype != DT_INT32) { |
515 | return errors::Unimplemented( |
516 | "SliceGrad for int64 index are not supported." ); |
517 | } |
518 | |
519 | // TODO(aselle): Shouldn't the int32 tensors return zeros of shape like |
520 | // dy_grad? |
521 | // I'm following slice's behavior for now. |
522 | *g = FDH::Define( |
523 | // Arg defs |
524 | {"shape: int32" , "begin: int32" , "end: int32" , "stride: int32" , "dy: T" , |
525 | "grad: T" }, |
526 | // Ret val defs |
527 | {"shape_grad: int32" , "begin_grad: int32" , "end_grad: int32" , |
528 | "stride_grad: int32" , "dy_grad: T" }, |
529 | // Attr defs |
530 | {"T: type" , "Index: {int32, int64}" , "begin_mask: int" , "end_mask: int" , |
531 | "ellipsis_mask: int" , "new_axis_mask: int" , "shrink_axis_mask: int" }, |
532 | {// Nodes |
533 | {{{"shape_grad" }, "ZerosLike" , {"shape" }, {{"T" , DT_INT32}}}, |
534 | {{"begin_grad" }, "ZerosLike" , {"begin" }, {{"T" , DT_INT32}}}, |
535 | {{"end_grad" }, "ZerosLike" , {"end" }, {{"T" , DT_INT32}}}, |
536 | {{"stride_grad" }, "ZerosLike" , {"stride" }, {{"T" , DT_INT32}}}, |
537 | {{"dy_grad" }, |
538 | "StridedSlice" , |
539 | {"grad" , "begin" , "end" , "stride" }, |
540 | {{"T" , "$T" }, |
541 | {"Index" , "$Index" }, |
542 | {"begin_mask" , "$begin_mask" }, |
543 | {"end_mask" , "$end_mask" }, |
544 | {"ellipsis_mask" , "$ellipsis_mask" }, |
545 | {"new_axis_mask" , "$new_axis_mask" }, |
546 | {"shrink_axis_mask" , "$shrink_axis_mask" }}}}}); |
547 | |
548 | VLOG(1) << "StridedSliceGrad " << DebugString(*g); |
549 | return OkStatus(); |
550 | } |
551 | REGISTER_OP_GRADIENT("StridedSliceGrad" , StridedSliceGradGrad); |
552 | |
553 | Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) { |
554 | DataType itype; |
555 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx" , &itype)); |
556 | if (itype != DT_INT32) { |
557 | return errors::Unimplemented( |
558 | "BroadcastToGrad for int64 index are not supported." ); |
559 | } |
560 | std::vector<FDH::Node> nodes = { |
561 | {{"sx" }, "Shape" , {"x" }, {{"T" , "$T" }}}, |
562 | {{"rx" , "ry" }, "BroadcastGradientArgs" , {"sx" , "shape" }}, |
563 | {{"sum_gx" }, "Sum" , {"dy" , "rx" }, {{"T" , "$T" }}}, |
564 | {{"dx" }, "Reshape" , {"sum_gx" , "sx" }, {{"T" , "$T" }}}, |
565 | {{"dshape" }, "ZerosLike" , {"shape" }, {{"T" , "$Tidx" }}}}; |
566 | *g = FDH::Define( |
567 | // Arg defs |
568 | {"x: T" , "shape: int32" , "dy: T" }, |
569 | // Ret val defs |
570 | {"dx: T" , "dshape: Tidx" }, |
571 | // Attr defs |
572 | {{"T: type" }, {"Tidx: {int32, int64}" }}, |
573 | // Nodes |
574 | nodes); |
575 | return OkStatus(); |
576 | } |
577 | REGISTER_OP_GRADIENT("BroadcastTo" , BroadcastToGrad); |
578 | |
579 | } // end namespace tensorflow |
580 | |