1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/kernels/cwise_ops.h" |
20 | #include "tensorflow/core/kernels/cwise_ops_common.h" |
21 | |
22 | #define _USE_MATH_DEFINES |
23 | #include <cmath> |
24 | #include <functional> |
25 | #include <type_traits> |
26 | |
27 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
28 | #include "tensorflow/core/framework/bounds_check.h" |
29 | #include "tensorflow/core/framework/numeric_types.h" |
30 | #include "tensorflow/core/framework/tensor_types.h" |
31 | |
32 | // Keeping all new leakyrelu changes in 1 file. |
33 | // This is similar to changes in cwise_ops.h |
34 | namespace Eigen { |
35 | namespace internal { |
36 | |
37 | template <typename Scalar> |
38 | struct leakyrelu_op { |
39 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit leakyrelu_op(float val = 0.2f) |
40 | EIGEN_NO_THROW { |
41 | m_alpha = Scalar(val); |
42 | } |
43 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
44 | operator()(const Scalar& x) const { |
45 | return x > Scalar(0) ? x : x * Scalar(m_alpha); |
46 | } |
47 | template <typename Packet> |
48 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
49 | Packet alpha = pset1<Packet>(m_alpha); |
50 | return pselect(pcmp_le(x, pzero(x)), pmul(x, alpha), x); |
51 | } |
52 | Scalar m_alpha; |
53 | }; |
54 | |
55 | template <typename Scalar> |
56 | struct functor_traits<leakyrelu_op<Scalar>> { |
57 | enum { |
58 | Cost = |
59 | Eigen::NumTraits<Scalar>::AddCost + Eigen::NumTraits<Scalar>::MulCost, |
60 | PacketAccess = |
61 | packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasCmp, |
62 | }; |
63 | }; |
64 | |
65 | } // namespace internal |
66 | } // namespace Eigen |
67 | |
68 | namespace tensorflow { |
69 | |
70 | namespace functor { |
71 | template <typename T> |
72 | struct leakyrelu : base<T, Eigen::internal::leakyrelu_op<T>> {}; |
73 | } // namespace functor |
74 | |
75 | template <typename Device, typename Functor> |
76 | class LeakyReluOp : public OpKernel { |
77 | public: |
78 | typedef typename Functor::in_type Tin; // Input scalar data type. |
79 | typedef typename Functor::out_type Tout; // Output scalar data type. |
80 | // Tin may be different from Tout. E.g., abs: complex64 -> float |
81 | |
82 | explicit LeakyReluOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
83 | auto in = DataTypeToEnum<Tin>::v(); |
84 | auto out = DataTypeToEnum<Tout>::v(); |
85 | OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out})); |
86 | |
87 | float alpha; |
88 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha" , &alpha)); |
89 | alpha_ = alpha; |
90 | } |
91 | |
92 | void Compute(OpKernelContext* ctx) override { |
93 | const Tensor& inp = ctx->input(0); |
94 | Tensor* out = nullptr; |
95 | if (std::is_same<Tin, Tout>::value) { |
96 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
97 | {0}, 0, inp.shape(), &out)); |
98 | } else { |
99 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); |
100 | } |
101 | functor::UnaryFunctorWithArg<Device, Functor, float>()( |
102 | ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>(), |
103 | alpha_); |
104 | } |
105 | |
106 | private: |
107 | float alpha_; |
108 | }; |
109 | |
110 | REGISTER(LeakyReluOp, CPU, "LeakyRelu" , functor::leakyrelu, bfloat16); |
111 | } // namespace tensorflow |
112 | |