1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/tensor_types.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace functor { |
24 | |
25 | typedef Eigen::Index Index; |
26 | |
27 | // TODO(b/154339590): Needs to be vectorized. |
28 | template <typename Device, typename Reducer, typename T> |
29 | struct Scan { |
30 | void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in, |
31 | typename TTypes<T, 3>::Tensor out, const Reducer& reducer, |
32 | const bool reverse, const bool exclusive) { |
33 | // Perform the reverse ops directly with Eigen, which avoids copying the |
34 | // tensor twice compared to using individual ops. |
35 | Eigen::array<bool, 3> dims; |
36 | dims[0] = false; |
37 | dims[1] = reverse; |
38 | dims[2] = false; |
39 | MaybeWith32BitIndexing<Device>( |
40 | [&](auto in32, auto out32) { |
41 | out32.device(d) = |
42 | in32.reverse(dims).scan(1, reducer, exclusive).reverse(dims); |
43 | }, |
44 | in, out); |
45 | } |
46 | }; |
47 | |
48 | template <typename T> |
49 | struct LogSumExp { |
50 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, |
51 | const T& b) const { |
52 | auto mi = Eigen::internal::scalar_min_op<T>()(a, b); |
53 | auto ma = Eigen::internal::scalar_max_op<T>()(a, b); |
54 | |
55 | auto sub = Eigen::internal::scalar_difference_op<T>(); |
56 | auto add = Eigen::internal::scalar_sum_op<T>(); |
57 | auto exp = Eigen::internal::scalar_exp_op<T>(); |
58 | auto log1p = Eigen::internal::scalar_log1p_op<T>(); |
59 | auto cmp_lt = |
60 | Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>(); |
61 | |
62 | auto logsumexp = add(log1p(exp(sub(mi, ma))), ma); |
63 | return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp; |
64 | } |
65 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a, |
66 | const T& b) const { |
67 | auto mi = Eigen::internal::pmin(a, b); |
68 | auto ma = Eigen::internal::pmax(a, b); |
69 | using Eigen::internal::padd; |
70 | using Eigen::internal::pcmp_lt; |
71 | using Eigen::internal::pexp; |
72 | using Eigen::internal::plog1p; |
73 | using Eigen::internal::pset1; |
74 | using Eigen::internal::psub; |
75 | |
76 | auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma); |
77 | return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, |
78 | logsumexp); |
79 | } |
80 | }; |
81 | |
82 | template <typename T> |
83 | struct LogSumExpReducer { |
84 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { |
85 | LogSumExp<T> logsumexp; |
86 | *accum = logsumexp(*accum, t); |
87 | } |
88 | |
89 | template <typename Packet> |
90 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, |
91 | Packet* accum) const { |
92 | LogSumExp<T> logsumexp; |
93 | *accum = logsumexp.packetOp(*accum, p); |
94 | } |
95 | |
96 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { |
97 | return -Eigen::NumTraits<T>::infinity(); |
98 | } |
99 | |
100 | template <typename Packet> |
101 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { |
102 | return Eigen::internal::pset1(initialize()); |
103 | } |
104 | |
105 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { |
106 | return accum; |
107 | } |
108 | |
109 | template <typename Packet> |
110 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet |
111 | finalizePacket(const Packet& vaccum) const { |
112 | return vaccum; |
113 | } |
114 | |
115 | template <typename Packet> |
116 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T |
117 | finalizeBoth(const T saccum, const Packet& vaccum) const { |
118 | auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>(); |
119 | auto sum_reducer = Eigen::internal::SumReducer<T>(); |
120 | auto exp = Eigen::internal::scalar_exp_op<T>(); |
121 | auto cmp_lt = |
122 | Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>(); |
123 | auto log = Eigen::internal::scalar_log_op<T>(); |
124 | auto add = Eigen::internal::scalar_sum_op<T>(); |
125 | |
126 | using Eigen::internal::pexp; |
127 | using Eigen::internal::psub; |
128 | |
129 | // `ma = max(x1, ..., xn)` |
130 | // If the max of all of the `xi` is `-infinity` then the result is |
131 | // -infinity. If the max is larger than `-infinity` then it's safe to use |
132 | // for normalization even if the other elements are `-infinity`. |
133 | // |
134 | // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))` |
135 | auto ma = max_reducer.finalizeBoth(saccum, vaccum); |
136 | auto logsumexp = add(log(sum_reducer.finalizeBoth( |
137 | exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))), |
138 | ma); |
139 | return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp; |
140 | } |
141 | }; |
142 | |
143 | } // namespace functor |
144 | } // namespace tensorflow |
145 | |
146 | #endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ |
147 | |