1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | #include "tensorflow/core/kernels/segment_reduction_ops_impl.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \ |
22 | name, type, index_type, initial_value_functor, reduction_functor) \ |
23 | REGISTER_KERNEL_BUILDER( \ |
24 | Name(name) \ |
25 | .Device(DEVICE_CPU) \ |
26 | .TypeConstraint<type>("T") \ |
27 | .TypeConstraint<index_type>("Tindices"), \ |
28 | UnsortedSegmentReductionOp< \ |
29 | type, index_type, \ |
30 | functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \ |
31 | initial_value_functor, \ |
32 | reduction_functor> >) |
33 | |
34 | #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ |
35 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ |
36 | functor::Zero<type>, \ |
37 | functor::SumOp<type>); \ |
38 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ |
39 | functor::Lowest<type>, \ |
40 | functor::MaxOp<type>); \ |
41 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ |
42 | functor::Highest<type>, \ |
43 | functor::MinOp<type>); \ |
44 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ |
45 | functor::One<type>, \ |
46 | functor::ProdOp<type>); |
47 | |
48 | #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ |
49 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ |
50 | functor::Zero<type>, \ |
51 | functor::SumOp<type>); \ |
52 | REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ |
53 | functor::One<type>, \ |
54 | functor::ProdOp<type>) |
55 | |
56 | #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ |
57 | REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32) |
58 | |
59 | #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \ |
60 | REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32) |
61 | |
62 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); |
63 | REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); |
64 | REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); |
65 | |
66 | #undef REGISTER_REAL_CPU_UNSORTED_KERNELS |
67 | #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT |
68 | #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS |
69 | #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL |
70 | #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL |
71 | |
72 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
73 | #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \ |
74 | name, type, index_type, initial_value_functor, reduction_kernel_functor) \ |
75 | REGISTER_KERNEL_BUILDER( \ |
76 | Name(name) \ |
77 | .Device(DEVICE_GPU) \ |
78 | .HostMemory("num_segments") \ |
79 | .TypeConstraint<type>("T") \ |
80 | .TypeConstraint<index_type>("Tindices"), \ |
81 | UnsortedSegmentReductionOp< \ |
82 | type, index_type, \ |
83 | functor::UnsortedSegmentFunctor<GPUDevice, type, index_type, \ |
84 | initial_value_functor, \ |
85 | reduction_kernel_functor> >) |
86 | |
87 | // sum is the only op that supports all input types currently |
88 | #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ |
89 | REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ |
90 | functor::Lowest<type>, functor::Max); \ |
91 | REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ |
92 | functor::Highest<type>, functor::Min); \ |
93 | REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ |
94 | functor::One<type>, functor::Prod); |
95 | |
96 | #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ |
97 | REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ |
98 | functor::Zero<type>, functor::Sum); |
99 | |
100 | #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ |
101 | REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32) |
102 | |
103 | #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \ |
104 | REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32) |
105 | |
106 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); |
107 | TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); |
108 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); |
109 | TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); |
110 | // TODO(rocm): support atomicAdd for complex numbers on ROCm |
111 | #if GOOGLE_CUDA |
112 | TF_CALL_COMPLEX_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); |
113 | #endif |
114 | |
115 | #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT |
116 | #undef REGISTER_REAL_GPU_UNSORTED_KERNELS |
117 | #undef REGISTER_SUM_GPU_UNSORTED_KERNELS |
118 | #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL |
119 | #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL |
120 | |
121 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
122 | |
123 | } // namespace tensorflow |
124 | |