1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/framework/tensor_types.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | class OpKernelContext; |
27 | |
28 | bool UseDeterministicSegmentReductions(); |
29 | bool DisableSegmentReductionOpDeterminismExceptions(); |
30 | |
31 | // Type of SparseSegmentReduction operation to perform gradient of. |
32 | enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN }; |
33 | |
34 | namespace functor { |
35 | |
36 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
37 | |
38 | // Note that we define this ourselves to avoid a dependency on gpuprim. |
39 | struct Sum { |
40 | template <typename T> |
41 | __host__ __device__ T operator()(const T& a, const T& b) const { |
42 | return a + b; |
43 | } |
44 | }; |
45 | |
46 | struct Prod { |
47 | template <typename T> |
48 | __host__ __device__ T operator()(const T& a, const T& b) const { |
49 | return a * b; |
50 | } |
51 | }; |
52 | |
53 | // Note that we don't use gpuprim::Min/Max because they use operator<, which is |
54 | // not implemented for AlignedVector types. |
55 | struct Min { |
56 | template <typename T> |
57 | __host__ __device__ T operator()(const T& a, const T& b) const { |
58 | return min(a, b); |
59 | } |
60 | }; |
61 | |
62 | struct Max { |
63 | template <typename T> |
64 | __host__ __device__ T operator()(const T& a, const T& b) const { |
65 | return max(a, b); |
66 | } |
67 | }; |
68 | |
69 | template <typename ReduceOp, typename T> |
70 | struct ReduceOpIsAssociative {}; |
71 | template <typename T> |
72 | struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {}; |
73 | template <typename T> |
74 | struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {}; |
75 | template <typename T> |
76 | struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {}; |
77 | template <typename T> |
78 | struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {}; |
79 | |
80 | typedef Eigen::GpuDevice GPUDevice; |
81 | // Functor for SegmentReductionGPUOp. |
82 | // output_rows: the number of output segments (unique segment ids in |
83 | // 'segment_ids'). |
84 | // segment_ids_shape: shape of 'segment_ids' tensor. |
85 | // segment_ids: unsorted map from input to output segment ids at which to |
86 | // perform segment sum operation. |
87 | // data_size: size of input data tensor. |
88 | // data: input data tensor. |
89 | // output: output reshaped to {output_rows, output.size/output_rows} |
90 | template <typename T, typename Index, typename InitialValueF, |
91 | typename EmptySegmentValueF, typename ReductionF> |
92 | struct SegmentReductionFunctor { |
93 | void operator()(OpKernelContext* ctx, const GPUDevice& d, |
94 | const Index output_rows, const TensorShape& segment_ids_shape, |
95 | bool is_mean, typename TTypes<Index>::ConstFlat segment_ids, |
96 | const Index data_size, const T* data, |
97 | typename TTypes<T, 2>::Tensor output); |
98 | static constexpr bool atomic_reduction_is_associative = |
99 | ReduceOpIsAssociative<ReductionF, T>::value; |
100 | }; |
101 | |
102 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
103 | |
104 | template <typename Device, typename T, typename Index, typename InitialValueF, |
105 | typename ReductionF> |
106 | struct UnsortedSegmentFunctor { |
107 | void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, |
108 | typename TTypes<Index>::ConstFlat segment_ids, |
109 | typename TTypes<T, 2>::ConstTensor data, |
110 | typename TTypes<T, 2>::Tensor output); |
111 | }; |
112 | |
113 | // Initial value functors. |
114 | template <typename T> |
115 | struct Zero { |
116 | EIGEN_STRONG_INLINE T operator()() const { return T(0); } |
117 | }; |
118 | |
119 | template <typename T> |
120 | struct One { |
121 | EIGEN_STRONG_INLINE T operator()() const { return T(1); } |
122 | }; |
123 | |
124 | template <typename T> |
125 | struct Lowest { |
126 | EIGEN_STRONG_INLINE T operator()() const { |
127 | return Eigen::NumTraits<T>::lowest(); |
128 | } |
129 | }; |
130 | |
131 | template <typename T> |
132 | struct Highest { |
133 | EIGEN_STRONG_INLINE T operator()() const { |
134 | return Eigen::NumTraits<T>::highest(); |
135 | } |
136 | }; |
137 | |
138 | template <typename T, typename Index, typename SegmentId> |
139 | struct SparseSegmentReductionFunctor { |
140 | Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn, |
141 | T default_value, typename TTypes<T, 2>::ConstTensor input, |
142 | typename TTypes<Index>::ConstVec indices, |
143 | typename TTypes<SegmentId>::ConstVec segment_ids, |
144 | typename TTypes<T, 2>::Tensor output); |
145 | }; |
146 | |
147 | template <class Device, typename T, typename Index, typename SegmentId> |
148 | struct SparseSegmentGradFunctor { |
149 | void operator()(OpKernelContext* context, |
150 | SparseSegmentReductionOperation operation, |
151 | typename TTypes<T>::ConstMatrix input_flat, |
152 | typename TTypes<Index>::ConstVec indices_vec, |
153 | typename TTypes<SegmentId>::ConstVec segment_vec, |
154 | typename TTypes<T>::Matrix output_flat); |
155 | }; |
156 | |
157 | } // namespace functor |
158 | } // namespace tensorflow |
159 | |
160 | #endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ |
161 | |