1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_ |
19 | #define TENSORFLOW_CORE_KERNELS_RELU_OP_H_ |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
24 | #include "tensorflow/core/framework/numeric_op.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/kernels/relu_op_functor.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | template <typename Device, typename T> |
34 | class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> { |
35 | public: |
36 | using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp; |
37 | |
38 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
39 | functor::Relu<Device, T> functor; |
40 | functor(context->eigen_device<Device>(), input.flat<T>(), |
41 | output->flat<T>()); |
42 | } |
43 | }; |
44 | |
45 | // Out of line check to save code space (we have this code once, rather |
46 | // than once for every NDIMS * NumTypes * Num_different_relu_variants |
47 | // functions. |
48 | struct ReluHelpers { |
49 | static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, |
50 | const Tensor& a) { |
51 | OP_REQUIRES(context, a.IsSameSize(g), |
52 | errors::InvalidArgument("g and a must be the same size" )); |
53 | } |
54 | static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, |
55 | const Tensor& a) { |
56 | ValidateSameSizeHelper(context, g, a); |
57 | return context->status().ok(); |
58 | } |
59 | }; |
60 | |
61 | template <typename Device, typename T> |
62 | class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> { |
63 | public: |
64 | using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp; |
65 | |
66 | void OperateNoTemplate(OpKernelContext* context, const Tensor& g, |
67 | const Tensor& a, Tensor* output); |
68 | |
69 | // INPUTS: |
70 | // g (gradients): backpropagated gradients |
71 | // a (inputs): either the inputs that were passed to ReluOp(), or its |
72 | // outputs (using either one yields the same result here). |
73 | // OUTPUT: |
74 | // gradients to backprop |
75 | template <int NDIMS> |
76 | void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, |
77 | Tensor* output) { |
78 | OperateNoTemplate(context, g, a, output); |
79 | } |
80 | }; |
81 | |
82 | template <typename Device, typename T> |
83 | void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, |
84 | const Tensor& g, const Tensor& a, |
85 | Tensor* output) { |
86 | if (!ReluHelpers::ValidateSameSize(context, g, a)) return; |
87 | functor::ReluGrad<Device, T> functor; |
88 | functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), |
89 | output->flat<T>()); |
90 | } |
91 | |
92 | template <typename Device, typename T> |
93 | class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> { |
94 | public: |
95 | using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp; |
96 | |
97 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
98 | functor::Relu6<Device, T> functor; |
99 | functor(context->eigen_device<Device>(), input.flat<T>(), |
100 | output->flat<T>()); |
101 | } |
102 | }; |
103 | |
104 | template <typename Device, typename T> |
105 | class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> { |
106 | public: |
107 | using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp; |
108 | |
109 | void OperateNoTemplate(OpKernelContext* context, const Tensor& g, |
110 | const Tensor& a, Tensor* output); |
111 | |
112 | // INPUTS: |
113 | // g (gradients): backpropagated gradients |
114 | // a (inputs): inputs that were passed to Relu6Op() |
115 | // OUTPUT: |
116 | // gradients to backprop |
117 | template <int NDIMS> |
118 | void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, |
119 | Tensor* output) { |
120 | OperateNoTemplate(context, g, a, output); |
121 | } |
122 | }; |
123 | |
124 | template <typename Device, typename T> |
125 | void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, |
126 | const Tensor& g, const Tensor& a, |
127 | Tensor* output) { |
128 | if (!ReluHelpers::ValidateSameSize(context, g, a)) return; |
129 | functor::Relu6Grad<Device, T> functor; |
130 | functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), |
131 | output->flat<T>()); |
132 | } |
133 | |
134 | template <typename Device, typename T> |
135 | class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> { |
136 | public: |
137 | explicit LeakyReluOp(OpKernelConstruction* context) |
138 | : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) { |
139 | float alpha_tmp; |
140 | OP_REQUIRES_OK(context, context->GetAttr("alpha" , &alpha_tmp)); |
141 | alpha_ = T(alpha_tmp); |
142 | } |
143 | |
144 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
145 | functor::LeakyRelu<Device, T> functor; |
146 | functor({context->eigen_device<Device>(), input.flat<T>(), alpha_, |
147 | output->flat<T>()}); |
148 | } |
149 | |
150 | private: |
151 | T alpha_; |
152 | }; |
153 | |
154 | template <typename Device, typename T> |
155 | class LeakyReluGradOp |
156 | : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> { |
157 | public: |
158 | explicit LeakyReluGradOp(OpKernelConstruction* context) |
159 | : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) { |
160 | float alpha_tmp; |
161 | OP_REQUIRES_OK(context, context->GetAttr("alpha" , &alpha_tmp)); |
162 | alpha_ = T(alpha_tmp); |
163 | } |
164 | |
165 | void OperateNoTemplate(OpKernelContext* context, const Tensor& g, |
166 | const Tensor& a, T alpha, Tensor* output); |
167 | |
168 | // INPUTS: |
169 | // g (gradients): backpropagated gradients |
170 | // a (inputs): either the inputs that were passed to LeakyReluOp(), or its |
171 | // outputs (using either one yields the same result here). |
172 | // OUTPUT: |
173 | // gradients to backprop |
174 | template <int NDIMS> |
175 | void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, |
176 | Tensor* output) { |
177 | OperateNoTemplate(context, g, a, alpha_, output); |
178 | } |
179 | |
180 | private: |
181 | T alpha_; |
182 | }; |
183 | |
184 | template <typename Device, typename T> |
185 | void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, |
186 | const Tensor& g, |
187 | const Tensor& a, T alpha, |
188 | Tensor* output) { |
189 | if (!ReluHelpers::ValidateSameSize(context, g, a)) return; |
190 | functor::LeakyReluGrad<Device, T> functor; |
191 | functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha, |
192 | output->flat<T>()); |
193 | }; |
194 | |
195 | template <typename Device, typename T> |
196 | class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> { |
197 | public: |
198 | using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp; |
199 | |
200 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
201 | functor::Elu<Device, T> functor; |
202 | functor(context->eigen_device<Device>(), input.flat<T>(), |
203 | output->flat<T>()); |
204 | } |
205 | }; |
206 | |
207 | template <typename Device, typename T> |
208 | class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> { |
209 | public: |
210 | using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp; |
211 | |
212 | void OperateNoTemplate(OpKernelContext* context, const Tensor& g, |
213 | const Tensor& a, Tensor* output); |
214 | |
215 | // INPUTS: |
216 | // g (gradients): backpropagated gradients |
217 | // a (outputs): outputs of the EluOp() |
218 | // OUTPUT: |
219 | // gradients to backprop |
220 | template <int NDIMS> |
221 | void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, |
222 | Tensor* output) { |
223 | OperateNoTemplate(context, g, a, output); |
224 | } |
225 | }; |
226 | |
227 | template <typename Device, typename T> |
228 | void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, |
229 | const Tensor& g, const Tensor& a, |
230 | Tensor* output) { |
231 | if (!ReluHelpers::ValidateSameSize(context, g, a)) return; |
232 | functor::EluGrad<Device, T> functor; |
233 | functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), |
234 | output->flat<T>()); |
235 | } |
236 | |
237 | template <typename Device, typename T> |
238 | class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> { |
239 | public: |
240 | using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp; |
241 | |
242 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
243 | functor::Selu<Device, T> functor; |
244 | functor(context->eigen_device<Device>(), input.flat<T>(), |
245 | output->flat<T>()); |
246 | } |
247 | }; |
248 | |
249 | template <typename Device, typename T> |
250 | class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> { |
251 | public: |
252 | using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp; |
253 | |
254 | void OperateNoTemplate(OpKernelContext* context, const Tensor& g, |
255 | const Tensor& a, Tensor* output); |
256 | |
257 | // INPUTS: |
258 | // g (gradients): backpropagated gradients |
259 | // a (outputs): outputs of the SeluOp() |
260 | // OUTPUT: |
261 | // gradients to backprop |
262 | template <int NDIMS> |
263 | void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, |
264 | Tensor* output) { |
265 | OperateNoTemplate(context, g, a, output); |
266 | } |
267 | }; |
268 | |
269 | template <typename Device, typename T> |
270 | void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, |
271 | const Tensor& g, const Tensor& a, |
272 | Tensor* output) { |
273 | if (!ReluHelpers::ValidateSameSize(context, g, a)) return; |
274 | functor::SeluGrad<Device, T> functor; |
275 | functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), |
276 | output->flat<T>()); |
277 | } |
278 | |
279 | } // namespace tensorflow |
280 | |
281 | #undef EIGEN_USE_THREADS |
282 | |
283 | #endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_ |
284 | |