1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
18
19#include <numeric>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/tensor_types.h"
23
24namespace tensorflow {
25namespace functor {
26
27// Functor definitions for Aggregate ops, must be compilable by nvcc.
28template <typename Device, typename T>
29struct Add2Functor {
30 void operator()(const Device& d, typename TTypes<T>::Flat out,
31 typename TTypes<T>::ConstFlat in1,
32 typename TTypes<T>::ConstFlat in2);
33};
34
35template <typename Device, typename T>
36struct Add2EigenImpl {
37 static void Compute(const Device& d, typename TTypes<T>::Flat out,
38 typename TTypes<T>::ConstFlat in1,
39 typename TTypes<T>::ConstFlat in2) {
40 out.device(d) = in1 + in2;
41 }
42};
43
44template <typename Device, typename T>
45struct Add3Functor {
46 void operator()(const Device& d, typename TTypes<T>::Flat out,
47 typename TTypes<T>::ConstFlat in1,
48 typename TTypes<T>::ConstFlat in2,
49 typename TTypes<T>::ConstFlat in3);
50};
51
52template <typename Device, typename T>
53struct Add3EigenImpl {
54 static void Compute(const Device& d, typename TTypes<T>::Flat out,
55 typename TTypes<T>::ConstFlat in1,
56 typename TTypes<T>::ConstFlat in2,
57 typename TTypes<T>::ConstFlat in3) {
58 out.device(d) = in1 + in2 + in3;
59 }
60};
61
62template <typename Device, typename T>
63struct Add4Functor {
64 void operator()(const Device& d, typename TTypes<T>::Flat out,
65 typename TTypes<T>::ConstFlat in1,
66 typename TTypes<T>::ConstFlat in2,
67 typename TTypes<T>::ConstFlat in3,
68 typename TTypes<T>::ConstFlat in4);
69};
70
71template <typename Device, typename T>
72struct Add4EigenImpl {
73 static void Compute(const Device& d, typename TTypes<T>::Flat out,
74 typename TTypes<T>::ConstFlat in1,
75 typename TTypes<T>::ConstFlat in2,
76 typename TTypes<T>::ConstFlat in3,
77 typename TTypes<T>::ConstFlat in4) {
78 out.device(d) = in1 + in2 + in3 + in4;
79 }
80};
81
82template <typename Device, typename T>
83struct Add5Functor {
84 void operator()(const Device& d, typename TTypes<T>::Flat out,
85 typename TTypes<T>::ConstFlat in1,
86 typename TTypes<T>::ConstFlat in2,
87 typename TTypes<T>::ConstFlat in3,
88 typename TTypes<T>::ConstFlat in4,
89 typename TTypes<T>::ConstFlat in5);
90};
91
92template <typename Device, typename T>
93struct Add5EigenImpl {
94 static void Compute(const Device& d, typename TTypes<T>::Flat out,
95 typename TTypes<T>::ConstFlat in1,
96 typename TTypes<T>::ConstFlat in2,
97 typename TTypes<T>::ConstFlat in3,
98 typename TTypes<T>::ConstFlat in4,
99 typename TTypes<T>::ConstFlat in5) {
100 out.device(d) = in1 + in2 + in3 + in4 + in5;
101 }
102};
103
104template <typename Device, typename T>
105struct Add6Functor {
106 void operator()(const Device& d, typename TTypes<T>::Flat out,
107 typename TTypes<T>::ConstFlat in1,
108 typename TTypes<T>::ConstFlat in2,
109 typename TTypes<T>::ConstFlat in3,
110 typename TTypes<T>::ConstFlat in4,
111 typename TTypes<T>::ConstFlat in5,
112 typename TTypes<T>::ConstFlat in6);
113};
114
115template <typename Device, typename T>
116struct Add6EigenImpl {
117 static void Compute(const Device& d, typename TTypes<T>::Flat out,
118 typename TTypes<T>::ConstFlat in1,
119 typename TTypes<T>::ConstFlat in2,
120 typename TTypes<T>::ConstFlat in3,
121 typename TTypes<T>::ConstFlat in4,
122 typename TTypes<T>::ConstFlat in5,
123 typename TTypes<T>::ConstFlat in6) {
124 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
125 }
126};
127
128template <typename Device, typename T>
129struct Add7Functor {
130 void operator()(const Device& d, typename TTypes<T>::Flat out,
131 typename TTypes<T>::ConstFlat in1,
132 typename TTypes<T>::ConstFlat in2,
133 typename TTypes<T>::ConstFlat in3,
134 typename TTypes<T>::ConstFlat in4,
135 typename TTypes<T>::ConstFlat in5,
136 typename TTypes<T>::ConstFlat in6,
137 typename TTypes<T>::ConstFlat in7);
138};
139
140template <typename Device, typename T>
141struct Add7EigenImpl {
142 static void Compute(const Device& d, typename TTypes<T>::Flat out,
143 typename TTypes<T>::ConstFlat in1,
144 typename TTypes<T>::ConstFlat in2,
145 typename TTypes<T>::ConstFlat in3,
146 typename TTypes<T>::ConstFlat in4,
147 typename TTypes<T>::ConstFlat in5,
148 typename TTypes<T>::ConstFlat in6,
149 typename TTypes<T>::ConstFlat in7) {
150 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
151 }
152};
153
154template <typename Device, typename T>
155struct Add8Functor {
156 void operator()(
157 const Device& d, typename TTypes<T>::Flat out,
158 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
159 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
160 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
161 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
162};
163
164template <typename Device, typename T>
165struct Add8EigenImpl {
166 static void Compute(
167 const Device& d, typename TTypes<T>::Flat out,
168 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
169 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
170 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
171 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
172 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
173 }
174};
175
176// Add8p is like Add8 except the underlying implementation should +=
177// rather than assign to the output.
178template <typename Device, typename T>
179struct Add8pFunctor {
180 void operator()(
181 const Device& d, typename TTypes<T>::Flat out,
182 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
183 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
184 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
185 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
186};
187
188template <typename Device, typename T>
189struct Add8pEigenImpl {
190 static void Compute(
191 const Device& d, typename TTypes<T>::Flat out,
192 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
193 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
194 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
195 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
196 out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
197 }
198};
199
200template <typename Device, typename T>
201struct Add9Functor {
202 void operator()(
203 const Device& d, typename TTypes<T>::Flat out,
204 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
205 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
206 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
207 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
208 typename TTypes<T>::ConstFlat in9);
209};
210
211template <typename Device, typename T>
212struct Add9EigenImpl {
213 static void Compute(
214 const Device& d, typename TTypes<T>::Flat out,
215 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
216 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
217 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
218 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
219 typename TTypes<T>::ConstFlat in9) {
220 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
221 }
222};
223} // namespace functor
224} // namespace tensorflow
225
226#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
227