1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ |
18 | |
19 | // Functor definitions for Reduction ops, must be compilable by nvcc. |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor_types.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace functor { |
27 | |
28 | template <typename Reducer> |
29 | struct ReducerTraits { |
30 | enum { IsScalarIdentity = true }; |
31 | }; |
32 | |
33 | // Dummy class used for template specialization for mean reduction, which is |
34 | // accomplished by SumReducer and on-the-fly division by the reduction factor. |
35 | template <typename Scalar> |
36 | struct MeanReducer { |
37 | Scalar initialize() const { return Scalar(0); } |
38 | }; |
39 | |
40 | // Dummy class used for template specialization for l2-norm reduction. |
41 | template <typename Scalar> |
42 | struct EuclideanNormReducer { |
43 | Scalar initialize() const { return Scalar(0); } |
44 | }; |
45 | |
46 | template <typename Scalar> |
47 | struct ReducerTraits<EuclideanNormReducer<Scalar>> { |
48 | enum { IsScalarIdentity = false }; |
49 | }; |
50 | |
51 | template <typename Device, typename OUT_T, typename IN_T, |
52 | typename ReductionAxes, typename Reducer> |
53 | struct ReduceEigenImpl { |
54 | void operator()(const Device& d, OUT_T out, IN_T in, |
55 | const ReductionAxes& reduction_axes, const Reducer& reducer) { |
56 | out.device(d) = in.reduce(reduction_axes, reducer); |
57 | } |
58 | }; |
59 | |
60 | // Specialization for BF16 Reducer to fix accuracy. |
61 | // TODO: All BF16 reducers should have specializations to fix accuracy. |
62 | #define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \ |
63 | template <typename Device, typename OUT_T, typename IN_T, \ |
64 | typename ReductionAxes> \ |
65 | struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \ |
66 | Reducer<ScalarType>> { \ |
67 | void operator()(const Device& d, OUT_T out, IN_T in, \ |
68 | const ReductionAxes& reduction_axes, \ |
69 | const Reducer<ScalarType>& reducer) { \ |
70 | static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \ |
71 | ""); \ |
72 | Reducer<IntermediateType> intermediate_reducer; \ |
73 | auto in_as_intermediate = in.template cast<IntermediateType>(); \ |
74 | out.device(d) = \ |
75 | in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \ |
76 | .template cast<ScalarType>(); \ |
77 | } \ |
78 | }; |
79 | |
80 | CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float); |
81 | #undef CASTING_SPECIALIZATION |
82 | |
83 | template <typename Device, typename OUT_T, typename IN_T, |
84 | typename ReductionAxes, typename Scalar> |
85 | struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, |
86 | functor::MeanReducer<Scalar>> { |
87 | void operator()(const Device& d, OUT_T out, IN_T in, |
88 | const ReductionAxes& reduction_axes, |
89 | const functor::MeanReducer<Scalar>& reducer) { |
90 | static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "" ); |
91 | Eigen::internal::SumReducer<Scalar> sum_reducer; |
92 | out.device(d) = in.reduce(reduction_axes, sum_reducer) / |
93 | static_cast<Scalar>(in.size() / out.size()); |
94 | } |
95 | }; |
96 | |
97 | // Specialization for which we do the reduction in IntermediateType to |
98 | // avoid integer overflow. |
99 | #define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \ |
100 | template <typename Device, typename OUT_T, typename IN_T, \ |
101 | typename ReductionAxes> \ |
102 | struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \ |
103 | functor::MeanReducer<ScalarType>> { \ |
104 | void operator()(const Device& d, OUT_T out, IN_T in, \ |
105 | const ReductionAxes& reduction_axes, \ |
106 | const functor::MeanReducer<ScalarType>& reducer) { \ |
107 | static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \ |
108 | ""); \ |
109 | Eigen::internal::SumReducer<IntermediateType> sum_reducer; \ |
110 | out.device(d) = (in.template cast<IntermediateType>().reduce( \ |
111 | reduction_axes, sum_reducer) / \ |
112 | static_cast<IntermediateType>(in.size() / out.size())) \ |
113 | .template cast<ScalarType>(); \ |
114 | } \ |
115 | } |
116 | |
117 | CASTING_SPECIALIZATION(uint8, uint64); |
118 | CASTING_SPECIALIZATION(uint16, uint64); |
119 | CASTING_SPECIALIZATION(uint32, uint64); |
120 | CASTING_SPECIALIZATION(int8, int64_t); |
121 | CASTING_SPECIALIZATION(int16, int64_t); |
122 | CASTING_SPECIALIZATION(int32, int64_t); |
123 | #undef CASTING_SPECIALIZATION |
124 | |
125 | // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional |
126 | // controlled by an attribute. |
127 | template <typename Device, typename OUT_T, typename IN_T, |
128 | typename ReductionAxes, typename Scalar> |
129 | struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, |
130 | functor::EuclideanNormReducer<Scalar>> { |
131 | void operator()(const Device& d, OUT_T out, IN_T in, |
132 | const ReductionAxes& reduction_axes, |
133 | const functor::EuclideanNormReducer<Scalar>& reducer) { |
134 | static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "" ); |
135 | Eigen::internal::SumReducer<Scalar> sum_reducer; |
136 | out.device(d) = |
137 | (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt(); |
138 | } |
139 | }; |
140 | |
141 | template <typename Device, typename OUT_T, typename IN_T, |
142 | typename ReductionAxes> |
143 | struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, |
144 | functor::EuclideanNormReducer<bfloat16>> { |
145 | void operator()(const Device& d, OUT_T out, IN_T in, |
146 | const ReductionAxes& reduction_axes, |
147 | const functor::EuclideanNormReducer<bfloat16>& reducer) { |
148 | static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "" ); |
149 | Eigen::internal::SumReducer<float> sum_reducer; |
150 | auto in_as_float = in.template cast<float>(); |
151 | out.device(d) = (in_as_float * in_as_float.conjugate()) |
152 | .reduce(reduction_axes, sum_reducer) |
153 | .sqrt() |
154 | .template cast<bfloat16>(); |
155 | } |
156 | }; |
157 | |
158 | // For most reducers, the identity is Reducer::initialize() |
159 | template <typename Reducer> |
160 | struct Identity { |
161 | static auto identity(const Reducer& reducer) |
162 | -> decltype(reducer.initialize()) { |
163 | return reducer.initialize(); |
164 | } |
165 | }; |
166 | |
167 | // MeanReducer is a special case, since it doesn't technically have an identity. |
168 | // Thus, ideally we'd return nan. However, mean is instantiated for integer |
169 | // types as well, so we do the nan override only for floating point types. |
170 | #define FIX_MEAN_IDENTITY(T) \ |
171 | template <> \ |
172 | struct Identity<functor::MeanReducer<T>> { \ |
173 | static T identity(const functor::MeanReducer<T>&) { \ |
174 | return Eigen::NumTraits<T>::quiet_NaN(); \ |
175 | } \ |
176 | }; |
177 | FIX_MEAN_IDENTITY(Eigen::half) |
178 | FIX_MEAN_IDENTITY(float) |
179 | FIX_MEAN_IDENTITY(double) |
180 | #undef FIX_MEAN_IDENTITY |
181 | |
182 | template <typename Device, typename OUT_T, typename Reducer> |
183 | void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) { |
184 | MaybeWith32BitIndexing<Device>( |
185 | [&](auto out32) { |
186 | out32.device(d) = out32.constant(Identity<Reducer>::identity(reducer)); |
187 | }, |
188 | out); |
189 | } |
190 | |
191 | template <typename Device, typename Reducer> |
192 | struct ReduceFunctor { |
193 | template <typename OUT_T, typename IN_T, typename ReductionAxes> |
194 | static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, |
195 | const ReductionAxes& reduction_axes, |
196 | const Reducer& reducer); |
197 | |
198 | template <typename OUT_T> |
199 | static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer); |
200 | }; |
201 | |
202 | } // namespace functor |
203 | } // namespace tensorflow |
204 | |
205 | #endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ |
206 | |