1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17#include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18
19namespace tensorflow {
20
21#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
22 REGISTER_CPU_SPARSE_KERNELS(type, index_type, int32) \
23 REGISTER_CPU_SPARSE_KERNELS(type, index_type, int64_t)
24#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type) \
25 REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
26 REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64_t)
27
28#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
29 REGISTER_KERNEL_BUILDER( \
30 Name("SparseSegmentSum") \
31 .Device(DEVICE_CPU) \
32 .TypeConstraint<type>("T") \
33 .TypeConstraint<index_type>("Tidx") \
34 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
35 SparseSegmentReductionSumOp<CPUDevice, type, index_type, \
36 segment_ids_type>); \
37 REGISTER_KERNEL_BUILDER( \
38 Name("SparseSegmentSumWithNumSegments") \
39 .Device(DEVICE_CPU) \
40 .TypeConstraint<type>("T") \
41 .TypeConstraint<index_type>("Tidx") \
42 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
43 SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type, index_type, \
44 segment_ids_type>);
45TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
46#undef REGISTER_CPU_SPARSE_KERNELS
47
48#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
49 REGISTER_KERNEL_BUILDER( \
50 Name("SparseSegmentMean") \
51 .Device(DEVICE_CPU) \
52 .TypeConstraint<type>("T") \
53 .TypeConstraint<index_type>("Tidx") \
54 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
55 SparseSegmentReductionMeanOp<CPUDevice, type, index_type, \
56 segment_ids_type>); \
57 REGISTER_KERNEL_BUILDER( \
58 Name("SparseSegmentMeanWithNumSegments") \
59 .Device(DEVICE_CPU) \
60 .TypeConstraint<type>("T") \
61 .TypeConstraint<index_type>("Tidx") \
62 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
63 SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type, index_type, \
64 segment_ids_type>);
65TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
66#undef REGISTER_CPU_SPARSE_KERNELS
67
68#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
69 REGISTER_KERNEL_BUILDER( \
70 Name("SparseSegmentSqrtN") \
71 .Device(DEVICE_CPU) \
72 .TypeConstraint<type>("T") \
73 .TypeConstraint<index_type>("Tidx") \
74 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
75 SparseSegmentReductionSqrtNOp<CPUDevice, type, index_type, \
76 segment_ids_type>); \
77 REGISTER_KERNEL_BUILDER( \
78 Name("SparseSegmentSqrtNWithNumSegments") \
79 .Device(DEVICE_CPU) \
80 .TypeConstraint<type>("T") \
81 .TypeConstraint<index_type>("Tidx") \
82 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
83 SparseSegmentReductionSqrtNWithNumSegmentsOp< \
84 CPUDevice, type, index_type, segment_ids_type>);
85TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
86#undef REGISTER_CPU_SPARSE_KERNELS
87
88// TODO(benbarsdell): These kernels are disabled on Windows as a workaround for
89// a CI build error: "formal parameter with requested alignment of 128 won't be
90// aligned". The root cause is suspected to be an aligned type (AlignedVector)
91// being passed to a function by value, possibly inside the CUB library
92// somewhere, but I have not yet been able to reproduce it in isolation outside
93// of the GitHub CI.
94#if GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
95
96#define REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
97 REGISTER_GPU_SPARSE_KERNELS(type, index_type, int32) \
98 REGISTER_GPU_SPARSE_KERNELS(type, index_type, int64_t)
99#define REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type) \
100 REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
101 REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64_t)
102
103#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
104 REGISTER_KERNEL_BUILDER( \
105 Name("SparseSegmentSum") \
106 .Device(DEVICE_GPU) \
107 .TypeConstraint<type>("T") \
108 .TypeConstraint<index_type>("Tidx") \
109 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
110 SparseSegmentReductionSumOp<GPUDevice, type, index_type, \
111 segment_ids_type>); \
112 REGISTER_KERNEL_BUILDER( \
113 Name("SparseSegmentSumWithNumSegments") \
114 .Device(DEVICE_GPU) \
115 .HostMemory("num_segments") \
116 .TypeConstraint<type>("T") \
117 .TypeConstraint<index_type>("Tidx") \
118 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
119 SparseSegmentReductionSumWithNumSegmentsOp<GPUDevice, type, index_type, \
120 segment_ids_type>);
121TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
122#undef REGISTER_GPU_SPARSE_KERNELS
123
124#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
125 REGISTER_KERNEL_BUILDER( \
126 Name("SparseSegmentMean") \
127 .Device(DEVICE_GPU) \
128 .TypeConstraint<type>("T") \
129 .TypeConstraint<index_type>("Tidx") \
130 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
131 SparseSegmentReductionMeanOp<GPUDevice, type, index_type, \
132 segment_ids_type>); \
133 REGISTER_KERNEL_BUILDER( \
134 Name("SparseSegmentMeanWithNumSegments") \
135 .Device(DEVICE_GPU) \
136 .HostMemory("num_segments") \
137 .TypeConstraint<type>("T") \
138 .TypeConstraint<index_type>("Tidx") \
139 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
140 SparseSegmentReductionMeanWithNumSegmentsOp<GPUDevice, type, index_type, \
141 segment_ids_type>);
142TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
143#undef REGISTER_GPU_SPARSE_KERNELS
144
145#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
146 REGISTER_KERNEL_BUILDER( \
147 Name("SparseSegmentSqrtN") \
148 .Device(DEVICE_GPU) \
149 .TypeConstraint<type>("T") \
150 .TypeConstraint<index_type>("Tidx") \
151 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
152 SparseSegmentReductionSqrtNOp<GPUDevice, type, index_type, \
153 segment_ids_type>); \
154 REGISTER_KERNEL_BUILDER( \
155 Name("SparseSegmentSqrtNWithNumSegments") \
156 .Device(DEVICE_GPU) \
157 .HostMemory("num_segments") \
158 .TypeConstraint<type>("T") \
159 .TypeConstraint<index_type>("Tidx") \
160 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
161 SparseSegmentReductionSqrtNWithNumSegmentsOp< \
162 GPUDevice, type, index_type, segment_ids_type>);
163TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
164#undef REGISTER_GPU_SPARSE_KERNELS
165
166#endif // GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
167
168#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
169 REGISTER_KERNEL_BUILDER( \
170 Name("SparseSegmentSumGrad") \
171 .Device(DEVICE_CPU) \
172 .TypeConstraint<type>("T") \
173 .TypeConstraint<index_type>("Tidx") \
174 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
175 SparseSegmentSumGradOp<CPUDevice, type, index_type, segment_ids_type>);
176TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
177#undef REGISTER_CPU_SPARSE_KERNELS
178
179#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
180 REGISTER_KERNEL_BUILDER( \
181 Name("SparseSegmentMeanGrad") \
182 .Device(DEVICE_CPU) \
183 .TypeConstraint<type>("T") \
184 .TypeConstraint<index_type>("Tidx") \
185 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
186 SparseSegmentMeanGradOp<CPUDevice, type, index_type, segment_ids_type>);
187TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
188#undef REGISTER_CPU_SPARSE_KERNELS
189
190#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
191 REGISTER_KERNEL_BUILDER( \
192 Name("SparseSegmentSqrtNGrad") \
193 .Device(DEVICE_CPU) \
194 .TypeConstraint<type>("T") \
195 .TypeConstraint<index_type>("Tidx") \
196 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
197 SparseSegmentSqrtNGradOp<CPUDevice, type, index_type, \
198 segment_ids_type>);
199TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
200#undef REGISTER_CPU_SPARSE_KERNELS
201
202#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
203#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
204
205// TODO(benbarsdell): See comment above.
206#if GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
207
208#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
209 REGISTER_KERNEL_BUILDER( \
210 Name("SparseSegmentSumGrad") \
211 .Device(DEVICE_GPU) \
212 .HostMemory("output_dim0") \
213 .TypeConstraint<type>("T") \
214 .TypeConstraint<index_type>("Tidx") \
215 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
216 SparseSegmentSumGradOp<GPUDevice, type, index_type, segment_ids_type>);
217TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
218#undef REGISTER_GPU_SPARSE_KERNELS
219
220#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
221 REGISTER_KERNEL_BUILDER( \
222 Name("SparseSegmentMeanGrad") \
223 .Device(DEVICE_GPU) \
224 .HostMemory("output_dim0") \
225 .TypeConstraint<type>("T") \
226 .TypeConstraint<index_type>("Tidx") \
227 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
228 SparseSegmentMeanGradOp<GPUDevice, type, index_type, segment_ids_type>);
229TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
230#undef REGISTER_GPU_SPARSE_KERNELS
231
232#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
233 REGISTER_KERNEL_BUILDER( \
234 Name("SparseSegmentSqrtNGrad") \
235 .Device(DEVICE_GPU) \
236 .HostMemory("output_dim0") \
237 .TypeConstraint<type>("T") \
238 .TypeConstraint<index_type>("Tidx") \
239 .TypeConstraint<segment_ids_type>("Tsegmentids"), \
240 SparseSegmentSqrtNGradOp<GPUDevice, type, index_type, \
241 segment_ids_type>);
242TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
243#undef REGISTER_GPU_SPARSE_KERNELS
244
245#undef REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
246#undef REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
247
248#endif // GOOGLE_CUDA && !defined(PLATFORM_WINDOWS)
249
250} // namespace tensorflow
251