1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ |
18 | |
19 | #include <type_traits> |
20 | |
21 | #include "third_party/eigen3/Eigen/Core" |
22 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
23 | #include "tensorflow/core/framework/bounds_check.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/variant_op_registry.h" |
27 | #include "tensorflow/core/kernels/dense_update_functor.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | #include "tensorflow/core/util/determinism.h" |
30 | #include "tensorflow/core/util/work_sharder.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | class OpKernelContext; |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | typedef Eigen::GpuDevice GPUDevice; |
37 | |
38 | namespace scatter_op { |
39 | |
40 | enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; |
41 | |
42 | namespace internal { |
43 | |
44 | template <scatter_op::UpdateOp Op> |
45 | struct Assign {}; |
46 | template <> |
47 | struct Assign<scatter_op::UpdateOp::ASSIGN> { |
48 | template <typename Params, typename Update> |
49 | static void Run(Params p, Update u) { |
50 | p = u; |
51 | } |
52 | template <typename Params, typename Update> |
53 | static void RunScalar(Params p, Update u) { |
54 | p.setConstant(u); |
55 | } |
56 | }; |
57 | template <> |
58 | struct Assign<scatter_op::UpdateOp::ADD> { |
59 | template <typename Params, typename Update> |
60 | static void Run(Params p, Update u) { |
61 | p += u; |
62 | } |
63 | template <typename Params, typename Update> |
64 | static void RunScalar(Params p, Update u) { |
65 | p = p + u; |
66 | } |
67 | }; |
68 | template <> |
69 | struct Assign<scatter_op::UpdateOp::SUB> { |
70 | template <typename Params, typename Update> |
71 | static void Run(Params p, Update u) { |
72 | p -= u; |
73 | } |
74 | template <typename Params, typename Update> |
75 | static void RunScalar(Params p, Update u) { |
76 | p = p + static_cast<Update>(-u); |
77 | } |
78 | }; |
79 | template <> |
80 | struct Assign<scatter_op::UpdateOp::MUL> { |
81 | template <typename Params, typename Update> |
82 | static void Run(Params p, Update u) { |
83 | p *= u; |
84 | } |
85 | template <typename Params, typename Update> |
86 | static void RunScalar(Params p, Update u) { |
87 | p = p * u; |
88 | } |
89 | }; |
90 | template <> |
91 | struct Assign<scatter_op::UpdateOp::DIV> { |
92 | template <typename Params, typename Update> |
93 | static void Run(Params p, Update u) { |
94 | p /= u; |
95 | } |
96 | template <typename Params, typename Update> |
97 | static void RunScalar(Params p, Update u) { |
98 | p = p / u; |
99 | } |
100 | }; |
101 | template <> |
102 | struct Assign<scatter_op::UpdateOp::MIN> { |
103 | // This method requires that Params and Update are tensor types. |
104 | template <typename Params, typename Update> |
105 | static void Run(Params p, Update u) { |
106 | p = p.cwiseMin(u); |
107 | } |
108 | // Same thing, but for Update being a scalar type. |
109 | template <typename Params, typename Update> |
110 | static void RunScalar(Params p, Update u) { |
111 | p = p.cwiseMin(u); |
112 | } |
113 | }; |
114 | template <> |
115 | struct Assign<scatter_op::UpdateOp::MAX> { |
116 | template <typename Params, typename Update> |
117 | static void Run(Params p, Update u) { |
118 | p = p.cwiseMax(u); |
119 | } |
120 | template <typename Params, typename Update> |
121 | static void RunScalar(Params p, Update u) { |
122 | p = p.cwiseMax(u); |
123 | } |
124 | }; |
125 | |
126 | |
127 | } // namespace internal |
128 | } // namespace scatter_op |
129 | |
130 | namespace functor { |
131 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
132 | struct ScatterFunctor { |
133 | Index operator()(OpKernelContext* c, const Device& d, |
134 | typename TTypes<T>::Matrix params, |
135 | typename TTypes<T>::ConstMatrix updates, |
136 | typename TTypes<Index>::ConstFlat indices); |
137 | }; |
138 | |
139 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
140 | struct ScatterFunctorBase { |
141 | Index ParallelExecute(OpKernelContext* c, const Device& d, |
142 | typename TTypes<T>::Matrix params, |
143 | typename TTypes<T>::ConstMatrix updates, |
144 | typename TTypes<Index>::ConstFlat indices) { |
145 | const Index N = static_cast<Index>(indices.size()); |
146 | const Index limit = static_cast<Index>(params.dimension(0)); |
147 | const Index kMaxLocks = 1024; |
148 | const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks; |
149 | // To reduce the number of locks and the memory usage, we divide the whole |
150 | // index space into kMaxLocks regions with each lock serializing access to |
151 | // a region. |
152 | mutex accessed[kMaxLocks]; |
153 | std::atomic<Index> bad_index(-1); |
154 | auto ParallelScatter = [&](Index start, Index end) { |
155 | for (Index i = start; i < end; ++i) { |
156 | // Grab the index and check its validity. Do this carefully, |
157 | // to avoid checking the value and grabbing it again from |
158 | // memory a second time (a security risk since it may change in |
159 | // between). |
160 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
161 | if (!FastBoundsCheck(index, limit)) { |
162 | bad_index = i; |
163 | return; |
164 | } |
165 | const Index lock_id = index / entries_per_lock; |
166 | // Copy last Ndim-1 dimensions of updates[i] to params[index] |
167 | { |
168 | mutex_lock l(accessed[lock_id]); |
169 | scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), |
170 | updates.template chip<0>(i)); |
171 | } |
172 | } |
173 | }; |
174 | const float kMovingCost = 2.5f; |
175 | float shard_cost = kMovingCost * params.dimension(1); |
176 | const DeviceBase::CpuWorkerThreads& worker_threads = |
177 | *(c->device()->tensorflow_cpu_worker_threads()); |
178 | Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost, |
179 | ParallelScatter); // TODO: Come up with a good cost estimate. |
180 | return bad_index; |
181 | } |
182 | Index SerialExecute(OpKernelContext* c, const Device& d, |
183 | typename TTypes<T>::Matrix params, |
184 | typename TTypes<T>::ConstMatrix updates, |
185 | typename TTypes<Index>::ConstFlat indices) { |
186 | const Index N = static_cast<Index>(indices.size()); |
187 | const Index limit = static_cast<Index>(params.dimension(0)); |
188 | for (Index i = 0; i < N; ++i) { |
189 | // Grab the index and check its validity. Do this carefully, |
190 | // to avoid checking the value and grabbing it again from |
191 | // memory a second time (a security risk since it may change in |
192 | // between). |
193 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
194 | if (!FastBoundsCheck(index, limit)) return i; |
195 | // Copy last Ndim-1 dimensions of updates[i] to params[index] |
196 | scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), |
197 | updates.template chip<0>(i)); |
198 | } |
199 | return -1; |
200 | } |
201 | |
202 | Index operator()(OpKernelContext* c, const Device& d, |
203 | typename TTypes<T>::Matrix params, |
204 | typename TTypes<T>::ConstMatrix updates, |
205 | typename TTypes<Index>::ConstFlat indices) { |
206 | #ifdef PLATFORM_GOOGLE |
207 | // The parallel version is significantly slower internally. Only call the |
208 | // serial version for now. |
209 | // TODO(penporn): Avoid locking in parallelization (sort beforehand). |
210 | return SerialExecute(c, d, params, updates, indices); |
211 | #else |
212 | // indices and params sizes were validated in DoCompute(). |
213 | const Index N = static_cast<Index>(indices.size()); |
214 | const Index limit = static_cast<Index>(params.dimension(0)); |
215 | const Index min_n_threshold = 1024; |
216 | const Index ser_par_ratio = 10000; |
217 | // For parallelizing the updates, duplicate entries need to be handled |
218 | // correctly. Multiple updates to the same index has to be serialized. |
219 | // This can lead to lock contention which may nullify the benefits of |
220 | // parallelization. Assuming uniform random distribution of the indices, we |
221 | // come up with a rough heuristic and determine whether the updates execute |
222 | // serially or parallelly. Also if 'N' is small, overheads of parallel |
223 | // execution outweigh its benefits and hence we check the value of N. |
224 | const bool execute_serial = N < min_n_threshold || |
225 | (N / limit) > ser_par_ratio || |
226 | OpDeterminismRequired(); |
227 | if (execute_serial) |
228 | return SerialExecute(c, d, params, updates, indices); |
229 | else |
230 | return ParallelExecute(c, d, params, updates, indices); |
231 | #endif // PLATFORM_GOOGLE |
232 | } |
233 | }; |
234 | |
235 | template <typename Device, typename Index> |
236 | struct ScatterFunctorVariantAssignBase { |
237 | Index operator()(OpKernelContext* c, const Device& d, |
238 | typename TTypes<Variant>::Matrix params, |
239 | typename TTypes<Variant>::ConstMatrix updates, |
240 | typename TTypes<Index>::ConstFlat indices) { |
241 | // indices and params sizes were validated in DoCompute(). |
242 | const Index N = static_cast<Index>(indices.size()); |
243 | const Index limit = static_cast<Index>(params.dimension(0)); |
244 | const Index cols = static_cast<Index>(params.dimension(1)); |
245 | DCHECK_EQ(N, updates.dimension(0)); |
246 | DCHECK_EQ(cols, updates.dimension(1)); |
247 | for (Index i = 0; i < N; i++) { |
248 | // Grab the index and check its validity. Do this carefully, |
249 | // to avoid checking the value and grabbing it again from |
250 | // memory a second time (a security risk since it may change in between). |
251 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
252 | if (!FastBoundsCheck(index, limit)) return i; |
253 | // Copy last Ndim-1 dimensions of updates[i] to params[index] |
254 | for (int j = 0; j < cols; ++j) { |
255 | const Variant& to_scatter = updates(i, j); |
256 | params(index, j) = to_scatter; |
257 | } |
258 | } |
259 | return -1; |
260 | } |
261 | }; |
262 | |
263 | template <typename Index> |
264 | struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> |
265 | : ScatterFunctorVariantAssignBase<CPUDevice, Index> {}; |
266 | |
267 | template <typename Index> |
268 | struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> |
269 | : ScatterFunctorVariantAssignBase<GPUDevice, Index> {}; |
270 | |
271 | |
272 | template <typename T, typename Index> |
273 | struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> { |
274 | Index operator()(OpKernelContext* c, const CPUDevice& d, |
275 | typename TTypes<T>::Matrix params, |
276 | typename TTypes<T>::ConstMatrix updates, |
277 | typename TTypes<Index>::ConstFlat indices) { |
278 | // indices and params sizes were validated in DoCompute(). |
279 | const Index N = static_cast<Index>(indices.size()); |
280 | const Index limit = static_cast<Index>(params.dimension(0)); |
281 | if (!std::is_same<T, tstring>::value) { |
282 | for (Index i = 0; i < N; i++) { |
283 | // Grab the index and check its validity. Do this carefully, |
284 | // to avoid checking the value and grabbing it again from |
285 | // memory a second time (a security risk since it may change in |
286 | // between). |
287 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
288 | if (!FastBoundsCheck(index, limit)) return i; |
289 | memmove(params.data() + index * params.dimension(1), |
290 | updates.data() + i * updates.dimension(1), |
291 | updates.dimension(1) * sizeof(T)); |
292 | } |
293 | } else { |
294 | for (Index i = 0; i < N; i++) { |
295 | // Grab the index and check its validity. Do this carefully, |
296 | // to avoid checking the value and grabbing it again from |
297 | // memory a second time (a security risk since it may change in |
298 | // between). |
299 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
300 | if (!FastBoundsCheck(index, limit)) return i; |
301 | // Copy last Ndim-1 dimensions of updates[i] to params[index] |
302 | scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run( |
303 | params.template chip<0>(index), updates.template chip<0>(i)); |
304 | } |
305 | } |
306 | return -1; |
307 | } |
308 | }; |
309 | |
310 | template <typename T, typename Index, scatter_op::UpdateOp op> |
311 | struct ScatterFunctor<CPUDevice, T, Index, op> |
312 | : ScatterFunctorBase<CPUDevice, T, Index, op> {}; |
313 | |
314 | |
315 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
316 | struct ScatterScalarFunctor { |
317 | Index operator()(OpKernelContext* c, const Device& d, |
318 | typename TTypes<T>::Matrix params, |
319 | const typename TTypes<T>::ConstScalar update, |
320 | typename TTypes<Index>::ConstFlat indices); |
321 | }; |
322 | |
323 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
324 | struct ScatterScalarFunctorBase { |
325 | Index operator()(OpKernelContext* c, const Device& d, |
326 | typename TTypes<T>::Matrix params, |
327 | const typename TTypes<T>::ConstScalar update, |
328 | typename TTypes<Index>::ConstFlat indices) { |
329 | // indices and params sizes were validated in DoCompute(). |
330 | const Index N = static_cast<Index>(indices.size()); |
331 | const Index limit = static_cast<Index>(params.dimension(0)); |
332 | for (Index i = 0; i < N; i++) { |
333 | // Grab the index and check its validity. Do this carefully, |
334 | // to avoid checking the value and grabbing it again from |
335 | // memory a second time (a security risk since it may change in between). |
336 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
337 | if (!FastBoundsCheck(index, limit)) return i; |
338 | // Broadcast update to params[index] |
339 | scatter_op::internal::Assign<op>::RunScalar( |
340 | params.template chip<0>(index), update()); |
341 | } |
342 | return -1; |
343 | } |
344 | }; |
345 | |
346 | template <typename Device, typename Index> |
347 | struct ScatterScalarFunctorVariantAssignBase { |
348 | Index operator()(OpKernelContext* c, const Device& d, |
349 | typename TTypes<Variant>::Matrix params, |
350 | const typename TTypes<Variant>::ConstScalar update, |
351 | typename TTypes<Index>::ConstFlat indices) { |
352 | // indices and params sizes were validated in DoCompute(). |
353 | const Index N = static_cast<Index>(indices.size()); |
354 | const Index limit = static_cast<Index>(params.dimension(0)); |
355 | const Index cols = static_cast<Index>(params.dimension(1)); |
356 | const Variant& to_scatter = update(); |
357 | for (Index i = 0; i < N; i++) { |
358 | // Grab the index and check its validity. Do this carefully, |
359 | // to avoid checking the value and grabbing it again from |
360 | // memory a second time (a security risk since it may change in between). |
361 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
362 | if (!FastBoundsCheck(index, limit)) return i; |
363 | // Broadcast update to params[index] |
364 | for (Index j = 0; j < cols; ++j) { |
365 | params(index, j) = to_scatter; |
366 | } |
367 | } |
368 | return -1; |
369 | } |
370 | }; |
371 | |
372 | template <typename Index> |
373 | struct ScatterScalarFunctor<CPUDevice, Variant, Index, |
374 | scatter_op::UpdateOp::ASSIGN> |
375 | : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {}; |
376 | template <typename Index> |
377 | struct ScatterScalarFunctor<GPUDevice, Variant, Index, |
378 | scatter_op::UpdateOp::ASSIGN> |
379 | : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {}; |
380 | |
381 | |
382 | template <typename T, typename Index> |
383 | struct ScatterScalarFunctorBase<CPUDevice, T, Index, |
384 | scatter_op::UpdateOp::ASSIGN> { |
385 | Index operator()(OpKernelContext* c, const CPUDevice& d, |
386 | typename TTypes<T>::Matrix params, |
387 | const typename TTypes<T>::ConstScalar update, |
388 | typename TTypes<Index>::ConstFlat indices) { |
389 | // indices and params sizes were validated in DoCompute(). |
390 | const Index N = static_cast<Index>(indices.size()); |
391 | const Index limit = static_cast<Index>(params.dimension(0)); |
392 | for (Index i = 0; i < N; i++) { |
393 | // Grab the index and check its validity. Do this carefully, |
394 | // to avoid checking the value and grabbing it again from |
395 | // memory a second time (a security risk since it may change in between). |
396 | const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); |
397 | if (!FastBoundsCheck(index, limit)) return i; |
398 | // Broadcast update to params[index] |
399 | scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar( |
400 | params.template chip<0>(index), update()); |
401 | } |
402 | return -1; |
403 | } |
404 | }; |
405 | |
406 | template <typename T, typename Index, scatter_op::UpdateOp op> |
407 | struct ScatterScalarFunctor<CPUDevice, T, Index, op> |
408 | : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {}; |
409 | |
410 | |
411 | } // namespace functor |
412 | } // namespace tensorflow |
413 | |
414 | #endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ |
415 | |