1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
17#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
18
19#include <type_traits>
20
21#include "third_party/eigen3/Eigen/Core"
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/bounds_check.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/variant_op_registry.h"
27#include "tensorflow/core/kernels/dense_update_functor.h"
28#include "tensorflow/core/platform/types.h"
29#include "tensorflow/core/util/determinism.h"
30#include "tensorflow/core/util/work_sharder.h"
31
32namespace tensorflow {
33
34class OpKernelContext;
35typedef Eigen::ThreadPoolDevice CPUDevice;
36typedef Eigen::GpuDevice GPUDevice;
37
38namespace scatter_op {
39
40enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
41
42namespace internal {
43
44template <scatter_op::UpdateOp Op>
45struct Assign {};
46template <>
47struct Assign<scatter_op::UpdateOp::ASSIGN> {
48 template <typename Params, typename Update>
49 static void Run(Params p, Update u) {
50 p = u;
51 }
52 template <typename Params, typename Update>
53 static void RunScalar(Params p, Update u) {
54 p.setConstant(u);
55 }
56};
57template <>
58struct Assign<scatter_op::UpdateOp::ADD> {
59 template <typename Params, typename Update>
60 static void Run(Params p, Update u) {
61 p += u;
62 }
63 template <typename Params, typename Update>
64 static void RunScalar(Params p, Update u) {
65 p = p + u;
66 }
67};
68template <>
69struct Assign<scatter_op::UpdateOp::SUB> {
70 template <typename Params, typename Update>
71 static void Run(Params p, Update u) {
72 p -= u;
73 }
74 template <typename Params, typename Update>
75 static void RunScalar(Params p, Update u) {
76 p = p + static_cast<Update>(-u);
77 }
78};
79template <>
80struct Assign<scatter_op::UpdateOp::MUL> {
81 template <typename Params, typename Update>
82 static void Run(Params p, Update u) {
83 p *= u;
84 }
85 template <typename Params, typename Update>
86 static void RunScalar(Params p, Update u) {
87 p = p * u;
88 }
89};
90template <>
91struct Assign<scatter_op::UpdateOp::DIV> {
92 template <typename Params, typename Update>
93 static void Run(Params p, Update u) {
94 p /= u;
95 }
96 template <typename Params, typename Update>
97 static void RunScalar(Params p, Update u) {
98 p = p / u;
99 }
100};
101template <>
102struct Assign<scatter_op::UpdateOp::MIN> {
103 // This method requires that Params and Update are tensor types.
104 template <typename Params, typename Update>
105 static void Run(Params p, Update u) {
106 p = p.cwiseMin(u);
107 }
108 // Same thing, but for Update being a scalar type.
109 template <typename Params, typename Update>
110 static void RunScalar(Params p, Update u) {
111 p = p.cwiseMin(u);
112 }
113};
114template <>
115struct Assign<scatter_op::UpdateOp::MAX> {
116 template <typename Params, typename Update>
117 static void Run(Params p, Update u) {
118 p = p.cwiseMax(u);
119 }
120 template <typename Params, typename Update>
121 static void RunScalar(Params p, Update u) {
122 p = p.cwiseMax(u);
123 }
124};
125
126
127} // namespace internal
128} // namespace scatter_op
129
130namespace functor {
131template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
132struct ScatterFunctor {
133 Index operator()(OpKernelContext* c, const Device& d,
134 typename TTypes<T>::Matrix params,
135 typename TTypes<T>::ConstMatrix updates,
136 typename TTypes<Index>::ConstFlat indices);
137};
138
139template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
140struct ScatterFunctorBase {
141 Index ParallelExecute(OpKernelContext* c, const Device& d,
142 typename TTypes<T>::Matrix params,
143 typename TTypes<T>::ConstMatrix updates,
144 typename TTypes<Index>::ConstFlat indices) {
145 const Index N = static_cast<Index>(indices.size());
146 const Index limit = static_cast<Index>(params.dimension(0));
147 const Index kMaxLocks = 1024;
148 const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks;
149 // To reduce the number of locks and the memory usage, we divide the whole
150 // index space into kMaxLocks regions with each lock serializing access to
151 // a region.
152 mutex accessed[kMaxLocks];
153 std::atomic<Index> bad_index(-1);
154 auto ParallelScatter = [&](Index start, Index end) {
155 for (Index i = start; i < end; ++i) {
156 // Grab the index and check its validity. Do this carefully,
157 // to avoid checking the value and grabbing it again from
158 // memory a second time (a security risk since it may change in
159 // between).
160 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
161 if (!FastBoundsCheck(index, limit)) {
162 bad_index = i;
163 return;
164 }
165 const Index lock_id = index / entries_per_lock;
166 // Copy last Ndim-1 dimensions of updates[i] to params[index]
167 {
168 mutex_lock l(accessed[lock_id]);
169 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
170 updates.template chip<0>(i));
171 }
172 }
173 };
174 const float kMovingCost = 2.5f;
175 float shard_cost = kMovingCost * params.dimension(1);
176 const DeviceBase::CpuWorkerThreads& worker_threads =
177 *(c->device()->tensorflow_cpu_worker_threads());
178 Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost,
179 ParallelScatter); // TODO: Come up with a good cost estimate.
180 return bad_index;
181 }
182 Index SerialExecute(OpKernelContext* c, const Device& d,
183 typename TTypes<T>::Matrix params,
184 typename TTypes<T>::ConstMatrix updates,
185 typename TTypes<Index>::ConstFlat indices) {
186 const Index N = static_cast<Index>(indices.size());
187 const Index limit = static_cast<Index>(params.dimension(0));
188 for (Index i = 0; i < N; ++i) {
189 // Grab the index and check its validity. Do this carefully,
190 // to avoid checking the value and grabbing it again from
191 // memory a second time (a security risk since it may change in
192 // between).
193 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
194 if (!FastBoundsCheck(index, limit)) return i;
195 // Copy last Ndim-1 dimensions of updates[i] to params[index]
196 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
197 updates.template chip<0>(i));
198 }
199 return -1;
200 }
201
202 Index operator()(OpKernelContext* c, const Device& d,
203 typename TTypes<T>::Matrix params,
204 typename TTypes<T>::ConstMatrix updates,
205 typename TTypes<Index>::ConstFlat indices) {
206#ifdef PLATFORM_GOOGLE
207 // The parallel version is significantly slower internally. Only call the
208 // serial version for now.
209 // TODO(penporn): Avoid locking in parallelization (sort beforehand).
210 return SerialExecute(c, d, params, updates, indices);
211#else
212 // indices and params sizes were validated in DoCompute().
213 const Index N = static_cast<Index>(indices.size());
214 const Index limit = static_cast<Index>(params.dimension(0));
215 const Index min_n_threshold = 1024;
216 const Index ser_par_ratio = 10000;
217 // For parallelizing the updates, duplicate entries need to be handled
218 // correctly. Multiple updates to the same index has to be serialized.
219 // This can lead to lock contention which may nullify the benefits of
220 // parallelization. Assuming uniform random distribution of the indices, we
221 // come up with a rough heuristic and determine whether the updates execute
222 // serially or parallelly. Also if 'N' is small, overheads of parallel
223 // execution outweigh its benefits and hence we check the value of N.
224 const bool execute_serial = N < min_n_threshold ||
225 (N / limit) > ser_par_ratio ||
226 OpDeterminismRequired();
227 if (execute_serial)
228 return SerialExecute(c, d, params, updates, indices);
229 else
230 return ParallelExecute(c, d, params, updates, indices);
231#endif // PLATFORM_GOOGLE
232 }
233};
234
235template <typename Device, typename Index>
236struct ScatterFunctorVariantAssignBase {
237 Index operator()(OpKernelContext* c, const Device& d,
238 typename TTypes<Variant>::Matrix params,
239 typename TTypes<Variant>::ConstMatrix updates,
240 typename TTypes<Index>::ConstFlat indices) {
241 // indices and params sizes were validated in DoCompute().
242 const Index N = static_cast<Index>(indices.size());
243 const Index limit = static_cast<Index>(params.dimension(0));
244 const Index cols = static_cast<Index>(params.dimension(1));
245 DCHECK_EQ(N, updates.dimension(0));
246 DCHECK_EQ(cols, updates.dimension(1));
247 for (Index i = 0; i < N; i++) {
248 // Grab the index and check its validity. Do this carefully,
249 // to avoid checking the value and grabbing it again from
250 // memory a second time (a security risk since it may change in between).
251 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
252 if (!FastBoundsCheck(index, limit)) return i;
253 // Copy last Ndim-1 dimensions of updates[i] to params[index]
254 for (int j = 0; j < cols; ++j) {
255 const Variant& to_scatter = updates(i, j);
256 params(index, j) = to_scatter;
257 }
258 }
259 return -1;
260 }
261};
262
263template <typename Index>
264struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
265 : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
266
267template <typename Index>
268struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
269 : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
270
271
272template <typename T, typename Index>
273struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
274 Index operator()(OpKernelContext* c, const CPUDevice& d,
275 typename TTypes<T>::Matrix params,
276 typename TTypes<T>::ConstMatrix updates,
277 typename TTypes<Index>::ConstFlat indices) {
278 // indices and params sizes were validated in DoCompute().
279 const Index N = static_cast<Index>(indices.size());
280 const Index limit = static_cast<Index>(params.dimension(0));
281 if (!std::is_same<T, tstring>::value) {
282 for (Index i = 0; i < N; i++) {
283 // Grab the index and check its validity. Do this carefully,
284 // to avoid checking the value and grabbing it again from
285 // memory a second time (a security risk since it may change in
286 // between).
287 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
288 if (!FastBoundsCheck(index, limit)) return i;
289 memmove(params.data() + index * params.dimension(1),
290 updates.data() + i * updates.dimension(1),
291 updates.dimension(1) * sizeof(T));
292 }
293 } else {
294 for (Index i = 0; i < N; i++) {
295 // Grab the index and check its validity. Do this carefully,
296 // to avoid checking the value and grabbing it again from
297 // memory a second time (a security risk since it may change in
298 // between).
299 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
300 if (!FastBoundsCheck(index, limit)) return i;
301 // Copy last Ndim-1 dimensions of updates[i] to params[index]
302 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run(
303 params.template chip<0>(index), updates.template chip<0>(i));
304 }
305 }
306 return -1;
307 }
308};
309
310template <typename T, typename Index, scatter_op::UpdateOp op>
311struct ScatterFunctor<CPUDevice, T, Index, op>
312 : ScatterFunctorBase<CPUDevice, T, Index, op> {};
313
314
315template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
316struct ScatterScalarFunctor {
317 Index operator()(OpKernelContext* c, const Device& d,
318 typename TTypes<T>::Matrix params,
319 const typename TTypes<T>::ConstScalar update,
320 typename TTypes<Index>::ConstFlat indices);
321};
322
323template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
324struct ScatterScalarFunctorBase {
325 Index operator()(OpKernelContext* c, const Device& d,
326 typename TTypes<T>::Matrix params,
327 const typename TTypes<T>::ConstScalar update,
328 typename TTypes<Index>::ConstFlat indices) {
329 // indices and params sizes were validated in DoCompute().
330 const Index N = static_cast<Index>(indices.size());
331 const Index limit = static_cast<Index>(params.dimension(0));
332 for (Index i = 0; i < N; i++) {
333 // Grab the index and check its validity. Do this carefully,
334 // to avoid checking the value and grabbing it again from
335 // memory a second time (a security risk since it may change in between).
336 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
337 if (!FastBoundsCheck(index, limit)) return i;
338 // Broadcast update to params[index]
339 scatter_op::internal::Assign<op>::RunScalar(
340 params.template chip<0>(index), update());
341 }
342 return -1;
343 }
344};
345
346template <typename Device, typename Index>
347struct ScatterScalarFunctorVariantAssignBase {
348 Index operator()(OpKernelContext* c, const Device& d,
349 typename TTypes<Variant>::Matrix params,
350 const typename TTypes<Variant>::ConstScalar update,
351 typename TTypes<Index>::ConstFlat indices) {
352 // indices and params sizes were validated in DoCompute().
353 const Index N = static_cast<Index>(indices.size());
354 const Index limit = static_cast<Index>(params.dimension(0));
355 const Index cols = static_cast<Index>(params.dimension(1));
356 const Variant& to_scatter = update();
357 for (Index i = 0; i < N; i++) {
358 // Grab the index and check its validity. Do this carefully,
359 // to avoid checking the value and grabbing it again from
360 // memory a second time (a security risk since it may change in between).
361 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
362 if (!FastBoundsCheck(index, limit)) return i;
363 // Broadcast update to params[index]
364 for (Index j = 0; j < cols; ++j) {
365 params(index, j) = to_scatter;
366 }
367 }
368 return -1;
369 }
370};
371
372template <typename Index>
373struct ScatterScalarFunctor<CPUDevice, Variant, Index,
374 scatter_op::UpdateOp::ASSIGN>
375 : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
376template <typename Index>
377struct ScatterScalarFunctor<GPUDevice, Variant, Index,
378 scatter_op::UpdateOp::ASSIGN>
379 : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
380
381
382template <typename T, typename Index>
383struct ScatterScalarFunctorBase<CPUDevice, T, Index,
384 scatter_op::UpdateOp::ASSIGN> {
385 Index operator()(OpKernelContext* c, const CPUDevice& d,
386 typename TTypes<T>::Matrix params,
387 const typename TTypes<T>::ConstScalar update,
388 typename TTypes<Index>::ConstFlat indices) {
389 // indices and params sizes were validated in DoCompute().
390 const Index N = static_cast<Index>(indices.size());
391 const Index limit = static_cast<Index>(params.dimension(0));
392 for (Index i = 0; i < N; i++) {
393 // Grab the index and check its validity. Do this carefully,
394 // to avoid checking the value and grabbing it again from
395 // memory a second time (a security risk since it may change in between).
396 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
397 if (!FastBoundsCheck(index, limit)) return i;
398 // Broadcast update to params[index]
399 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
400 params.template chip<0>(index), update());
401 }
402 return -1;
403 }
404};
405
406template <typename T, typename Index, scatter_op::UpdateOp op>
407struct ScatterScalarFunctor<CPUDevice, T, Index, op>
408 : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
409
410
411} // namespace functor
412} // namespace tensorflow
413
414#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
415