1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
17#define TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
18// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/tensor_types.h"
22
23namespace tensorflow {
24namespace functor {
25
26// Functor used by ReluOp to do the computations.
27template <typename Device, typename T>
28struct Relu {
29 // Computes Relu activation.
30 //
31 // features: any shape.
32 // activations: same shape as "features".
33 void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
34 typename TTypes<T>::Tensor activations) {
35 activations.device(d) =
36 features.template cwiseMax<Eigen::PropagateNaN>(static_cast<T>(0));
37 }
38};
39
40// Functor used by ReluGradOp to do the computations.
41template <typename Device, typename T>
42struct ReluGrad {
43 // Computes ReluGrad backprops.
44 //
45 // gradients: gradients backpropagated to the Relu op.
46 // features: either the inputs that were passed to the Relu or, or its
47 // outputs (using either one yields the same result here).
48 // backprops: gradients to backpropagate to the Relu inputs.
49 void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
50 typename TTypes<T>::ConstTensor features,
51 typename TTypes<T>::Tensor backprops) {
52 // NOTE: When the activation is exactly zero, we do not propagate the
53 // associated gradient value. This allows the output of the Relu to be used,
54 // as well as its input.
55 backprops.device(d) =
56 gradients * (features > static_cast<T>(0)).template cast<T>();
57 }
58};
59
60// Functor used by Relu6Op to do the computations.
61template <typename Device, typename T>
62struct Relu6 {
63 // Computes Relu6 activation.
64 //
65 // features: any shape.
66 // activations: same shape as "features".
67 void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
68 typename TTypes<T>::Tensor activations) {
69 activations.device(d) =
70 features.template cwiseMax<Eigen::PropagateNaN>(static_cast<T>(0))
71 .template cwiseMin<Eigen::PropagateNaN>(static_cast<T>(6));
72 }
73};
74
75// Functor used by ReluGradOp to do the computations.
76template <typename Device, typename T>
77struct Relu6Grad {
78 // Computes Relu6Grad backprops.
79 //
80 // gradients: gradients backpropagated to the Relu6 op.
81 // features: inputs that where passed to the Relu6 op, or its outputs.
82 // backprops: gradients to backpropagate to the Relu6 inputs.
83 void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
84 typename TTypes<T>::ConstTensor features,
85 typename TTypes<T>::Tensor backprops) {
86 // NOTE: When the activation is exactly zero or six, we
87 // make sure not to propagate the associated gradient
88 // value. This allows "features" to be either the input or the output of
89 // the relu6.
90 backprops.device(d) = gradients * ((features > static_cast<T>(0)) *
91 (features < static_cast<T>(6)))
92 .template cast<T>();
93 }
94};
95
96// Functor used by LeakyReluOp to do the computations.
97template <typename Device, typename T>
98struct LeakyRelu {
99 // Computes LeakyRelu activation.
100 //
101 // features: any shape.
102 // activations: same shape as "features".
103
104 // Need to bundle the args (to the LeakyRelu functor) within a struct
105 // Not doing so leads to Eigen kernel args not getting populated
106 // corretly for Eigen::half type (when building on the ROCM platform)
107 struct LeakyReluArgs {
108 const Device& d;
109 typename TTypes<T>::ConstTensor features;
110 T alpha;
111 typename TTypes<T>::Tensor activations;
112 };
113 void operator()(LeakyReluArgs args) {
114 // Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
115 args.activations.device(args.d) =
116 (args.features > static_cast<T>(0))
117 .select(args.features, args.features * args.alpha);
118 }
119};
120
121// Functor used by LeakyReluGradOp to do the computations.
122template <typename Device, typename T>
123struct LeakyReluGrad {
124 // Computes LeakyReluGrad backprops.
125 //
126 // gradients: gradients backpropagated to the LeakyRelu op.
127 // features: either the inputs that were passed to the LeakyRelu or, or its
128 // outputs (using either one yields the same result here).
129 // backprops: gradients to backpropagate to the LeakyRelu inputs.
130 void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
131 typename TTypes<T>::ConstTensor features, T alpha,
132 typename TTypes<T>::Tensor backprops) {
133 backprops.device(d) =
134 (features > static_cast<T>(0)).select(gradients, gradients * alpha);
135 }
136};
137
138// Functor used by EluOp to do the computations.
139template <typename Device, typename T>
140struct Elu {
141 // Computes Elu activation.
142 //
143 // features: any shape.
144 // activations: same shape as "features".
145 void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
146 typename TTypes<T>::Tensor activations) {
147 // features.constant(?)
148 activations.device(d) =
149 (features < static_cast<T>(0))
150 .select(features.exp() - features.constant(static_cast<T>(1)),
151 features);
152 }
153};
154
155// Functor used by EluGradOp to do the computations.
156template <typename Device, typename T>
157struct EluGrad {
158 // Computes EluGrad backprops.
159 //
160 // gradients: gradients backpropagated to the Elu op.
161 // activations: outputs of the Elu op.
162 // backprops: gradients to backpropagate to the Elu inputs.
163 void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
164 typename TTypes<T>::ConstTensor activations,
165 typename TTypes<T>::Tensor backprops) {
166 backprops.device(d) =
167 (activations < static_cast<T>(0))
168 .select((activations + static_cast<T>(1)) * gradients, gradients);
169 }
170};
171
172// Functor used by SeluOp to do the computations.
173template <typename Device, typename T>
174struct Selu {
175 // Computes Selu activation.
176 //
177 // features: any shape.
178 // activations: same shape as "features".
179 void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
180 typename TTypes<T>::Tensor activations) {
181 // features.constant(?)
182 const auto scale = static_cast<T>(1.0507009873554804934193349852946);
183 const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
184 const auto one = static_cast<T>(1);
185 const auto zero = static_cast<T>(0);
186 activations.device(d) =
187 (features < zero)
188 .select(scale_alpha * (features.exp() - features.constant(one)),
189 scale * features);
190 }
191};
192
193// Functor used by SeluGradOp to do the computations.
194template <typename Device, typename T>
195struct SeluGrad {
196 // Computes SeluGrad backprops.
197 //
198 // gradients: gradients backpropagated to the Selu op.
199 // activations: outputs of the Selu op.
200 // backprops: gradients to backpropagate to the Selu inputs.
201 void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
202 typename TTypes<T>::ConstTensor activations,
203 typename TTypes<T>::Tensor backprops) {
204 const auto scale = static_cast<T>(1.0507009873554804934193349852946);
205 const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
206 backprops.device(d) =
207 (activations < static_cast<T>(0))
208 .select(gradients * (activations + scale_alpha), gradients * scale);
209 }
210};
211
212} // namespace functor
213} // namespace tensorflow
214
215#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
216