1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/function.h"
17#include "tensorflow/core/lib/core/errors.h"
18#include "tensorflow/core/util/padding.h"
19#include "tensorflow/core/util/tensor_format.h"
20
21namespace tensorflow {
22
23typedef FunctionDefHelper FDH;
24
25Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
26 // clang-format off
27 *g = FDH::Define(
28 "SoftmaxGrad",
29 // Arg defs
30 {"x: T", "grad_softmax: T"},
31 // Ret val defs
32 {"grad_x: T"},
33 // Attr defs
34 {{"T: {float, double, bfloat16}"}},
35 // Nodes
36 // Based on _SoftmaxGrad in nn_grad.py.
37 {
38 {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}},
39 {{"n0"}, "Mul", {"grad_softmax", "softmax"}, {{"T", "$T"}}},
40 FDH::Const<int32>("indices", {-1}),
41 {{"n1"}, "Sum", {"n0", "indices"}, {{"keep_dims", true}, {"T", "$T"}}},
42 {{"n2"}, "Sub", {"grad_softmax", "n1"}, {{"T", "$T"}}},
43 {{"grad_x"}, "Mul", {"n2", "softmax"}, {{"T", "$T"}}}
44 });
45 // clang-format on
46 return OkStatus();
47}
48REGISTER_OP_GRADIENT("Softmax", SoftmaxGrad);
49
50Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
51 // clang-format off
52 *g = FDH::Define(
53 "LogSoftmaxGrad",
54 // Arg defs
55 {"x: T", "grad_logsoftmax: T"},
56 // Ret val defs
57 {"grad_x: T"},
58 // Attr defs
59 {{"T: {float, double}"}},
60 // Nodes
61 // Based on _LogSoftmaxGrad in nn_grad.py.
62 {
63 {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}},
64 FDH::Const<int32>("indices", {-1}),
65 {{"n0"}, "Sum", {"grad_logsoftmax", "indices"},
66 {{"keep_dims", true}, {"T", "$T"}}},
67 {{"n1"}, "Mul", {"n0", "softmax"}, {{"T", "$T"}}},
68 {{"grad_x"}, "Sub", {"grad_logsoftmax", "n1"}, {{"T", "$T"}}}
69 });
70 // clang-format on
71 return OkStatus();
72}
73REGISTER_OP_GRADIENT("LogSoftmax", LogSoftmaxGrad);
74
75Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) {
76 // clang-format off
77 *g = FDH::Define(
78 // Arg defs
79 {"x: T", "dy: T"},
80 // Ret val defs
81 {"dx: T"},
82 // Attr defs
83 {{"T: {float, double}"}},
84 // Nodes
85 {
86 {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}}
87 });
88 // clang-format on
89 return OkStatus();
90}
91REGISTER_OP_GRADIENT("Relu", ReluGrad);
92
93Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) {
94 // clang-format off
95 *g = FDH::Define(
96 // Arg defs
97 {"x: T", "dy: T"},
98 // Ret val defs
99 {"dx: T"},
100 // Attr defs
101 {{"T: {float, double}"}},
102 // Nodes
103 {
104 {{"dx"}, "Relu6Grad", {"dy", "x"}, {{"T", "$T"}}}
105 });
106 // clang-format on
107 return OkStatus();
108}
109REGISTER_OP_GRADIENT("Relu6", Relu6Grad);
110
111Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
112 // clang-format off
113 *g = FDH::Define(
114 // Arg defs
115 {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"},
116 // Ret val defs
117 {"dcost_dfeatures: T", "dcost_dlabels: T"},
118 // Attr defs
119 {{"T: {float, double}"}},
120 // Nodes
121 {
122 // _, dloss_dfeatures = CrossEntropy(features, labels)
123 {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy",
124 {"features", "labels"}, {{"T", "$T"}}},
125 // dcost_dloss is of shape [batch_size].
126 // dcost_dloss_mat is of shape [batch_size, 1].
127 FDH::Const("neg1", -1),
128 {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"},
129 {{"T", "$T"}}},
130 // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures
131 {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"},
132 {{"T", "$T"}}},
133 {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}},
134 });
135 // clang-format on
136 return OkStatus();
137}
138REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
139
140Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) {
141 // clang-format off
142 *g = FDH::Define(
143 // Arg defs
144 {"input: T", "filter: T", "grad: T"},
145 // Ret val defs
146 {"input_grad: T", "filter_grad: T"},
147 // Attr defs
148 {"T: {float, double}",
149 "strides: list(int)",
150 "use_cudnn_on_gpu: bool = true",
151 GetPaddingAttrString(),
152 GetConvnetDataFormatAttrString()},
153 // Nodes
154 {
155 {{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
156 {{"input_grad"}, "Conv2DBackpropInput", {"i_shape", "filter", "grad"},
157 /*Attrs=*/{{"T", "$T"},
158 {"strides", "$strides"},
159 {"padding", "$padding"},
160 {"data_format", "$data_format"},
161 {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
162
163 {{"f_shape"}, "Shape", {"filter"}, {{"T", "$T"}}},
164 {{"filter_grad"}, "Conv2DBackpropFilter", {"input", "f_shape", "grad"},
165 /*Attrs=*/{{"T", "$T"},
166 {"strides", "$strides"},
167 {"padding", "$padding"},
168 {"data_format", "$data_format"},
169 {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
170 });
171 // clang-format on
172 return OkStatus();
173}
174REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad);
175
176Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
177 // clang-format off
178 *g = FDH::Define(
179 // Arg defs
180 {"input: T", "grad: T"},
181 // Ret val defs
182 {"output: T"},
183 // Attr defs
184 {"T: {float, half} = DT_FLOAT",
185 "ksize: list(int) >= 4",
186 "strides: list(int) >= 4",
187 GetPaddingAttrString()},
188 // Nodes
189 {
190 // Invoke MaxPool again to recompute the outputs (removed by CSE?).
191 {{"maxpool"}, "MaxPool", {"input"},
192 /*Attrs=*/{{"T", "$T"},
193 {"ksize", "$ksize"},
194 {"strides", "$strides"},
195 {"padding", "$padding"}}},
196 {{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
197 /*Attrs=*/{{"T", "$T"},
198 {"ksize", "$ksize"},
199 {"strides", "$strides"},
200 {"padding", "$padding"}}}
201 });
202 // clang-format on
203 return OkStatus();
204}
205REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
206
207Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
208 // clang-format off
209 *g = FDH::Define(
210 // Arg defs
211 {"input: T", "grad: T"},
212 // Ret val defs
213 {"output: T"},
214 // Attr defs
215 {"T: {float, half} = DT_FLOAT",
216 "ksize: list(int) >= 4",
217 "strides: list(int) >= 4",
218 GetPaddingAttrString()},
219 // Nodes
220 {
221 {{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
222 {{"output"}, "AvgPoolGrad", {"i_shape", "grad"},
223 /*Attrs=*/{{"T", "$T"},
224 {"ksize", "$ksize"},
225 {"strides", "$strides"},
226 {"padding", "$padding"}}}
227 });
228 // clang-format on
229 return OkStatus();
230}
231REGISTER_OP_GRADIENT("AvgPool", AvgPoolGrad);
232
233Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) {
234 // clang-format off
235 *g = FDH::Define(
236 // Arg defs
237 {"input: T", "grad: T"},
238 // Ret val defs
239 {"output: T"},
240 // Attr defs
241 {"T: {float, half} = DT_FLOAT",
242 "ksize: list(int) >= 4",
243 "strides: list(int) >= 4",
244 GetPaddingAttrString()},
245 // Nodes
246 {
247 // Invoke MaxPool again to recompute the outputs (removed by CSE?).
248 {{"maxpool"}, "MaxPool", {"input"},
249 /*Attrs=*/{{"T", "$T"},
250 {"ksize", "$ksize"},
251 {"strides", "$strides"},
252 {"padding", "$padding"}}},
253 {{"output"}, "MaxPoolGradGrad", {"input", "maxpool", "grad"},
254 /*Attrs=*/{{"T", "$T"},
255 {"ksize", "$ksize"},
256 {"strides", "$strides"},
257 {"padding", "$padding"}}}
258 });
259 // clang-format on
260 return OkStatus();
261}
262REGISTER_OP_GRADIENT("MaxPoolGrad", MaxPoolGradGrad);
263
264Status BiasAddGrad(const AttrSlice& attrs, FunctionDef* g) {
265 // clang-format off
266 *g = FDH::Define(
267 // Arg defs
268 {"input: T", "bias: T", "grad: T"},
269 // Ret val defs
270 {"grad: T", "bias_grad: T"},
271 // Attr defs
272 {{"T: {float, double}"},
273 GetConvnetDataFormatAttrString()},
274 // Nodes
275 {
276 {{"bias_grad"}, "BiasAddGrad", {"grad"},
277 /*Attrs=*/{{"T", "$T"},
278 {"data_format", "$data_format"}}}
279 });
280 // clang-format on
281 return OkStatus();
282}
283REGISTER_OP_GRADIENT("BiasAdd", BiasAddGrad);
284
285} // end namespace tensorflow
286