1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_
19#define TENSORFLOW_CORE_KERNELS_RELU_OP_H_
20
21#define EIGEN_USE_THREADS
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/numeric_op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/kernels/relu_op_functor.h"
29#include "tensorflow/core/lib/core/errors.h"
30
31namespace tensorflow {
32
33template <typename Device, typename T>
34class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
35 public:
36 using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
37
38 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
39 functor::Relu<Device, T> functor;
40 functor(context->eigen_device<Device>(), input.flat<T>(),
41 output->flat<T>());
42 }
43};
44
45// Out of line check to save code space (we have this code once, rather
46// than once for every NDIMS * NumTypes * Num_different_relu_variants
47// functions.
48struct ReluHelpers {
49 static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
50 const Tensor& a) {
51 OP_REQUIRES(context, a.IsSameSize(g),
52 errors::InvalidArgument("g and a must be the same size"));
53 }
54 static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
55 const Tensor& a) {
56 ValidateSameSizeHelper(context, g, a);
57 return context->status().ok();
58 }
59};
60
61template <typename Device, typename T>
62class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
63 public:
64 using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
65
66 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
67 const Tensor& a, Tensor* output);
68
69 // INPUTS:
70 // g (gradients): backpropagated gradients
71 // a (inputs): either the inputs that were passed to ReluOp(), or its
72 // outputs (using either one yields the same result here).
73 // OUTPUT:
74 // gradients to backprop
75 template <int NDIMS>
76 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
77 Tensor* output) {
78 OperateNoTemplate(context, g, a, output);
79 }
80};
81
82template <typename Device, typename T>
83void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
84 const Tensor& g, const Tensor& a,
85 Tensor* output) {
86 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
87 functor::ReluGrad<Device, T> functor;
88 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
89 output->flat<T>());
90}
91
92template <typename Device, typename T>
93class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
94 public:
95 using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
96
97 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
98 functor::Relu6<Device, T> functor;
99 functor(context->eigen_device<Device>(), input.flat<T>(),
100 output->flat<T>());
101 }
102};
103
104template <typename Device, typename T>
105class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
106 public:
107 using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
108
109 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
110 const Tensor& a, Tensor* output);
111
112 // INPUTS:
113 // g (gradients): backpropagated gradients
114 // a (inputs): inputs that were passed to Relu6Op()
115 // OUTPUT:
116 // gradients to backprop
117 template <int NDIMS>
118 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
119 Tensor* output) {
120 OperateNoTemplate(context, g, a, output);
121 }
122};
123
124template <typename Device, typename T>
125void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
126 const Tensor& g, const Tensor& a,
127 Tensor* output) {
128 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
129 functor::Relu6Grad<Device, T> functor;
130 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
131 output->flat<T>());
132}
133
134template <typename Device, typename T>
135class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
136 public:
137 explicit LeakyReluOp(OpKernelConstruction* context)
138 : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
139 float alpha_tmp;
140 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
141 alpha_ = T(alpha_tmp);
142 }
143
144 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
145 functor::LeakyRelu<Device, T> functor;
146 functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
147 output->flat<T>()});
148 }
149
150 private:
151 T alpha_;
152};
153
154template <typename Device, typename T>
155class LeakyReluGradOp
156 : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
157 public:
158 explicit LeakyReluGradOp(OpKernelConstruction* context)
159 : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
160 float alpha_tmp;
161 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
162 alpha_ = T(alpha_tmp);
163 }
164
165 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
166 const Tensor& a, T alpha, Tensor* output);
167
168 // INPUTS:
169 // g (gradients): backpropagated gradients
170 // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
171 // outputs (using either one yields the same result here).
172 // OUTPUT:
173 // gradients to backprop
174 template <int NDIMS>
175 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
176 Tensor* output) {
177 OperateNoTemplate(context, g, a, alpha_, output);
178 }
179
180 private:
181 T alpha_;
182};
183
184template <typename Device, typename T>
185void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
186 const Tensor& g,
187 const Tensor& a, T alpha,
188 Tensor* output) {
189 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
190 functor::LeakyReluGrad<Device, T> functor;
191 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
192 output->flat<T>());
193};
194
195template <typename Device, typename T>
196class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
197 public:
198 using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
199
200 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
201 functor::Elu<Device, T> functor;
202 functor(context->eigen_device<Device>(), input.flat<T>(),
203 output->flat<T>());
204 }
205};
206
207template <typename Device, typename T>
208class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
209 public:
210 using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
211
212 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
213 const Tensor& a, Tensor* output);
214
215 // INPUTS:
216 // g (gradients): backpropagated gradients
217 // a (outputs): outputs of the EluOp()
218 // OUTPUT:
219 // gradients to backprop
220 template <int NDIMS>
221 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
222 Tensor* output) {
223 OperateNoTemplate(context, g, a, output);
224 }
225};
226
227template <typename Device, typename T>
228void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
229 const Tensor& g, const Tensor& a,
230 Tensor* output) {
231 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
232 functor::EluGrad<Device, T> functor;
233 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
234 output->flat<T>());
235}
236
237template <typename Device, typename T>
238class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
239 public:
240 using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
241
242 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
243 functor::Selu<Device, T> functor;
244 functor(context->eigen_device<Device>(), input.flat<T>(),
245 output->flat<T>());
246 }
247};
248
249template <typename Device, typename T>
250class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
251 public:
252 using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
253
254 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
255 const Tensor& a, Tensor* output);
256
257 // INPUTS:
258 // g (gradients): backpropagated gradients
259 // a (outputs): outputs of the SeluOp()
260 // OUTPUT:
261 // gradients to backprop
262 template <int NDIMS>
263 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
264 Tensor* output) {
265 OperateNoTemplate(context, g, a, output);
266 }
267};
268
269template <typename Device, typename T>
270void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
271 const Tensor& g, const Tensor& a,
272 Tensor* output) {
273 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
274 functor::SeluGrad<Device, T> functor;
275 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
276 output->flat<T>());
277}
278
279} // namespace tensorflow
280
281#undef EIGEN_USE_THREADS
282
283#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_
284