1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17#include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18
19namespace tensorflow {
20
21#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
22 default_value) \
23 REGISTER_KERNEL_BUILDER( \
24 Name(name) \
25 .Device(DEVICE_CPU) \
26 .TypeConstraint<type>("T") \
27 .TypeConstraint<index_type>("Tindices"), \
28 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
29
30#define REGISTER_REAL_CPU_KERNELS(type, index_type) \
31 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
32 type, index_type, 0); \
33 REGISTER_CPU_KERNEL_SEGMENT( \
34 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
35 REGISTER_CPU_KERNEL_SEGMENT( \
36 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
37 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
38 type, index_type, 0); \
39 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
40 type, index_type, 0)
41
42#define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \
43 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
44 type, index_type, 0); \
45 REGISTER_CPU_KERNEL_SEGMENT( \
46 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
47 REGISTER_CPU_KERNEL_SEGMENT( \
48 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
49
50#define REGISTER_REAL_CPU_KERNELS_ALL(type) \
51 REGISTER_REAL_CPU_KERNELS(type, int64_t)
52
53#define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
54 REGISTER_COMPLEX_CPU_KERNELS(type, int64_t)
55
56TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
57REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
58REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
59#undef REGISTER_CPU_KERNEL_SEGMENT
60#undef REGISTER_REAL_CPU_KERNELS
61#undef REGISTER_COMPLEX_CPU_KERNELS
62#undef REGISTER_REAL_CPU_KERNELS_ALL
63#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
64
65#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
66#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
67 name, type, index_type, initial_value_functor, \
68 empty_segment_value_functor, reduction_kernel_functor, is_mean) \
69 REGISTER_KERNEL_BUILDER( \
70 Name(name) \
71 .Device(DEVICE_GPU) \
72 .TypeConstraint<type>("T") \
73 .TypeConstraint<index_type>("Tindices"), \
74 SegmentReductionGPUOp< \
75 type, index_type, \
76 functor::SegmentReductionFunctor< \
77 type, index_type, initial_value_functor, \
78 empty_segment_value_functor, reduction_kernel_functor>, \
79 is_mean>)
80
81#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
82 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
83 functor::Zero<type>, functor::Zero<type>, \
84 functor::Sum, /*is_mean=*/false); \
85 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type, \
86 functor::Zero<type>, functor::Zero<type>, \
87 functor::Sum, /*is_mean=*/true); \
88 REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
89 functor::One<type>, functor::One<type>, \
90 functor::Prod, /*is_mean=*/false); \
91 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
92 "SegmentMin", type, index_type, functor::Highest<type>, \
93 functor::Zero<type>, functor::Min, /*is_mean=*/false); \
94 REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
95 "SegmentMax", type, index_type, functor::Lowest<type>, \
96 functor::Zero<type>, functor::Max, /*is_mean=*/false);
97
98#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
99 REGISTER_GPU_SORTED_KERNELS(type, int64_t);
100
101TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
102#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
103#undef REGISTER_GPU_SORTED_KERNELS
104#undef REGISTER_GPU_SORTED_KERNELS_ALL
105#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
106
107} // namespace tensorflow
108