1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ |
18 | |
19 | // Functor definitions for ScatterND ops, must be compilable by nvcc. |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include <atomic> |
24 | |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | |
27 | #include "tensorflow/core/framework/bounds_check.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/kernels/fill_functor.h" |
33 | #include "tensorflow/core/kernels/scatter_nd_op.h" |
34 | #include "tensorflow/core/platform/mutex.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | #include "tensorflow/core/util/util.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | typedef Eigen::ThreadPoolDevice CPUDevice; |
41 | |
42 | class OpKernelContext; |
43 | |
44 | // Specialization of UpdateExecutor to CPU |
45 | namespace update_executor { |
46 | |
47 | template <typename T, typename Input, typename Update, typename Output, |
48 | scatter_nd_op::UpdateOp OP> |
49 | class UpdateExecutor { |
50 | public: |
51 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input value, |
52 | Update update, Output output); |
53 | }; |
54 | |
55 | template <typename T, typename Input, typename Update, typename Output> |
56 | class UpdateExecutor<T, Input, Update, Output, |
57 | scatter_nd_op::UpdateOp::ASSIGN> { |
58 | public: |
59 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, |
60 | Update update, Output output) { |
61 | output.device(device) = update; |
62 | } |
63 | }; |
64 | |
65 | template <typename T, typename Input, typename Update, typename Output> |
66 | class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> { |
67 | public: |
68 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, |
69 | Update update, Output output) { |
70 | output.device(device) += update; |
71 | } |
72 | }; |
73 | |
74 | template <typename T, typename Input, typename Update, typename Output> |
75 | class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> { |
76 | public: |
77 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, |
78 | Update update, Output output) { |
79 | output.device(device) -= update; |
80 | } |
81 | }; |
82 | |
83 | template <typename T, typename Input, typename Update, typename Output> |
84 | class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> { |
85 | public: |
86 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, |
87 | Update update, Output output) { |
88 | output.device(device) = output.cwiseMin(update); |
89 | } |
90 | }; |
91 | |
92 | template <typename T, typename Input, typename Update, typename Output> |
93 | class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> { |
94 | public: |
95 | EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, |
96 | Update update, Output output) { |
97 | output.device(device) = output.cwiseMax(update); |
98 | } |
99 | }; |
100 | |
101 | } // namespace update_executor |
102 | |
103 | namespace functor { |
104 | |
105 | // Implementation of update functor for CPU. |
106 | template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> |
107 | struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> { |
108 | Index operator()( |
109 | const CPUDevice& d, const Index slice_size, |
110 | const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, |
111 | typename TTypes<T, 2>::Tensor Tparams, |
112 | typename TTypes<Index, 2>::ConstTensor Tindices, |
113 | typename TTypes<T, 2>::ConstTensor Tupdates, |
114 | typename TTypes<T, 2>::Tensor Toutput) { |
115 | // error_loc is -1 if there's no out-of-bounds index, |
116 | // otherwise it is the location of an OOB index in Tindices. |
117 | Index error_loc = -1; |
118 | |
119 | const Eigen::DenseIndex batch_size = Tindices.dimension(0); |
120 | |
121 | Index batch_strides[IXDIM]; |
122 | for (int dim = IXDIM - 1; dim >= 0; --dim) { |
123 | if (dim == IXDIM - 1) { |
124 | batch_strides[dim] = 1; |
125 | } else { |
126 | batch_strides[dim] = |
127 | batch_strides[dim + 1] * output_shape_prefix[dim + 1]; |
128 | } |
129 | } |
130 | |
131 | for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { |
132 | Index i = 0; |
133 | bool out_of_bounds = false; |
134 | for (int dim = 0; dim < IXDIM; ++dim) { |
135 | const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); |
136 | out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); |
137 | i += ix_d * batch_strides[dim]; |
138 | } |
139 | if (TF_PREDICT_FALSE(out_of_bounds)) { |
140 | error_loc = loc; |
141 | break; |
142 | } else { |
143 | auto input_chip = Toutput.template chip<0>(i); |
144 | auto output_chip = input_chip; |
145 | auto update_chip = Tupdates.template chip<0>(loc); |
146 | update_executor::UpdateExecutor< |
147 | CPUDevice, decltype(input_chip), decltype(update_chip), |
148 | decltype(output_chip), OP>::Execute(d, input_chip, update_chip, |
149 | output_chip); |
150 | } |
151 | } |
152 | |
153 | return error_loc; |
154 | } |
155 | }; |
156 | |
157 | #define REGISTER_SCATTER_ND_FULL(T, Index, op) \ |
158 | template Index \ |
159 | ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ |
160 | const CPUDevice& d, const Index slice_size, \ |
161 | const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ |
162 | output_shape_prefix, \ |
163 | typename TTypes<T, 2>::Tensor Tparams, \ |
164 | typename TTypes<Index, 2>::ConstTensor Tindices, \ |
165 | typename TTypes<T, 2>::ConstTensor Tupdates, \ |
166 | typename TTypes<T, 2>::Tensor Toutput) |
167 | |
168 | #define REGISTER_SCATTER_ND_INDEX(type, op) \ |
169 | REGISTER_SCATTER_ND_FULL(type, int32, op); \ |
170 | REGISTER_SCATTER_ND_FULL(type, int64, op) |
171 | |
172 | #define REGISTER_SCATTER_ND_UPDATE(type) \ |
173 | REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN); |
174 | |
175 | #define REGISTER_SCATTER_ND_MATH(type) \ |
176 | REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ |
177 | REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); |
178 | |
179 | #define REGISTER_SCATTER_ND_MIN_MAX(type) \ |
180 | REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \ |
181 | REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN); |
182 | |
183 | TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); |
184 | REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD); |
185 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH); |
186 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX); |
187 | TF_CALL_bool(REGISTER_SCATTER_ND_MATH); |
188 | |
189 | #undef REGISTER_SCATTER_ND_MATH |
190 | #undef REGISTER_SCATTER_ND_MIN_MAX |
191 | #undef REGISTER_SCATTER_ND_UPDATE |
192 | #undef REGISTER_SCATTER_ND_INDEX |
193 | #undef REGISTER_SCATTER_ND_FULL |
194 | } // namespace functor |
195 | |
196 | } // namespace tensorflow |
197 | |
198 | #endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ |
199 | |