1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17#include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18
19namespace tensorflow {
20namespace internal {
21// Static routines not in the templated class to reduce code size
22Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input,
23 const Tensor& segment_ids) {
24 if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) {
25 return errors::InvalidArgument("input must be at least rank 1");
26 }
27 if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
28 return errors::InvalidArgument("segment_ids should be a vector.");
29 }
30 const int64_t num_indices = segment_ids.NumElements();
31 if (num_indices != input.dim_size(0)) {
32 return errors::InvalidArgument(
33 "segment_ids should be the same size as dimension 0 of"
34 " input.");
35 }
36
37 return OkStatus();
38}
39
40// check routines not in the templated class to reduce code size
41Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
42 OpKernelContext* context,
43 const Tensor& data,
44 const Tensor& segment_ids,
45 const Tensor& num_segments) {
46 if (!TensorShapeUtils::IsScalar(num_segments.shape())) {
47 return errors::InvalidArgument(
48 "num_segments should be a scalar, not shape ",
49 num_segments.shape().DebugString());
50 }
51
52 if (!TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape())) {
53 return errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
54 " does not start with segment_ids.shape = ",
55 segment_ids.shape().DebugString());
56 }
57
58 return OkStatus();
59}
60
61Status ValidateSparseSegmentReduction(OpKernelContext* context,
62 const Tensor& input,
63 const Tensor& indices,
64 const Tensor& segment_ids,
65 bool has_num_segments) {
66 if (has_num_segments) {
67 const Tensor& num_segments_t = context->input(3);
68 if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) {
69 return errors::InvalidArgument(
70 "num_segments should be a scalar, not shape ",
71 num_segments_t.shape().DebugString());
72 }
73 int64_t output_rows =
74 internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32
75 ? num_segments_t.scalar<int32>()()
76 : num_segments_t.scalar<int64_t>()());
77 if (output_rows < 0) {
78 return errors::InvalidArgument("segment ids must be >= 0");
79 }
80 }
81
82 if (!TensorShapeUtils::IsVector(indices.shape())) {
83 return errors::InvalidArgument("indices should be a vector.");
84 }
85
86 if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
87 return errors::InvalidArgument("segment_ids should be a vector.");
88 }
89
90 const int64_t num_indices = indices.NumElements();
91 if (num_indices != segment_ids.NumElements()) {
92 return errors::InvalidArgument(
93 "segment_ids and indices should have same size.");
94 }
95
96 if (input.dims() < 1) {
97 return errors::InvalidArgument("Shape must be at least rank 1");
98 }
99
100 return OkStatus();
101}
102
103} // namespace internal
104
105#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
106 default_value) \
107 REGISTER_KERNEL_BUILDER( \
108 Name(name) \
109 .Device(DEVICE_CPU) \
110 .TypeConstraint<type>("T") \
111 .TypeConstraint<index_type>("Tindices"), \
112 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
113
114#define REGISTER_REAL_CPU_KERNELS(type, index_type) \
115 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
116 type, index_type, 0); \
117 REGISTER_CPU_KERNEL_SEGMENT( \
118 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
119 REGISTER_CPU_KERNEL_SEGMENT( \
120 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
121 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
122 type, index_type, 0); \
123 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
124 type, index_type, 0)
125
126#define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \
127 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
128 type, index_type, 0); \
129 REGISTER_CPU_KERNEL_SEGMENT( \
130 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
131 REGISTER_CPU_KERNEL_SEGMENT( \
132 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
133
134#define REGISTER_REAL_CPU_KERNELS_ALL(type) \
135 REGISTER_REAL_CPU_KERNELS(type, int32)
136
137#define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
138 REGISTER_COMPLEX_CPU_KERNELS(type, int32)
139
140TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
141REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
142REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
143#undef REGISTER_CPU_KERNEL_SEGMENT
144#undef REGISTER_REAL_CPU_KERNELS
145#undef REGISTER_COMPLEX_CPU_KERNELS
146#undef REGISTER_REAL_CPU_KERNELS_ALL
147#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
148
149#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
150#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
151 name, type, index_type, initial_value_functor, \
152 empty_segment_value_functor, reduction_kernel_functor, is_mean) \
153 REGISTER_KERNEL_BUILDER( \
154 Name(name) \
155 .Device(DEVICE_GPU) \
156 .TypeConstraint<type>("T") \
157 .TypeConstraint<index_type>("Tindices"), \
158 SegmentReductionGPUOp< \
159 type, index_type, \
160 functor::SegmentReductionFunctor< \
161 type, index_type, initial_value_functor, \
162 empty_segment_value_functor, reduction_kernel_functor>, \
163 is_mean>)
164
165#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
166 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
167 functor::Zero<type>, functor::Zero<type>, \
168 functor::Sum, /*is_mean=*/false); \
169 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type, \
170 functor::Zero<type>, functor::Zero<type>, \
171 functor::Sum, /*is_mean=*/true); \
172 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
173 functor::One<type>, functor::One<type>, \
174 functor::Prod, /*is_mean=*/false); \
175 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
176 "SegmentMin", type, index_type, functor::Highest<type>, \
177 functor::Zero<type>, functor::Min, /*is_mean=*/false); \
178 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
179 "SegmentMax", type, index_type, functor::Lowest<type>, \
180 functor::Zero<type>, functor::Max, /*is_mean=*/false);
181
182#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
183 REGISTER_GPU_SORTED_KERNELS(type, int32)
184
185TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
186#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
187#undef REGISTER_GPU_SORTED_KERNELS
188#undef REGISTER_GPU_SORTED_KERNELS_ALL
189#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
190
191} // namespace tensorflow
192