1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/tensor_types.h" |
21 | |
22 | #include "tensorflow/core/kernels/aggregate_ops.h" |
23 | |
24 | typedef Eigen::ThreadPoolDevice CPUDevice; |
25 | |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Partial specializations for a CPUDevice, that uses the Eigen implementation |
30 | // from AddNEigenImpl. |
31 | namespace functor { |
32 | template <typename T> |
33 | struct Add2Functor<CPUDevice, T> { |
34 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
35 | typename TTypes<T>::ConstFlat in1, |
36 | typename TTypes<T>::ConstFlat in2) { |
37 | Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2); |
38 | } |
39 | }; |
40 | template <typename T> |
41 | struct Add3Functor<CPUDevice, T> { |
42 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
43 | typename TTypes<T>::ConstFlat in1, |
44 | typename TTypes<T>::ConstFlat in2, |
45 | typename TTypes<T>::ConstFlat in3) { |
46 | Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3); |
47 | } |
48 | }; |
49 | template <typename T> |
50 | struct Add4Functor<CPUDevice, T> { |
51 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
52 | typename TTypes<T>::ConstFlat in1, |
53 | typename TTypes<T>::ConstFlat in2, |
54 | typename TTypes<T>::ConstFlat in3, |
55 | typename TTypes<T>::ConstFlat in4) { |
56 | Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4); |
57 | } |
58 | }; |
59 | template <typename T> |
60 | struct Add5Functor<CPUDevice, T> { |
61 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
62 | typename TTypes<T>::ConstFlat in1, |
63 | typename TTypes<T>::ConstFlat in2, |
64 | typename TTypes<T>::ConstFlat in3, |
65 | typename TTypes<T>::ConstFlat in4, |
66 | typename TTypes<T>::ConstFlat in5) { |
67 | Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); |
68 | } |
69 | }; |
70 | template <typename T> |
71 | struct Add6Functor<CPUDevice, T> { |
72 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
73 | typename TTypes<T>::ConstFlat in1, |
74 | typename TTypes<T>::ConstFlat in2, |
75 | typename TTypes<T>::ConstFlat in3, |
76 | typename TTypes<T>::ConstFlat in4, |
77 | typename TTypes<T>::ConstFlat in5, |
78 | typename TTypes<T>::ConstFlat in6) { |
79 | Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); |
80 | } |
81 | }; |
82 | template <typename T> |
83 | struct Add7Functor<CPUDevice, T> { |
84 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, |
85 | typename TTypes<T>::ConstFlat in1, |
86 | typename TTypes<T>::ConstFlat in2, |
87 | typename TTypes<T>::ConstFlat in3, |
88 | typename TTypes<T>::ConstFlat in4, |
89 | typename TTypes<T>::ConstFlat in5, |
90 | typename TTypes<T>::ConstFlat in6, |
91 | typename TTypes<T>::ConstFlat in7) { |
92 | Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, |
93 | in7); |
94 | } |
95 | }; |
96 | |
97 | template <typename T> |
98 | struct Add8Functor<CPUDevice, T> { |
99 | void operator()( |
100 | const CPUDevice& d, typename TTypes<T>::Flat out, |
101 | typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, |
102 | typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, |
103 | typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, |
104 | typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { |
105 | Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, |
106 | in7, in8); |
107 | } |
108 | }; |
109 | |
110 | template <typename T> |
111 | struct Add8pFunctor<CPUDevice, T> { |
112 | void operator()( |
113 | const CPUDevice& d, typename TTypes<T>::Flat out, |
114 | typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, |
115 | typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, |
116 | typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, |
117 | typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { |
118 | Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, |
119 | in7, in8); |
120 | } |
121 | }; |
122 | |
123 | template <typename T> |
124 | struct Add9Functor<CPUDevice, T> { |
125 | void operator()( |
126 | const CPUDevice& d, typename TTypes<T>::Flat out, |
127 | typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, |
128 | typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, |
129 | typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, |
130 | typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, |
131 | typename TTypes<T>::ConstFlat in9) { |
132 | Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, |
133 | in7, in8, in9); |
134 | } |
135 | }; |
136 | |
137 | |
138 | } // namespace functor |
139 | |
140 | } // namespace tensorflow |
141 | |
142 | #endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ |
143 | |