1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
18
19// Functor definitions for Reduction ops, must be compilable by nvcc.
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor_types.h"
24
25namespace tensorflow {
26namespace functor {
27
28template <typename Reducer>
29struct ReducerTraits {
30 enum { IsScalarIdentity = true };
31};
32
33// Dummy class used for template specialization for mean reduction, which is
34// accomplished by SumReducer and on-the-fly division by the reduction factor.
35template <typename Scalar>
36struct MeanReducer {
37 Scalar initialize() const { return Scalar(0); }
38};
39
40// Dummy class used for template specialization for l2-norm reduction.
41template <typename Scalar>
42struct EuclideanNormReducer {
43 Scalar initialize() const { return Scalar(0); }
44};
45
46template <typename Scalar>
47struct ReducerTraits<EuclideanNormReducer<Scalar>> {
48 enum { IsScalarIdentity = false };
49};
50
51template <typename Device, typename OUT_T, typename IN_T,
52 typename ReductionAxes, typename Reducer>
53struct ReduceEigenImpl {
54 void operator()(const Device& d, OUT_T out, IN_T in,
55 const ReductionAxes& reduction_axes, const Reducer& reducer) {
56 out.device(d) = in.reduce(reduction_axes, reducer);
57 }
58};
59
60// Specialization for BF16 Reducer to fix accuracy.
61// TODO: All BF16 reducers should have specializations to fix accuracy.
62#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \
63 template <typename Device, typename OUT_T, typename IN_T, \
64 typename ReductionAxes> \
65 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
66 Reducer<ScalarType>> { \
67 void operator()(const Device& d, OUT_T out, IN_T in, \
68 const ReductionAxes& reduction_axes, \
69 const Reducer<ScalarType>& reducer) { \
70 static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
71 ""); \
72 Reducer<IntermediateType> intermediate_reducer; \
73 auto in_as_intermediate = in.template cast<IntermediateType>(); \
74 out.device(d) = \
75 in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \
76 .template cast<ScalarType>(); \
77 } \
78 };
79
80CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
81#undef CASTING_SPECIALIZATION
82
83template <typename Device, typename OUT_T, typename IN_T,
84 typename ReductionAxes, typename Scalar>
85struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
86 functor::MeanReducer<Scalar>> {
87 void operator()(const Device& d, OUT_T out, IN_T in,
88 const ReductionAxes& reduction_axes,
89 const functor::MeanReducer<Scalar>& reducer) {
90 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
91 Eigen::internal::SumReducer<Scalar> sum_reducer;
92 out.device(d) = in.reduce(reduction_axes, sum_reducer) /
93 static_cast<Scalar>(in.size() / out.size());
94 }
95};
96
97// Specialization for which we do the reduction in IntermediateType to
98// avoid integer overflow.
99#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \
100 template <typename Device, typename OUT_T, typename IN_T, \
101 typename ReductionAxes> \
102 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
103 functor::MeanReducer<ScalarType>> { \
104 void operator()(const Device& d, OUT_T out, IN_T in, \
105 const ReductionAxes& reduction_axes, \
106 const functor::MeanReducer<ScalarType>& reducer) { \
107 static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
108 ""); \
109 Eigen::internal::SumReducer<IntermediateType> sum_reducer; \
110 out.device(d) = (in.template cast<IntermediateType>().reduce( \
111 reduction_axes, sum_reducer) / \
112 static_cast<IntermediateType>(in.size() / out.size())) \
113 .template cast<ScalarType>(); \
114 } \
115 }
116
117CASTING_SPECIALIZATION(uint8, uint64);
118CASTING_SPECIALIZATION(uint16, uint64);
119CASTING_SPECIALIZATION(uint32, uint64);
120CASTING_SPECIALIZATION(int8, int64_t);
121CASTING_SPECIALIZATION(int16, int64_t);
122CASTING_SPECIALIZATION(int32, int64_t);
123#undef CASTING_SPECIALIZATION
124
125// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
126// controlled by an attribute.
127template <typename Device, typename OUT_T, typename IN_T,
128 typename ReductionAxes, typename Scalar>
129struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
130 functor::EuclideanNormReducer<Scalar>> {
131 void operator()(const Device& d, OUT_T out, IN_T in,
132 const ReductionAxes& reduction_axes,
133 const functor::EuclideanNormReducer<Scalar>& reducer) {
134 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
135 Eigen::internal::SumReducer<Scalar> sum_reducer;
136 out.device(d) =
137 (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt();
138 }
139};
140
141template <typename Device, typename OUT_T, typename IN_T,
142 typename ReductionAxes>
143struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
144 functor::EuclideanNormReducer<bfloat16>> {
145 void operator()(const Device& d, OUT_T out, IN_T in,
146 const ReductionAxes& reduction_axes,
147 const functor::EuclideanNormReducer<bfloat16>& reducer) {
148 static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "");
149 Eigen::internal::SumReducer<float> sum_reducer;
150 auto in_as_float = in.template cast<float>();
151 out.device(d) = (in_as_float * in_as_float.conjugate())
152 .reduce(reduction_axes, sum_reducer)
153 .sqrt()
154 .template cast<bfloat16>();
155 }
156};
157
158// For most reducers, the identity is Reducer::initialize()
159template <typename Reducer>
160struct Identity {
161 static auto identity(const Reducer& reducer)
162 -> decltype(reducer.initialize()) {
163 return reducer.initialize();
164 }
165};
166
167// MeanReducer is a special case, since it doesn't technically have an identity.
168// Thus, ideally we'd return nan. However, mean is instantiated for integer
169// types as well, so we do the nan override only for floating point types.
170#define FIX_MEAN_IDENTITY(T) \
171 template <> \
172 struct Identity<functor::MeanReducer<T>> { \
173 static T identity(const functor::MeanReducer<T>&) { \
174 return Eigen::NumTraits<T>::quiet_NaN(); \
175 } \
176 };
177FIX_MEAN_IDENTITY(Eigen::half)
178FIX_MEAN_IDENTITY(float)
179FIX_MEAN_IDENTITY(double)
180#undef FIX_MEAN_IDENTITY
181
182template <typename Device, typename OUT_T, typename Reducer>
183void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) {
184 MaybeWith32BitIndexing<Device>(
185 [&](auto out32) {
186 out32.device(d) = out32.constant(Identity<Reducer>::identity(reducer));
187 },
188 out);
189}
190
191template <typename Device, typename Reducer>
192struct ReduceFunctor {
193 template <typename OUT_T, typename IN_T, typename ReductionAxes>
194 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
195 const ReductionAxes& reduction_axes,
196 const Reducer& reducer);
197
198 template <typename OUT_T>
199 static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer);
200};
201
202} // namespace functor
203} // namespace tensorflow
204
205#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
206