1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/kernels/cwise_ops.h"
20#include "tensorflow/core/kernels/cwise_ops_common.h"
21
22#define _USE_MATH_DEFINES
23#include <cmath>
24#include <functional>
25#include <type_traits>
26
27#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28#include "tensorflow/core/framework/bounds_check.h"
29#include "tensorflow/core/framework/numeric_types.h"
30#include "tensorflow/core/framework/tensor_types.h"
31
32// Keeping all new leakyrelu changes in 1 file.
33// This is similar to changes in cwise_ops.h
34namespace Eigen {
35namespace internal {
36
37template <typename Scalar>
38struct leakyrelu_op {
39 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit leakyrelu_op(float val = 0.2f)
40 EIGEN_NO_THROW {
41 m_alpha = Scalar(val);
42 }
43 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
44 operator()(const Scalar& x) const {
45 return x > Scalar(0) ? x : x * Scalar(m_alpha);
46 }
47 template <typename Packet>
48 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
49 Packet alpha = pset1<Packet>(m_alpha);
50 return pselect(pcmp_le(x, pzero(x)), pmul(x, alpha), x);
51 }
52 Scalar m_alpha;
53};
54
55template <typename Scalar>
56struct functor_traits<leakyrelu_op<Scalar>> {
57 enum {
58 Cost =
59 Eigen::NumTraits<Scalar>::AddCost + Eigen::NumTraits<Scalar>::MulCost,
60 PacketAccess =
61 packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasCmp,
62 };
63};
64
65} // namespace internal
66} // namespace Eigen
67
68namespace tensorflow {
69
70namespace functor {
71template <typename T>
72struct leakyrelu : base<T, Eigen::internal::leakyrelu_op<T>> {};
73} // namespace functor
74
75template <typename Device, typename Functor>
76class LeakyReluOp : public OpKernel {
77 public:
78 typedef typename Functor::in_type Tin; // Input scalar data type.
79 typedef typename Functor::out_type Tout; // Output scalar data type.
80 // Tin may be different from Tout. E.g., abs: complex64 -> float
81
82 explicit LeakyReluOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
83 auto in = DataTypeToEnum<Tin>::v();
84 auto out = DataTypeToEnum<Tout>::v();
85 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
86
87 float alpha;
88 OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha));
89 alpha_ = alpha;
90 }
91
92 void Compute(OpKernelContext* ctx) override {
93 const Tensor& inp = ctx->input(0);
94 Tensor* out = nullptr;
95 if (std::is_same<Tin, Tout>::value) {
96 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
97 {0}, 0, inp.shape(), &out));
98 } else {
99 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
100 }
101 functor::UnaryFunctorWithArg<Device, Functor, float>()(
102 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>(),
103 alpha_);
104 }
105
106 private:
107 float alpha_;
108};
109
110REGISTER(LeakyReluOp, CPU, "LeakyRelu", functor::leakyrelu, bfloat16);
111} // namespace tensorflow
112