1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
17#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
18
19// Functor definitions for ScatterND ops, must be compilable by nvcc.
20
21#define EIGEN_USE_THREADS
22
23#include <atomic>
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26
27#include "tensorflow/core/framework/bounds_check.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/kernels/fill_functor.h"
33#include "tensorflow/core/kernels/scatter_nd_op.h"
34#include "tensorflow/core/platform/mutex.h"
35#include "tensorflow/core/platform/types.h"
36#include "tensorflow/core/util/util.h"
37
38namespace tensorflow {
39
40typedef Eigen::ThreadPoolDevice CPUDevice;
41
42class OpKernelContext;
43
44// Specialization of UpdateExecutor to CPU
45namespace update_executor {
46
47template <typename T, typename Input, typename Update, typename Output,
48 scatter_nd_op::UpdateOp OP>
49class UpdateExecutor {
50 public:
51 EIGEN_STRONG_INLINE static void Execute(const T& device, Input value,
52 Update update, Output output);
53};
54
55template <typename T, typename Input, typename Update, typename Output>
56class UpdateExecutor<T, Input, Update, Output,
57 scatter_nd_op::UpdateOp::ASSIGN> {
58 public:
59 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
60 Update update, Output output) {
61 output.device(device) = update;
62 }
63};
64
65template <typename T, typename Input, typename Update, typename Output>
66class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
67 public:
68 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
69 Update update, Output output) {
70 output.device(device) += update;
71 }
72};
73
74template <typename T, typename Input, typename Update, typename Output>
75class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
76 public:
77 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
78 Update update, Output output) {
79 output.device(device) -= update;
80 }
81};
82
83template <typename T, typename Input, typename Update, typename Output>
84class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> {
85 public:
86 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
87 Update update, Output output) {
88 output.device(device) = output.cwiseMin(update);
89 }
90};
91
92template <typename T, typename Input, typename Update, typename Output>
93class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
94 public:
95 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
96 Update update, Output output) {
97 output.device(device) = output.cwiseMax(update);
98 }
99};
100
101} // namespace update_executor
102
103namespace functor {
104
105// Implementation of update functor for CPU.
106template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
107struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
108 Index operator()(
109 const CPUDevice& d, const Index slice_size,
110 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
111 typename TTypes<T, 2>::Tensor Tparams,
112 typename TTypes<Index, 2>::ConstTensor Tindices,
113 typename TTypes<T, 2>::ConstTensor Tupdates,
114 typename TTypes<T, 2>::Tensor Toutput) {
115 // error_loc is -1 if there's no out-of-bounds index,
116 // otherwise it is the location of an OOB index in Tindices.
117 Index error_loc = -1;
118
119 const Eigen::DenseIndex batch_size = Tindices.dimension(0);
120
121 Index batch_strides[IXDIM];
122 for (int dim = IXDIM - 1; dim >= 0; --dim) {
123 if (dim == IXDIM - 1) {
124 batch_strides[dim] = 1;
125 } else {
126 batch_strides[dim] =
127 batch_strides[dim + 1] * output_shape_prefix[dim + 1];
128 }
129 }
130
131 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
132 Index i = 0;
133 bool out_of_bounds = false;
134 for (int dim = 0; dim < IXDIM; ++dim) {
135 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
136 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
137 i += ix_d * batch_strides[dim];
138 }
139 if (TF_PREDICT_FALSE(out_of_bounds)) {
140 error_loc = loc;
141 break;
142 } else {
143 auto input_chip = Toutput.template chip<0>(i);
144 auto output_chip = input_chip;
145 auto update_chip = Tupdates.template chip<0>(loc);
146 update_executor::UpdateExecutor<
147 CPUDevice, decltype(input_chip), decltype(update_chip),
148 decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
149 output_chip);
150 }
151 }
152
153 return error_loc;
154 }
155};
156
157#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
158 template Index \
159 ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
160 const CPUDevice& d, const Index slice_size, \
161 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
162 output_shape_prefix, \
163 typename TTypes<T, 2>::Tensor Tparams, \
164 typename TTypes<Index, 2>::ConstTensor Tindices, \
165 typename TTypes<T, 2>::ConstTensor Tupdates, \
166 typename TTypes<T, 2>::Tensor Toutput)
167
168#define REGISTER_SCATTER_ND_INDEX(type, op) \
169 REGISTER_SCATTER_ND_FULL(type, int32, op); \
170 REGISTER_SCATTER_ND_FULL(type, int64, op)
171
172#define REGISTER_SCATTER_ND_UPDATE(type) \
173 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
174
175#define REGISTER_SCATTER_ND_MATH(type) \
176 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
177 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
178
179#define REGISTER_SCATTER_ND_MIN_MAX(type) \
180 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \
181 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN);
182
183TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
184REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD);
185TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
186TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX);
187TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
188
189#undef REGISTER_SCATTER_ND_MATH
190#undef REGISTER_SCATTER_ND_MIN_MAX
191#undef REGISTER_SCATTER_ND_UPDATE
192#undef REGISTER_SCATTER_ND_INDEX
193#undef REGISTER_SCATTER_ND_FULL
194} // namespace functor
195
196} // namespace tensorflow
197
198#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
199