1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/state_ops.cc.
17#define EIGEN_USE_THREADS
18
19#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20#define EIGEN_USE_GPU
21#include "tensorflow/core/platform/stream_executor.h"
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/framework/bounds_check.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/kernels/dense_update_functor.h"
31#include "tensorflow/core/kernels/fill_functor.h"
32#include "tensorflow/core/kernels/inplace_ops_functor.h"
33#include "tensorflow/core/kernels/scatter_nd_op.h"
34#include "tensorflow/core/kernels/scatter_nd_util.h"
35#include "tensorflow/core/kernels/training_op_helpers.h"
36#include "tensorflow/core/kernels/variable_ops.h"
37#include "tensorflow/core/lib/strings/str_util.h"
38#include "tensorflow/core/platform/mutex.h"
39#include "tensorflow/core/platform/types.h"
40#include "tensorflow/core/util/determinism.h"
41#include "tensorflow/core/util/util.h"
42
43namespace tensorflow {
44
45typedef Eigen::ThreadPoolDevice CPUDevice;
46typedef Eigen::GpuDevice GPUDevice;
47
48// Returns true if the three tensors have valid number of elements
49// If shape_input has 0 elements, then we need to have indices and updates with
50// exactly 0 elements too, otherwise we should error. If indices has 0 elements
51// then updates should also have 0 elements, otherwise we should error.
52bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices,
53 int64_t num_updates) {
54 if (num_indices == 0 && num_updates == 0) {
55 return true; // regardless of num_inputs ?= 0, covers both cases
56 }
57 // now we want all 3 tensors to have values
58 return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
59}
60
61template <typename Device, typename T, typename Index>
62class ScatterNdOp : public OpKernel {
63 public:
64 explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
65 const DataType dt = DataTypeToEnum<T>::v();
66 const DataType index_t = DataTypeToEnum<Index>::v();
67 OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
68 }
69
70 void Compute(OpKernelContext* c) override {
71 const Tensor& indices = c->input(0);
72 const Tensor& updates = c->input(1);
73 const Tensor& shape_input = c->input(2);
74
75 OP_REQUIRES(c, indices.shape().dims() >= 1,
76 errors::InvalidArgument(
77 "Indices shape must have rank at least one. Found:",
78 indices.shape().DebugString()));
79 OP_REQUIRES(c, updates.shape().dims() >= 1,
80 errors::InvalidArgument(
81 "Updates shape must have rank at least one. Found:",
82 updates.shape().DebugString()));
83
84 auto vec = shape_input.flat<Index>();
85 TensorShape shape;
86 OP_REQUIRES_OK(c,
87 TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
88
89 OP_REQUIRES(c,
90 ValidEmptyOutputShape(shape_input.NumElements(),
91 indices.shape().num_elements(),
92 updates.shape().num_elements()),
93 errors::InvalidArgument(
94 "Indices and updates specified for empty output shape"));
95
96 const int64_t outer_dims = indices.shape().dims() - 1;
97
98 for (int i = 0; i < outer_dims; ++i) {
99 OP_REQUIRES(
100 c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
101 errors::InvalidArgument(
102 "Dimensions [0,", outer_dims,
103 ") of indices[shape=", indices.shape().DebugString(),
104 "] must match dimensions [0,", outer_dims,
105 ") of updates[shape=", updates.shape().DebugString(), "]"));
106 }
107
108 const int64_t ix = indices.shape().dim_size(outer_dims);
109 OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix,
110 errors::InvalidArgument(
111 "Dimensions [", ix, ",", shape.dims(), ") of input[shape=",
112 shape.DebugString(), "] must match dimensions [",
113 outer_dims, ",", updates.shape().dims(),
114 ") of updates[shape=", updates.shape().DebugString(), "]"));
115
116 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
117 OP_REQUIRES(
118 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
119 errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(),
120 ") of input[shape=", shape.DebugString(),
121 "] must match dimensions [", outer_dims, ",",
122 updates.shape().dims(), ") of updates[shape=",
123 updates.shape().DebugString(), "]"));
124 }
125 OP_REQUIRES(c, shape_input.dims() == 1,
126 errors::InvalidArgument("Shape must be a vector"));
127
128 Tensor out;
129 OP_REQUIRES_OK(
130 c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
131 c, indices, updates, shape, &out, true /*allocate*/));
132 c->set_output(0, out);
133 }
134};
135
136template <typename Device, typename T, typename Index,
137 scatter_nd_op::UpdateOp op>
138class TensorScatterOp : public OpKernel {
139 public:
140 explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
141 const DataType dt = DataTypeToEnum<T>::v();
142 const DataType index_t = DataTypeToEnum<Index>::v();
143 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
144 }
145
146 void Compute(OpKernelContext* c) override {
147 const Tensor& input = c->input(0);
148 const Tensor& indices = c->input(1);
149 const Tensor& updates = c->input(2);
150
151 OP_REQUIRES(c, indices.shape().dims() >= 1,
152 errors::InvalidArgument(
153 "Indices shape must have rank at least one. Found:",
154 indices.shape().DebugString()));
155 OP_REQUIRES(c, updates.shape().dims() >= 1,
156 errors::InvalidArgument(
157 "Updates shape must have rank at least one. Found:",
158 updates.shape().DebugString()));
159
160 TensorShape shape = input.shape();
161
162 OP_REQUIRES(c,
163 ValidEmptyOutputShape(shape.num_elements(),
164 indices.shape().num_elements(),
165 updates.shape().num_elements()),
166 errors::InvalidArgument(
167 "Indices and updates specified for empty output shape"));
168
169 const int64_t outer_dims = indices.shape().dims() - 1;
170
171 for (int i = 0; i < outer_dims; ++i) {
172 OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
173 errors::InvalidArgument(
174 "Outer dimensions of indices and update must match. "
175 "Indices shape: ",
176 indices.shape().DebugString(),
177 ", updates shape:", updates.shape().DebugString()));
178 }
179
180 const int64_t ix = indices.shape().dim_size(outer_dims);
181 OP_REQUIRES(
182 c, updates.shape().dims() - outer_dims == shape.dims() - ix,
183 errors::InvalidArgument("Inner dimensions of output shape must match "
184 "inner dimensions of updates shape. Output: ",
185 shape.DebugString(),
186 " updates: ", updates.shape().DebugString()));
187 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
188 OP_REQUIRES(
189 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
190 errors::InvalidArgument(
191 "The inner ", shape.dims() - ix,
192 " dimensions of output.shape=", shape.DebugString(),
193 " must match the inner ", updates.shape().dims() - outer_dims,
194 " dimensions of updates.shape=", updates.shape().DebugString()));
195 }
196
197 AllocatorAttributes alloc_attr;
198 MemoryType memory_type = DEVICE_MEMORY;
199 if (std::is_same<Device, CPUDevice>::value) {
200 alloc_attr.set_on_host(true);
201 memory_type = HOST_MEMORY;
202 } else {
203 memory_type = DEVICE_MEMORY;
204 }
205 std::unique_ptr<Tensor> forwarded_input =
206 c->forward_input(0, 0, input.dtype(), shape, memory_type, alloc_attr);
207
208 if (forwarded_input == nullptr) {
209 // We were not able to forward the input, so we deep copy the tensor and
210 // set the output.
211 Tensor* out;
212 OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out));
213
214 OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
215 input, out));
216 OP_REQUIRES_OK(c,
217 functor::DoScatterNd<Device, T, Index, op>(
218 c, indices, updates, shape, out, false /*allocate*/));
219 } else {
220 // Output forwarded, so simply perform the scatter.
221 OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
222 c, indices, updates, shape, forwarded_input.get(),
223 false /*allocate*/));
224
225 c->set_output(0, *forwarded_input);
226 }
227 }
228};
229
230template <typename Device, typename T, typename Index,
231 scatter_nd_op::UpdateOp op>
232class ScatterNdUpdateOp : public OpKernel {
233 public:
234 explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
235 const DataType dt = DataTypeToEnum<T>::v();
236 const DataType dt_ref = DataTypeToEnum<T>::ref();
237 const DataType index_t = DataTypeToEnum<Index>::v();
238 dtype_ = c->input_type(0);
239 // If we are updating a resource, we always use the exclusive lock.
240 // For ref types, we lock based on the use_locking parameter
241 // Otherwise, we don't mutate the input tensor (we copy-on-write if needed).
242 if (c->input_type(0) == DT_RESOURCE) {
243 // TODO(apassos): what to validate here?
244 } else if (IsRefType(c->input_type(0))) {
245 OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
246 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
247 } else {
248 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
249 use_exclusive_lock_ = false;
250 }
251 }
252
253 void Compute(OpKernelContext* c) override {
254 if (dtype_ == DT_RESOURCE) {
255 core::RefCountPtr<Var> v;
256 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
257 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
258 mutex_lock m(*v->mu());
259 DoCompute(c);
260 } else if (use_exclusive_lock_) {
261 // If we're here, it means the input type is a ref.
262 DCHECK(IsRefType(c->input_dtype(0)));
263 // Hold mutex while we apply updates
264 mutex_lock l(*c->input_ref_mutex(0));
265 DoCompute(c);
266 } else {
267 DoCompute(c);
268 }
269 }
270
271 private:
272 DataType dtype_;
273 bool use_exclusive_lock_;
274
275 void DoCompute(OpKernelContext* c) {
276 const Tensor& indices = c->input(1);
277 const Tensor& updates = c->input(2);
278 Tensor params;
279 TensorShape params_shape;
280
281 if (dtype_ == DT_RESOURCE) {
282 core::RefCountPtr<Var> v;
283 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
284 Tensor* t = v->tensor();
285 params = *t;
286 params_shape = params.shape();
287 } else if (IsRefType(c->input_dtype(0))) {
288 params = c->mutable_input(0, use_exclusive_lock_);
289 params_shape = params.shape();
290 c->forward_ref_input_to_ref_output(0, 0);
291 OP_REQUIRES(c, params.IsInitialized(),
292 errors::FailedPrecondition("Null ref for params"));
293 } else {
294 Tensor* params_ptr;
295 params_shape = c->input(0).shape();
296 if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
297 &params_ptr)) {
298 // We weren't able to forward the input to output, so just
299 // allocate a new output tensor and copy the values over.
300 OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, &params_ptr));
301 params = *params_ptr;
302 functor::DenseUpdate<Device, T, ASSIGN> copy;
303 const Tensor& input_copy = c->input(0);
304 copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
305 } else {
306 params = *params_ptr;
307 }
308 }
309
310 OP_REQUIRES_OK(
311 c, functor::DoScatterNd<Device, T, Index, op>(
312 c, indices, updates, params_shape, &params, false /*allocate*/));
313 }
314};
315
316#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
317
318#define REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(type) \
319 template Status functor::DoScatterNd<GPUDevice, type, int64, \
320 scatter_nd_op::UpdateOp::ASSIGN>( \
321 OpKernelContext*, Tensor const&, Tensor const&, TensorShape const&, \
322 Tensor*, bool);
323
324// Explicitly instantiate DoScatterNd for template arguments which are used
325// by the CSRSparseMatrixToDense op.
326REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(float)
327REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(double)
328REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(complex64)
329REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(complex128)
330
331#undef REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU
332
333#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
334
335#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
336 REGISTER_KERNEL_BUILDER(Name(name) \
337 .Device(DEVICE_##dev) \
338 .TypeConstraint<type>("T") \
339 .TypeConstraint<index_type>("Tindices") \
340 .HostMemory("shape"), \
341 ScatterNdOp<dev##Device, type, index_type>)
342
343#define REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(index_type, name) \
344 REGISTER_KERNEL_BUILDER(Name(name) \
345 .Device(DEVICE_DEFAULT) \
346 .TypeConstraint<int32>("T") \
347 .TypeConstraint<index_type>("Tindices") \
348 .HostMemory("indices") \
349 .HostMemory("updates") \
350 .HostMemory("shape") \
351 .HostMemory("output"), \
352 ScatterNdOp<CPUDevice, int32, index_type>)
353
354#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
355 op) \
356 REGISTER_KERNEL_BUILDER( \
357 Name(name) \
358 .Device(DEVICE_##dev) \
359 .TypeConstraint<type>("T") \
360 .TypeConstraint<index_type>("Tindices"), \
361 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
362
363#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, name, \
364 op) \
365 REGISTER_KERNEL_BUILDER(Name(name) \
366 .Device(DEVICE_DEFAULT) \
367 .TypeConstraint<int32>("T") \
368 .TypeConstraint<index_type>("Tindices") \
369 .HostMemory("ref") \
370 .HostMemory("indices") \
371 .HostMemory("updates") \
372 .HostMemory("output_ref"), \
373 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
374
375#define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU( \
376 index_type, name, op) \
377 REGISTER_KERNEL_BUILDER(Name(name) \
378 .Device(DEVICE_DEFAULT) \
379 .TypeConstraint<int32>("T") \
380 .TypeConstraint<index_type>("Tindices") \
381 .HostMemory("input") \
382 .HostMemory("indices") \
383 .HostMemory("updates") \
384 .HostMemory("output"), \
385 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
386
387#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
388 dev, name, op) \
389 REGISTER_KERNEL_BUILDER( \
390 Name(name) \
391 .Device(DEVICE_##dev) \
392 .TypeConstraint<type>("T") \
393 .TypeConstraint<index_type>("Tindices") \
394 .HostMemory("ref"), \
395 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
396
397#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, \
398 name, op) \
399 REGISTER_KERNEL_BUILDER(Name(name) \
400 .Device(DEVICE_DEFAULT) \
401 .TypeConstraint<int32>("T") \
402 .TypeConstraint<index_type>("Tindices") \
403 .HostMemory("ref") \
404 .HostMemory("indices") \
405 .HostMemory("updates"), \
406 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
407
408#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
409 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
410 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64_t, dev, name)
411
412#define REGISTER_SCATTER_ND_KERNEL_INT32_GPU(name) \
413 REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int32, name); \
414 REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int64_t, name)
415
416#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
417 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
418 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
419
420#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \
421 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
422 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
423
424#define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU(name, op) \
425 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, \
426 op); \
427 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, \
428 name, op)
429
430#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
431 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
432 op); \
433 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
434
435#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \
436 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
437 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
438
439#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
440 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
441 scatter_nd_op::UpdateOp::ADD); \
442 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
443 scatter_nd_op::UpdateOp::ADD); \
444 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
445 scatter_nd_op::UpdateOp::SUB); \
446 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
447 type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \
448 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
449 type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
450
451#define REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU() \
452 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU( \
453 "ScatterNdNonAliasingAdd", scatter_nd_op::UpdateOp::ADD); \
454 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdAdd", \
455 scatter_nd_op::UpdateOp::ADD); \
456 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdSub", \
457 scatter_nd_op::UpdateOp::SUB); \
458 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
459 "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \
460 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
461 "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
462
463#define REGISTER_SCATTER_ND(type, dev) \
464 REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
465
466#define REGISTER_SCATTER_ND_INT32_GPU() \
467 REGISTER_SCATTER_ND_KERNEL_INT32_GPU("ScatterNd");
468
469#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
470 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
471 scatter_nd_op::UpdateOp::ASSIGN); \
472 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
473 type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
474
475#define REGISTER_SCATTER_ND_UPDATE_INT32_GPU() \
476 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
477 "ScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); \
478 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
479 "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
480
481#define REGISTER_SCATTER_ND_MIN_MAX(type, dev) \
482 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax", \
483 scatter_nd_op::UpdateOp::MAX); \
484 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin", \
485 scatter_nd_op::UpdateOp::MIN); \
486 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
487 type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
488 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
489 type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
490
491#define REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU() \
492 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMax", \
493 scatter_nd_op::UpdateOp::MAX); \
494 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMin", \
495 scatter_nd_op::UpdateOp::MIN); \
496 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
497 "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
498 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
499 "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
500
501// Registers CPU kernels.
502#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
503 REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
504
505#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
506 REGISTER_SCATTER_ND_UPDATE(type, CPU);
507
508#define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \
509 REGISTER_SCATTER_ND_MIN_MAX(type, CPU);
510
511#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
512#define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
513
514TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
515TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
516TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
517TF_CALL_tstring(REGISTER_SCATTER_ND_CPU);
518TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
519TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
520TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
521TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
522TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU);
523
524#define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
525 dev) \
526 REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \
527 .Device(DEVICE_##dev) \
528 .TypeConstraint<type>("T") \
529 .TypeConstraint<index_type>("Tindices"), \
530 TensorScatterOp<dev##Device, type, index_type, \
531 scatter_nd_op::UpdateOp::ASSIGN>)
532
533#define REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(index_type) \
534 REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \
535 .Device(DEVICE_DEFAULT) \
536 .TypeConstraint<int32>("T") \
537 .TypeConstraint<index_type>("Tindices") \
538 .HostMemory("tensor") \
539 .HostMemory("indices") \
540 .HostMemory("updates") \
541 .HostMemory("output"), \
542 TensorScatterOp<CPUDevice, int32, index_type, \
543 scatter_nd_op::UpdateOp::ASSIGN>)
544
545#define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \
546 REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \
547 .Device(DEVICE_##dev) \
548 .TypeConstraint<type>("T") \
549 .TypeConstraint<index_type>("Tindices"), \
550 TensorScatterOp<dev##Device, type, index_type, \
551 scatter_nd_op::UpdateOp::ADD>)
552
553#define REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(index_type) \
554 REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \
555 .Device(DEVICE_DEFAULT) \
556 .TypeConstraint<int32>("T") \
557 .TypeConstraint<index_type>("Tindices") \
558 .HostMemory("tensor") \
559 .HostMemory("indices") \
560 .HostMemory("updates") \
561 .HostMemory("output"), \
562 TensorScatterOp<CPUDevice, int32, index_type, \
563 scatter_nd_op::UpdateOp::ADD>)
564
565#define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \
566 REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \
567 .Device(DEVICE_##dev) \
568 .TypeConstraint<type>("T") \
569 .TypeConstraint<index_type>("Tindices"), \
570 TensorScatterOp<dev##Device, type, index_type, \
571 scatter_nd_op::UpdateOp::SUB>)
572
573#define REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(index_type) \
574 REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \
575 .Device(DEVICE_DEFAULT) \
576 .TypeConstraint<int32>("T") \
577 .TypeConstraint<index_type>("Tindices") \
578 .HostMemory("tensor") \
579 .HostMemory("indices") \
580 .HostMemory("updates") \
581 .HostMemory("output"), \
582 TensorScatterOp<CPUDevice, int32, index_type, \
583 scatter_nd_op::UpdateOp::SUB>)
584
585#define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \
586 REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
587 .Device(DEVICE_##dev) \
588 .TypeConstraint<type>("T") \
589 .TypeConstraint<index_type>("Tindices"), \
590 TensorScatterOp<dev##Device, type, index_type, \
591 scatter_nd_op::UpdateOp::MIN>)
592
593#define REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(index_type) \
594 REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
595 .Device(DEVICE_DEFAULT) \
596 .TypeConstraint<int32>("T") \
597 .TypeConstraint<index_type>("Tindices") \
598 .HostMemory("tensor") \
599 .HostMemory("indices") \
600 .HostMemory("updates") \
601 .HostMemory("output"), \
602 TensorScatterOp<CPUDevice, int32, index_type, \
603 scatter_nd_op::UpdateOp::MIN>)
604
605#define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \
606 REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
607 .Device(DEVICE_##dev) \
608 .TypeConstraint<type>("T") \
609 .TypeConstraint<index_type>("Tindices"), \
610 TensorScatterOp<dev##Device, type, index_type, \
611 scatter_nd_op::UpdateOp::MAX>)
612
613#define REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(index_type) \
614 REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
615 .Device(DEVICE_DEFAULT) \
616 .TypeConstraint<int32>("T") \
617 .TypeConstraint<index_type>("Tindices") \
618 .HostMemory("tensor") \
619 .HostMemory("indices") \
620 .HostMemory("updates") \
621 .HostMemory("output"), \
622 TensorScatterOp<CPUDevice, int32, index_type, \
623 scatter_nd_op::UpdateOp::MAX>)
624
625#define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \
626 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
627 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, CPU);
628
629#define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type) \
630 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \
631 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, CPU);
632
633#define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type) \
634 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
635 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, CPU);
636
637#define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type) \
638 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \
639 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, CPU);
640
641#define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type) \
642 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \
643 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, CPU);
644
645#define REGISTER_SCATTER_ND_TENSOR_CPU(type) \
646 REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
647 REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \
648 REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type);
649
650// Register TensorScatterUpdate/Add/Sub for all number types.
651TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
652// Register min/max operations only for Real number types
653TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU);
654TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU);
655// Register only TensorScatterUpdate for string/bool types as well.
656TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
657TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
658
659#undef REGISTER_SCATTER_ND_TENSOR_CPU
660
661// Registers GPU kernels.
662#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
663
664#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
665 REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
666
667#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
668 REGISTER_SCATTER_ND_UPDATE(type, GPU);
669
670#define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
671 REGISTER_SCATTER_ND_MIN_MAX(type, GPU);
672
673#define REGISTER_SCATTER_ND_ALL_GPU(type) \
674 REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
675 REGISTER_SCATTER_ND_UPDATE_GPU(type); \
676 REGISTER_SCATTER_ND_GPU(type);
677
678#define REGISTER_SCATTER_ND_ALL_INT32_GPU() \
679 REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU(); \
680 REGISTER_SCATTER_ND_UPDATE_INT32_GPU(); \
681 REGISTER_SCATTER_ND_INT32_GPU();
682
683REGISTER_SCATTER_ND_ALL_INT32_GPU();
684REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU();
685
686TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ND_ALL_GPU);
687TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ND_MIN_MAX_GPU);
688TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
689TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
690TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
691
692#undef REGISTER_SCATTER_ND_ALL_GPU
693
694#define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type) \
695 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \
696 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, GPU);
697
698#define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type) \
699 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \
700 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, GPU);
701
702#define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type) \
703 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
704 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, GPU);
705
706#define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type) \
707 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \
708 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, GPU);
709
710#define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type) \
711 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \
712 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, GPU);
713
714#define REGISTER_SCATTER_ND_TENSOR_GPU(type) \
715 REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \
716 REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
717 REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
718
719#define REGISTER_SCATTER_ND_TENSOR_INT32_GPU() \
720 REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int32); \
721 REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int64_t); \
722 REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int32); \
723 REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int64_t); \
724 REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int32); \
725 REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int64_t);
726
727#define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \
728 REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type); \
729 REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type);
730
731#define REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU() \
732 REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int32); \
733 REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int64_t); \
734 REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int32); \
735 REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int64_t);
736
737REGISTER_SCATTER_ND_TENSOR_INT32_GPU();
738REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU();
739
740TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
741TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
742TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
743TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
744TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU);
745
746#undef REGISTER_SCATTER_ND_ADD
747#undef REGISTER_SCATTER_ND_ADD_SUB
748#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
749#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
750#undef REGISTER_SCATTER_ND_MIN_MAX
751#undef REGISTER_SCATTER_ND_MIN_MAX_CPU
752#undef REGISTER_SCATTER_ND_MIN_MAX_GPU
753#undef REGISTER_SCATTER_ND_UPDATE
754#undef REGISTER_SCATTER_ND_UPDATE_CPU
755#undef REGISTER_SCATTER_ND_UPDATE_GPU
756#undef REGISTER_SCATTER_ND_KERNEL
757#undef REGISTER_SCATTER_ND_KERNEL_INDEX
758#undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE
759#undef REGISTER_SCATTER_ND_TENSOR_CPU
760#undef REGISTER_SCATTER_ND_TENSOR_GPU
761#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
762#undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
763#undef REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE
764#undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
765#undef REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE
766#undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE
767#undef REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE
768#undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE
769#undef REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE
770#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
771#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE
772#undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
773#undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
774#undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU
775#undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU
776#undef REGISTER_SCATTER_ND_TENSOR_GPU
777#undef REGISTER_SCATTER_ND_TENSOR_INT32_GPU
778#undef REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU
779#undef REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU
780#undef REGISTER_SCATTER_ND_ALL_INT32_GPU
781#undef REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU
782#undef REGISTER_SCATTER_ND_INT32_GPU
783#undef REGISTER_SCATTER_ND_UPDATE_INT32_GPU
784#undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
785#undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
786#undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
787#undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
788#undef REGISTER_SCATTER_ND_KERNEL_INT32_GPU
789#undef REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU
790
791#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
792
793namespace functor {
794
795template <typename Index>
796Status PrepareAndValidateInputs(const TensorShape& params_shape,
797 const Tensor& indices, const Tensor& updates,
798 int64_t* slice_dim, Index* num_updates,
799 Index* slice_size) {
800 const TensorShape& indices_shape(indices.shape());
801 const TensorShape& updates_shape(updates.shape());
802
803 if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
804 return errors::InvalidArgument("Output must be at least 1-D, ",
805 "got shape: ", params_shape.DebugString());
806 }
807
808 if (!ValidEmptyOutputShape(params_shape.num_elements(),
809 indices_shape.num_elements(),
810 updates_shape.num_elements())) {
811 return errors::InvalidArgument(
812 "Indices and updates specified for empty output. indices shape: ",
813 indices.shape().DebugString());
814 }
815
816 if (updates.dim_size(0) != indices.dim_size(0)) {
817 return errors::InvalidArgument(
818 "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(),
819 "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
820 "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
821 }
822 TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(),
823 updates.shape()));
824
825 // Check that we have enough index space
826 const int64_t N_big = indices.NumElements();
827 if (N_big > std::numeric_limits<Index>::max()) {
828 return errors::InvalidArgument("indices has too many elements for ",
829 DataTypeString(DataTypeToEnum<Index>::v()),
830 " indexing: ", N_big, " > ",
831 std::numeric_limits<Index>::max());
832 }
833 if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
834 return errors::InvalidArgument("params_shape[0] too large for ",
835 DataTypeString(DataTypeToEnum<Index>::v()),
836 " indexing: ", params_shape.dim_size(0),
837 " > ", std::numeric_limits<Index>::max());
838 }
839
840 // Calculate the number of dimensions in indices
841 *slice_dim = (indices_shape.dims() > 1)
842 ? indices_shape.dim_size(indices_shape.dims() - 1)
843 : 1;
844
845 // Calculate the number of elements that make up each slice of our updated
846 // tensor. This allows us to work with flattened tensors and copy over whole
847 // slices at a time.
848 Index total_nd = params_shape.dims();
849
850 int64_t slice_size_big = 1;
851 for (int64_t i = *slice_dim; i < total_nd; ++i) {
852 slice_size_big *= params_shape.dim_size(i);
853 }
854
855 if (slice_size_big > std::numeric_limits<Index>::max()) {
856 return errors::InvalidArgument(
857 "slice size is too large for indexing: ", slice_size_big, " > ",
858 std::numeric_limits<Index>::max());
859 }
860
861 *slice_size = static_cast<Index>(slice_size_big);
862
863 const int64_t safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
864 *num_updates = indices_shape.num_elements() / safe_slice_dim;
865
866 return OkStatus();
867}
868
869template <typename Device, typename Index>
870class IndexFlattener {
871 public:
872 inline typename TTypes<Index, 2>::ConstTensor operator()(
873 OpKernelContext*, const Tensor& indices) {
874 return indices.flat_inner_dims<Index>();
875 }
876};
877
878namespace {
879
880template <typename Device, typename T, typename Index,
881 scatter_nd_op::UpdateOp Op>
882Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
883 const Tensor& updates, const TensorShape& shape,
884 Tensor* out, bool allocate) {
885 int64_t slice_dim;
886 Index num_updates;
887 Index slice_size;
888 TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
889 shape, indices, updates, &slice_dim, &num_updates, &slice_size));
890
891 IndexFlattener<Device, Index> index_flattener;
892 auto indices_flat = index_flattener(c, indices);
893 auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
894
895 if (allocate) {
896 AllocatorAttributes alloc_attr;
897 if (std::is_same<Device, CPUDevice>::value) {
898 alloc_attr.set_on_host(true);
899 }
900 TF_RETURN_IF_ERROR(
901 c->allocate_temp(DataTypeToEnum<T>::value, shape, out, alloc_attr));
902 } else {
903 CHECK_NOTNULL(out);
904 }
905
906 if (shape.num_elements() == 0) {
907 return OkStatus();
908 }
909
910 if (allocate) {
911 // Brand new tensor, zero it out.
912 functor::SetZeroFunctor<Device, T> fill;
913 fill(c->eigen_device<Device>(), out->flat<T>());
914 }
915 auto output_matrix =
916 out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
917
918 Index bad_i = -1;
919
920 if (shape.num_elements() > 0) {
921 switch (slice_dim) {
922#define PARAMS_CASE(IXDIM) \
923 case IXDIM: { \
924 typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
925 for (int i = 0; i < IXDIM; ++i) { \
926 output_shape_prefix[i] = shape.dim_size(i); \
927 } \
928 functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
929 bad_i = \
930 functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
931 output_matrix, indices_flat, updates_flat, output_matrix); \
932 } break
933 // TODO(simister): Re-enable this once binary size is under control.
934 // PARAMS_CASE(0);
935 PARAMS_CASE(1);
936 PARAMS_CASE(2);
937 PARAMS_CASE(3);
938 PARAMS_CASE(4);
939 PARAMS_CASE(5);
940 PARAMS_CASE(6);
941 PARAMS_CASE(7);
942#undef PARAMS_CASE
943 default:
944 return errors::InvalidArgument(
945 "Only indices.shape[-1] values between 1 and 5 "
946 "are currently supported. Requested rank: ",
947 slice_dim);
948 }
949 }
950 if (bad_i >= 0) {
951 auto slice_shape = indices.shape();
952 slice_shape.RemoveLastDims(1);
953 return errors::InvalidArgument(
954 "indices", SliceDebugString(slice_shape, bad_i), " = [",
955 absl::StrJoin(
956 gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
957 "] does not index into shape ", shape.DebugString());
958 }
959 return OkStatus();
960}
961
962template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
963Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
964 const Tensor& updates, const TensorShape& shape,
965 Tensor* out, bool allocate);
966
967#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
968
969// Copies inputs to the CPU, runs DoScatterNd on the CPU, then copies output
970// back to GPU. This is useful because the CPU implementation is deterministic
971// and the GPU implementation is not. Tensor inputs to this function must be on
972// the GPU.
973template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
974Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
975 const Tensor& updates, const TensorShape& shape,
976 Tensor* out, bool allocate) {
977 AllocatorAttributes alloc_attr;
978 alloc_attr.set_on_host(true);
979 alloc_attr.set_gpu_compatible(true);
980 auto stream = c->op_device_context()->stream();
981
982 // Copy 'indices' to host.
983 Tensor host_indices;
984 TF_RETURN_IF_ERROR(c->allocate_temp(indices.dtype(), indices.shape(),
985 &host_indices, alloc_attr));
986 se::DeviceMemoryBase indices_ptr(
987 const_cast<Tensor&>(indices).flat<Index>().data(),
988 indices.flat<Index>().size() * sizeof(Index));
989 stream->ThenMemcpy(host_indices.flat<Index>().data(), indices_ptr,
990 indices.NumElements() * sizeof(Index));
991 if (!stream) {
992 return errors::Internal("Failed to copy indices to host");
993 }
994
995 // Copy 'updates' to host.
996 Tensor host_updates;
997 TF_RETURN_IF_ERROR(c->allocate_temp(updates.dtype(), updates.shape(),
998 &host_updates, alloc_attr));
999 se::DeviceMemoryBase updates_ptr(
1000 const_cast<Tensor&>(updates).flat<T>().data(),
1001 updates.flat<T>().size() * sizeof(T));
1002 stream->ThenMemcpy(host_updates.flat<T>().data(), updates_ptr,
1003 updates.NumElements() * sizeof(T));
1004 if (!stream) {
1005 return errors::Internal("Failed to copy updates to host");
1006 }
1007
1008 // Create 'out' on host, copying from device if 'allocate' is false.
1009 Tensor host_out;
1010 TF_RETURN_IF_ERROR(
1011 c->allocate_temp(updates.dtype(), shape, &host_out, alloc_attr));
1012 if (allocate) {
1013 TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
1014 functor::SetZeroFunctor<CPUDevice, T> fill;
1015 fill(c->eigen_device<CPUDevice>(), host_out.flat<T>());
1016 } else {
1017 CHECK_NOTNULL(out); // Crash OK
1018 se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
1019 out->flat<T>().size() * sizeof(T));
1020 stream->ThenMemcpy(host_out.flat<T>().data(), out_ptr,
1021 host_out.NumElements() * sizeof(T));
1022 if (!stream) {
1023 return errors::Internal("Failed to copy output to host");
1024 }
1025 }
1026
1027 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1028 TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
1029 c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));
1030
1031 // Copy 'host_out' to device.
1032 se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
1033 out->flat<T>().size() * sizeof(T));
1034 stream->ThenMemcpy(&out_ptr, host_out.flat<T>().data(),
1035 host_out.NumElements() * sizeof(T));
1036 if (!stream) {
1037 return errors::Internal("Failed to copy output to device");
1038 }
1039 // Block host, since 'host_out' cannot be destructed until the copy is done.
1040 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1041 return OkStatus();
1042}
1043
1044#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1045
1046} // namespace
1047
1048template <typename Device, typename T, typename Index,
1049 scatter_nd_op::UpdateOp Op>
1050Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
1051 const Tensor& updates, const TensorShape& shape, Tensor* out,
1052 bool allocate) {
1053#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1054 if (std::is_same<Device, GPUDevice>::value &&
1055 tensorflow::OpDeterminismRequired()) {
1056 return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
1057 allocate);
1058 }
1059#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1060
1061 // Run on the CPU for integer types, since the GPU implementation uses
1062 // atomics, which are not supported for all integer types.
1063 if constexpr (std::is_same<Device, GPUDevice>::value &&
1064 std::is_integral<T>::value) {
1065 return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
1066 allocate);
1067 } else {
1068 return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
1069 out, allocate);
1070 }
1071}
1072} // namespace functor
1073
1074#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1075// Forward declarations of the functor specializations for GPU.
1076namespace functor {
1077#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
1078 template <> \
1079 Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
1080 const GPUDevice& d, const Index slice_size, \
1081 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
1082 typename TTypes<T, 2>::Tensor Tparams, \
1083 typename TTypes<Index, 2>::ConstTensor Tindices, \
1084 typename TTypes<T, 2>::ConstTensor Tupdates, \
1085 typename TTypes<T, 2>::Tensor Toutput); \
1086 extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
1087
1088#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
1089 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
1090 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
1091 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
1092 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
1093 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
1094 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
1095 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
1096
1097#define DECLARE_GPU_SPECS_INDEX(T, Index) \
1098 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
1099 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
1100 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
1101
1102#define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index) \
1103 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \
1104 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX)
1105
1106#define DECLARE_GPU_SPECS(T) \
1107 DECLARE_GPU_SPECS_INDEX(T, int32); \
1108 DECLARE_GPU_SPECS_INDEX(T, int64_t)
1109
1110#define DECLARE_GPU_SPECS_MIN_MAX(T) \
1111 DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \
1112 DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64_t)
1113
1114TF_CALL_int32(DECLARE_GPU_SPECS);
1115TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
1116TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
1117TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
1118TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
1119
1120#undef DECLARE_GPU_SPECS_MIN_MAX
1121#undef DECLARE_GPU_SPECS
1122#undef DECLARE_GPU_SPECS_INDEX_MIN_MAX
1123#undef DECLARE_GPU_SPECS_INDEX
1124#undef DECLARE_GPU_SPECS_INDEX_OP
1125
1126} // namespace functor
1127
1128#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1129
1130} // namespace tensorflow
1131