1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | #include "tensorflow/core/kernels/segment_reduction_ops_impl.h" |
18 | |
19 | namespace tensorflow { |
20 | namespace internal { |
21 | // Static routines not in the templated class to reduce code size |
22 | Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input, |
23 | const Tensor& segment_ids) { |
24 | if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) { |
25 | return errors::InvalidArgument("input must be at least rank 1" ); |
26 | } |
27 | if (!TensorShapeUtils::IsVector(segment_ids.shape())) { |
28 | return errors::InvalidArgument("segment_ids should be a vector." ); |
29 | } |
30 | const int64_t num_indices = segment_ids.NumElements(); |
31 | if (num_indices != input.dim_size(0)) { |
32 | return errors::InvalidArgument( |
33 | "segment_ids should be the same size as dimension 0 of" |
34 | " input." ); |
35 | } |
36 | |
37 | return OkStatus(); |
38 | } |
39 | |
40 | // check routines not in the templated class to reduce code size |
41 | Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, |
42 | OpKernelContext* context, |
43 | const Tensor& data, |
44 | const Tensor& segment_ids, |
45 | const Tensor& num_segments) { |
46 | if (!TensorShapeUtils::IsScalar(num_segments.shape())) { |
47 | return errors::InvalidArgument( |
48 | "num_segments should be a scalar, not shape " , |
49 | num_segments.shape().DebugString()); |
50 | } |
51 | |
52 | if (!TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape())) { |
53 | return errors::InvalidArgument("data.shape = " , data.shape().DebugString(), |
54 | " does not start with segment_ids.shape = " , |
55 | segment_ids.shape().DebugString()); |
56 | } |
57 | |
58 | return OkStatus(); |
59 | } |
60 | |
61 | Status ValidateSparseSegmentReduction(OpKernelContext* context, |
62 | const Tensor& input, |
63 | const Tensor& indices, |
64 | const Tensor& segment_ids, |
65 | bool has_num_segments) { |
66 | if (has_num_segments) { |
67 | const Tensor& num_segments_t = context->input(3); |
68 | if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) { |
69 | return errors::InvalidArgument( |
70 | "num_segments should be a scalar, not shape " , |
71 | num_segments_t.shape().DebugString()); |
72 | } |
73 | int64_t output_rows = |
74 | internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32 |
75 | ? num_segments_t.scalar<int32>()() |
76 | : num_segments_t.scalar<int64_t>()()); |
77 | if (output_rows < 0) { |
78 | return errors::InvalidArgument("segment ids must be >= 0" ); |
79 | } |
80 | } |
81 | |
82 | if (!TensorShapeUtils::IsVector(indices.shape())) { |
83 | return errors::InvalidArgument("indices should be a vector." ); |
84 | } |
85 | |
86 | if (!TensorShapeUtils::IsVector(segment_ids.shape())) { |
87 | return errors::InvalidArgument("segment_ids should be a vector." ); |
88 | } |
89 | |
90 | const int64_t num_indices = indices.NumElements(); |
91 | if (num_indices != segment_ids.NumElements()) { |
92 | return errors::InvalidArgument( |
93 | "segment_ids and indices should have same size." ); |
94 | } |
95 | |
96 | if (input.dims() < 1) { |
97 | return errors::InvalidArgument("Shape must be at least rank 1" ); |
98 | } |
99 | |
100 | return OkStatus(); |
101 | } |
102 | |
103 | } // namespace internal |
104 | |
105 | #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \ |
106 | default_value) \ |
107 | REGISTER_KERNEL_BUILDER( \ |
108 | Name(name) \ |
109 | .Device(DEVICE_CPU) \ |
110 | .TypeConstraint<type>("T") \ |
111 | .TypeConstraint<index_type>("Tindices"), \ |
112 | SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>) |
113 | |
114 | #define REGISTER_REAL_CPU_KERNELS(type, index_type) \ |
115 | REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ |
116 | type, index_type, 0); \ |
117 | REGISTER_CPU_KERNEL_SEGMENT( \ |
118 | "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \ |
119 | REGISTER_CPU_KERNEL_SEGMENT( \ |
120 | "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \ |
121 | REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \ |
122 | type, index_type, 0); \ |
123 | REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \ |
124 | type, index_type, 0) |
125 | |
126 | #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \ |
127 | REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ |
128 | type, index_type, 0); \ |
129 | REGISTER_CPU_KERNEL_SEGMENT( \ |
130 | "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \ |
131 | REGISTER_CPU_KERNEL_SEGMENT( \ |
132 | "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); |
133 | |
134 | #define REGISTER_REAL_CPU_KERNELS_ALL(type) \ |
135 | REGISTER_REAL_CPU_KERNELS(type, int32) |
136 | |
137 | #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \ |
138 | REGISTER_COMPLEX_CPU_KERNELS(type, int32) |
139 | |
140 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL); |
141 | REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64); |
142 | REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); |
143 | #undef REGISTER_CPU_KERNEL_SEGMENT |
144 | #undef REGISTER_REAL_CPU_KERNELS |
145 | #undef REGISTER_COMPLEX_CPU_KERNELS |
146 | #undef REGISTER_REAL_CPU_KERNELS_ALL |
147 | #undef REGISTER_COMPLEX_CPU_KERNELS_ALL |
148 | |
149 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
150 | #define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ |
151 | name, type, index_type, initial_value_functor, \ |
152 | empty_segment_value_functor, reduction_kernel_functor, is_mean) \ |
153 | REGISTER_KERNEL_BUILDER( \ |
154 | Name(name) \ |
155 | .Device(DEVICE_GPU) \ |
156 | .TypeConstraint<type>("T") \ |
157 | .TypeConstraint<index_type>("Tindices"), \ |
158 | SegmentReductionGPUOp< \ |
159 | type, index_type, \ |
160 | functor::SegmentReductionFunctor< \ |
161 | type, index_type, initial_value_functor, \ |
162 | empty_segment_value_functor, reduction_kernel_functor>, \ |
163 | is_mean>) |
164 | |
165 | #define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ |
166 | REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \ |
167 | functor::Zero<type>, functor::Zero<type>, \ |
168 | functor::Sum, /*is_mean=*/false); \ |
169 | REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type, \ |
170 | functor::Zero<type>, functor::Zero<type>, \ |
171 | functor::Sum, /*is_mean=*/true); \ |
172 | REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \ |
173 | functor::One<type>, functor::One<type>, \ |
174 | functor::Prod, /*is_mean=*/false); \ |
175 | REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ |
176 | "SegmentMin", type, index_type, functor::Highest<type>, \ |
177 | functor::Zero<type>, functor::Min, /*is_mean=*/false); \ |
178 | REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ |
179 | "SegmentMax", type, index_type, functor::Lowest<type>, \ |
180 | functor::Zero<type>, functor::Max, /*is_mean=*/false); |
181 | |
182 | #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ |
183 | REGISTER_GPU_SORTED_KERNELS(type, int32) |
184 | |
185 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); |
186 | #undef REGISTER_GPU_KERNEL_SORTEDSEGMENT |
187 | #undef REGISTER_GPU_SORTED_KERNELS |
188 | #undef REGISTER_GPU_SORTED_KERNELS_ALL |
189 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
190 | |
191 | } // namespace tensorflow |
192 | |