1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
17#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
18
19#define EIGEN_USE_THREADS
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/kernels/cwise_ops.h"
22
23namespace Eigen {
24namespace internal {
25
26// Gradient for the tanh function
27template <typename T>
28struct scalar_tanh_gradient_op {
29 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
30 operator()(const T& output, const T& output_gradient) const {
31 return output_gradient * (T(1) - output * output);
32 }
33 template <typename Packet>
34 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
35 packetOp(const Packet& output, const Packet& output_gradient) const {
36 return pmul(output_gradient,
37 psub(pset1<Packet>(T(1)), pmul(output, output)));
38 }
39};
40template <typename T>
41struct functor_traits<scalar_tanh_gradient_op<T>> {
42 enum {
43 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
44 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
45 };
46};
47
48// Gradient for the sigmoid function
49template <typename T>
50struct scalar_sigmoid_gradient_op {
51 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
52 operator()(const T& output, const T& output_gradient) const {
53 return output_gradient * output * (T(1) - output);
54 }
55 template <typename Packet>
56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
57 packetOp(const Packet& output, const Packet& output_gradient) const {
58 return pmul(output_gradient,
59 pmul(output, psub(pset1<Packet>(T(1)), output)));
60 }
61};
62template <typename T>
63struct functor_traits<scalar_sigmoid_gradient_op<T>> {
64 enum {
65 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
66 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
67 };
68};
69
70// Gradient for the inverse function
71template <typename T>
72struct scalar_inverse_gradient_op {
73 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
74 operator()(const T& output, const T& output_gradient) const {
75 if (output_gradient == T(0)) {
76 return T(0);
77 } else {
78 const T out_conj = numext::conj(output);
79 return -out_conj * out_conj * output_gradient;
80 }
81 }
82 template <typename Packet>
83 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
84 packetOp(const Packet& output, const Packet& output_gradient) const {
85 const Packet out_conj = pconj(output);
86 return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)),
87 output_gradient);
88 }
89};
90template <typename T>
91struct functor_traits<scalar_inverse_gradient_op<T>> {
92 enum {
93 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
94 PacketAccess = packet_traits<T>::HasMul,
95 };
96};
97
98// Gradient for the sqrt function
99template <typename T>
100struct scalar_sqrt_gradient_op {
101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
102 operator()(const T& output, const T& output_gradient) const {
103 if (output_gradient == T(0)) {
104 return T(0);
105 } else {
106 const T out_conj = numext::conj(output);
107 return (static_cast<T>(0.5) * output_gradient) / out_conj;
108 }
109 }
110 template <typename Packet>
111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
112 packetOp(const Packet& output, const Packet& output_gradient) const {
113 const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
114 const Packet out_conj = pconj(output);
115 return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj),
116 output_gradient);
117 }
118};
119template <typename T>
120struct functor_traits<scalar_sqrt_gradient_op<T>> {
121 enum {
122 PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv,
123 Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
124 };
125};
126
127// Gradient for the rsqrt function
128template <typename T>
129struct scalar_rsqrt_gradient_op {
130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
131 operator()(const T& output, const T& output_gradient) const {
132 if (output_gradient == T(0)) {
133 return T(0);
134 } else {
135 const T out_conj = numext::conj(output);
136 return static_cast<T>(-0.5) * (output_gradient * out_conj) *
137 (out_conj * out_conj);
138 }
139 }
140 template <typename Packet>
141 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
142 packetOp(const Packet& output, const Packet& output_gradient) const {
143 const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
144 const Packet out_conj = pconj(output);
145 auto safe_pmul = [](const Packet& a, const Packet& b) {
146 return mul_no_nan_op<T>().packetOp(a, b);
147 };
148 return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)),
149 safe_pmul(out_conj, output_gradient));
150 }
151};
152template <typename T>
153struct functor_traits<scalar_rsqrt_gradient_op<T>> {
154 enum {
155 Cost = 4 * NumTraits<T>::MulCost,
156 PacketAccess = packet_traits<T>::HasMul,
157 };
158};
159
160} // end namespace internal
161} // end namespace Eigen
162
163namespace tensorflow {
164
165namespace functor {
166
167template <typename Device, typename Functor>
168struct SimpleBinaryFunctor {
169 void operator()(const Device& d, typename Functor::tout_type out,
170 typename Functor::tin_type in0,
171 typename Functor::tin_type in1);
172};
173
174// Partial specialization of BinaryFunctor for CPU devices
175typedef Eigen::ThreadPoolDevice CPUDevice;
176
177template <typename Functor>
178struct SimpleBinaryFunctor<CPUDevice, Functor> {
179 void operator()(const CPUDevice& d, typename Functor::tout_type out,
180 typename Functor::tin_type in0,
181 typename Functor::tin_type in1) {
182 out.device(d) = in0.binaryExpr(in1, typename Functor::func());
183 }
184};
185
186
187template <typename T>
188struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
189
190template <typename T>
191struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
192};
193
194template <typename T>
195struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> {
196};
197
198template <typename T>
199struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
200
201template <typename T>
202struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
203
204template <typename T>
205struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
206
207} // end namespace functor
208
209} // end namespace tensorflow
210#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
211