1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17#include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18
19namespace tensorflow {
20
21#define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \
22 name, type, index_type, initial_value_functor, reduction_functor) \
23 REGISTER_KERNEL_BUILDER( \
24 Name(name) \
25 .Device(DEVICE_CPU) \
26 .TypeConstraint<type>("T") \
27 .TypeConstraint<index_type>("Tindices"), \
28 UnsortedSegmentReductionOp< \
29 type, index_type, \
30 functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
31 initial_value_functor, \
32 reduction_functor> >)
33
34#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
35 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
36 functor::Zero<type>, \
37 functor::SumOp<type>); \
38 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
39 functor::Lowest<type>, \
40 functor::MaxOp<type>); \
41 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
42 functor::Highest<type>, \
43 functor::MinOp<type>); \
44 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
45 functor::One<type>, \
46 functor::ProdOp<type>);
47
48#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
49 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
50 functor::Zero<type>, \
51 functor::SumOp<type>); \
52 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
53 functor::One<type>, \
54 functor::ProdOp<type>)
55
56#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
57 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64_t)
58
59#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
60 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64_t)
61
62TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
63REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
64REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
65
66#undef REGISTER_REAL_CPU_UNSORTED_KERNELS
67#undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
68#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
69#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
70#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
71
72#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
73#define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \
74 name, type, index_type, initial_value_functor, reduction_kernel_functor) \
75 REGISTER_KERNEL_BUILDER( \
76 Name(name) \
77 .Device(DEVICE_GPU) \
78 .HostMemory("num_segments") \
79 .TypeConstraint<type>("T") \
80 .TypeConstraint<index_type>("Tindices"), \
81 UnsortedSegmentReductionOp< \
82 type, index_type, \
83 functor::UnsortedSegmentFunctor<GPUDevice, type, index_type, \
84 initial_value_functor, \
85 reduction_kernel_functor> >)
86
87// sum is the only op that supports all input types currently
88#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
89 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
90 functor::Lowest<type>, functor::Max); \
91 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
92 functor::Highest<type>, functor::Min); \
93 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
94 functor::One<type>, functor::Prod);
95
96#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
97 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
98 functor::Zero<type>, functor::Sum);
99
100#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
101 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64_t)
102
103#define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
104 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64_t)
105
106TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
107TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
108TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
109TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
110// TODO(rocm): support atomicAdd for complex numbers on ROCm
111#if GOOGLE_CUDA
112TF_CALL_COMPLEX_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
113#endif
114
115#undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
116#undef REGISTER_REAL_GPU_UNSORTED_KERNELS
117#undef REGISTER_SUM_GPU_UNSORTED_KERNELS
118#undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
119#undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
120
121#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
122
123} // namespace tensorflow
124