1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/framework/tensor_types.h"
23
24namespace tensorflow {
25
26class OpKernelContext;
27
28bool UseDeterministicSegmentReductions();
29bool DisableSegmentReductionOpDeterminismExceptions();
30
31// Type of SparseSegmentReduction operation to perform gradient of.
32enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN };
33
34namespace functor {
35
36#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37
38// Note that we define this ourselves to avoid a dependency on gpuprim.
39struct Sum {
40 template <typename T>
41 __host__ __device__ T operator()(const T& a, const T& b) const {
42 return a + b;
43 }
44};
45
46struct Prod {
47 template <typename T>
48 __host__ __device__ T operator()(const T& a, const T& b) const {
49 return a * b;
50 }
51};
52
53// Note that we don't use gpuprim::Min/Max because they use operator<, which is
54// not implemented for AlignedVector types.
55struct Min {
56 template <typename T>
57 __host__ __device__ T operator()(const T& a, const T& b) const {
58 return min(a, b);
59 }
60};
61
62struct Max {
63 template <typename T>
64 __host__ __device__ T operator()(const T& a, const T& b) const {
65 return max(a, b);
66 }
67};
68
69template <typename ReduceOp, typename T>
70struct ReduceOpIsAssociative {};
71template <typename T>
72struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
73template <typename T>
74struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
75template <typename T>
76struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
77template <typename T>
78struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};
79
80typedef Eigen::GpuDevice GPUDevice;
81// Functor for SegmentReductionGPUOp.
82// output_rows: the number of output segments (unique segment ids in
83// 'segment_ids').
84// segment_ids_shape: shape of 'segment_ids' tensor.
85// segment_ids: unsorted map from input to output segment ids at which to
86// perform segment sum operation.
87// data_size: size of input data tensor.
88// data: input data tensor.
89// output: output reshaped to {output_rows, output.size/output_rows}
90template <typename T, typename Index, typename InitialValueF,
91 typename EmptySegmentValueF, typename ReductionF>
92struct SegmentReductionFunctor {
93 void operator()(OpKernelContext* ctx, const GPUDevice& d,
94 const Index output_rows, const TensorShape& segment_ids_shape,
95 bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
96 const Index data_size, const T* data,
97 typename TTypes<T, 2>::Tensor output);
98 static constexpr bool atomic_reduction_is_associative =
99 ReduceOpIsAssociative<ReductionF, T>::value;
100};
101
102#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
103
104template <typename Device, typename T, typename Index, typename InitialValueF,
105 typename ReductionF>
106struct UnsortedSegmentFunctor {
107 void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
108 typename TTypes<Index>::ConstFlat segment_ids,
109 typename TTypes<T, 2>::ConstTensor data,
110 typename TTypes<T, 2>::Tensor output);
111};
112
113// Initial value functors.
114template <typename T>
115struct Zero {
116 EIGEN_STRONG_INLINE T operator()() const { return T(0); }
117};
118
119template <typename T>
120struct One {
121 EIGEN_STRONG_INLINE T operator()() const { return T(1); }
122};
123
124template <typename T>
125struct Lowest {
126 EIGEN_STRONG_INLINE T operator()() const {
127 return Eigen::NumTraits<T>::lowest();
128 }
129};
130
131template <typename T>
132struct Highest {
133 EIGEN_STRONG_INLINE T operator()() const {
134 return Eigen::NumTraits<T>::highest();
135 }
136};
137
138template <typename T, typename Index, typename SegmentId>
139struct SparseSegmentReductionFunctor {
140 Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn,
141 T default_value, typename TTypes<T, 2>::ConstTensor input,
142 typename TTypes<Index>::ConstVec indices,
143 typename TTypes<SegmentId>::ConstVec segment_ids,
144 typename TTypes<T, 2>::Tensor output);
145};
146
147template <class Device, typename T, typename Index, typename SegmentId>
148struct SparseSegmentGradFunctor {
149 void operator()(OpKernelContext* context,
150 SparseSegmentReductionOperation operation,
151 typename TTypes<T>::ConstMatrix input_flat,
152 typename TTypes<Index>::ConstVec indices_vec,
153 typename TTypes<SegmentId>::ConstVec segment_vec,
154 typename TTypes<T>::Matrix output_flat);
155};
156
157} // namespace functor
158} // namespace tensorflow
159
160#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
161