1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/state_ops.cc. |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
20 | #define EIGEN_USE_GPU |
21 | #include "tensorflow/core/platform/stream_executor.h" |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include "tensorflow/core/framework/bounds_check.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_shape.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/kernels/dense_update_functor.h" |
31 | #include "tensorflow/core/kernels/fill_functor.h" |
32 | #include "tensorflow/core/kernels/inplace_ops_functor.h" |
33 | #include "tensorflow/core/kernels/scatter_nd_op.h" |
34 | #include "tensorflow/core/kernels/scatter_nd_util.h" |
35 | #include "tensorflow/core/kernels/training_op_helpers.h" |
36 | #include "tensorflow/core/kernels/variable_ops.h" |
37 | #include "tensorflow/core/lib/strings/str_util.h" |
38 | #include "tensorflow/core/platform/mutex.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | #include "tensorflow/core/util/determinism.h" |
41 | #include "tensorflow/core/util/util.h" |
42 | |
43 | namespace tensorflow { |
44 | |
45 | typedef Eigen::ThreadPoolDevice CPUDevice; |
46 | typedef Eigen::GpuDevice GPUDevice; |
47 | |
48 | // Returns true if the three tensors have valid number of elements |
49 | // If shape_input has 0 elements, then we need to have indices and updates with |
50 | // exactly 0 elements too, otherwise we should error. If indices has 0 elements |
51 | // then updates should also have 0 elements, otherwise we should error. |
52 | bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices, |
53 | int64_t num_updates) { |
54 | if (num_indices == 0 && num_updates == 0) { |
55 | return true; // regardless of num_inputs ?= 0, covers both cases |
56 | } |
57 | // now we want all 3 tensors to have values |
58 | return (num_inputs != 0 && num_indices != 0 && num_updates != 0); |
59 | } |
60 | |
61 | template <typename Device, typename T, typename Index> |
62 | class ScatterNdOp : public OpKernel { |
63 | public: |
64 | explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) { |
65 | const DataType dt = DataTypeToEnum<T>::v(); |
66 | const DataType index_t = DataTypeToEnum<Index>::v(); |
67 | OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt})); |
68 | } |
69 | |
70 | void Compute(OpKernelContext* c) override { |
71 | const Tensor& indices = c->input(0); |
72 | const Tensor& updates = c->input(1); |
73 | const Tensor& shape_input = c->input(2); |
74 | |
75 | OP_REQUIRES(c, indices.shape().dims() >= 1, |
76 | errors::InvalidArgument( |
77 | "Indices shape must have rank at least one. Found:" , |
78 | indices.shape().DebugString())); |
79 | OP_REQUIRES(c, updates.shape().dims() >= 1, |
80 | errors::InvalidArgument( |
81 | "Updates shape must have rank at least one. Found:" , |
82 | updates.shape().DebugString())); |
83 | |
84 | auto vec = shape_input.flat<Index>(); |
85 | TensorShape shape; |
86 | OP_REQUIRES_OK(c, |
87 | TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape)); |
88 | |
89 | OP_REQUIRES(c, |
90 | ValidEmptyOutputShape(shape_input.NumElements(), |
91 | indices.shape().num_elements(), |
92 | updates.shape().num_elements()), |
93 | errors::InvalidArgument( |
94 | "Indices and updates specified for empty output shape" )); |
95 | |
96 | const int64_t outer_dims = indices.shape().dims() - 1; |
97 | |
98 | for (int i = 0; i < outer_dims; ++i) { |
99 | OP_REQUIRES( |
100 | c, indices.shape().dim_size(i) == updates.shape().dim_size(i), |
101 | errors::InvalidArgument( |
102 | "Dimensions [0," , outer_dims, |
103 | ") of indices[shape=" , indices.shape().DebugString(), |
104 | "] must match dimensions [0," , outer_dims, |
105 | ") of updates[shape=" , updates.shape().DebugString(), "]" )); |
106 | } |
107 | |
108 | const int64_t ix = indices.shape().dim_size(outer_dims); |
109 | OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix, |
110 | errors::InvalidArgument( |
111 | "Dimensions [" , ix, "," , shape.dims(), ") of input[shape=" , |
112 | shape.DebugString(), "] must match dimensions [" , |
113 | outer_dims, "," , updates.shape().dims(), |
114 | ") of updates[shape=" , updates.shape().DebugString(), "]" )); |
115 | |
116 | for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { |
117 | OP_REQUIRES( |
118 | c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), |
119 | errors::InvalidArgument("Dimensions [" , ix, "," , shape.dims(), |
120 | ") of input[shape=" , shape.DebugString(), |
121 | "] must match dimensions [" , outer_dims, "," , |
122 | updates.shape().dims(), ") of updates[shape=" , |
123 | updates.shape().DebugString(), "]" )); |
124 | } |
125 | OP_REQUIRES(c, shape_input.dims() == 1, |
126 | errors::InvalidArgument("Shape must be a vector" )); |
127 | |
128 | Tensor out; |
129 | OP_REQUIRES_OK( |
130 | c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>( |
131 | c, indices, updates, shape, &out, true /*allocate*/)); |
132 | c->set_output(0, out); |
133 | } |
134 | }; |
135 | |
136 | template <typename Device, typename T, typename Index, |
137 | scatter_nd_op::UpdateOp op> |
138 | class TensorScatterOp : public OpKernel { |
139 | public: |
140 | explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) { |
141 | const DataType dt = DataTypeToEnum<T>::v(); |
142 | const DataType index_t = DataTypeToEnum<Index>::v(); |
143 | OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt})); |
144 | } |
145 | |
146 | void Compute(OpKernelContext* c) override { |
147 | const Tensor& input = c->input(0); |
148 | const Tensor& indices = c->input(1); |
149 | const Tensor& updates = c->input(2); |
150 | |
151 | OP_REQUIRES(c, indices.shape().dims() >= 1, |
152 | errors::InvalidArgument( |
153 | "Indices shape must have rank at least one. Found:" , |
154 | indices.shape().DebugString())); |
155 | OP_REQUIRES(c, updates.shape().dims() >= 1, |
156 | errors::InvalidArgument( |
157 | "Updates shape must have rank at least one. Found:" , |
158 | updates.shape().DebugString())); |
159 | |
160 | TensorShape shape = input.shape(); |
161 | |
162 | OP_REQUIRES(c, |
163 | ValidEmptyOutputShape(shape.num_elements(), |
164 | indices.shape().num_elements(), |
165 | updates.shape().num_elements()), |
166 | errors::InvalidArgument( |
167 | "Indices and updates specified for empty output shape" )); |
168 | |
169 | const int64_t outer_dims = indices.shape().dims() - 1; |
170 | |
171 | for (int i = 0; i < outer_dims; ++i) { |
172 | OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i), |
173 | errors::InvalidArgument( |
174 | "Outer dimensions of indices and update must match. " |
175 | "Indices shape: " , |
176 | indices.shape().DebugString(), |
177 | ", updates shape:" , updates.shape().DebugString())); |
178 | } |
179 | |
180 | const int64_t ix = indices.shape().dim_size(outer_dims); |
181 | OP_REQUIRES( |
182 | c, updates.shape().dims() - outer_dims == shape.dims() - ix, |
183 | errors::InvalidArgument("Inner dimensions of output shape must match " |
184 | "inner dimensions of updates shape. Output: " , |
185 | shape.DebugString(), |
186 | " updates: " , updates.shape().DebugString())); |
187 | for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { |
188 | OP_REQUIRES( |
189 | c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), |
190 | errors::InvalidArgument( |
191 | "The inner " , shape.dims() - ix, |
192 | " dimensions of output.shape=" , shape.DebugString(), |
193 | " must match the inner " , updates.shape().dims() - outer_dims, |
194 | " dimensions of updates.shape=" , updates.shape().DebugString())); |
195 | } |
196 | |
197 | AllocatorAttributes alloc_attr; |
198 | MemoryType memory_type = DEVICE_MEMORY; |
199 | if (std::is_same<Device, CPUDevice>::value) { |
200 | alloc_attr.set_on_host(true); |
201 | memory_type = HOST_MEMORY; |
202 | } else { |
203 | memory_type = DEVICE_MEMORY; |
204 | } |
205 | std::unique_ptr<Tensor> forwarded_input = |
206 | c->forward_input(0, 0, input.dtype(), shape, memory_type, alloc_attr); |
207 | |
208 | if (forwarded_input == nullptr) { |
209 | // We were not able to forward the input, so we deep copy the tensor and |
210 | // set the output. |
211 | Tensor* out; |
212 | OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out)); |
213 | |
214 | OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(), |
215 | input, out)); |
216 | OP_REQUIRES_OK(c, |
217 | functor::DoScatterNd<Device, T, Index, op>( |
218 | c, indices, updates, shape, out, false /*allocate*/)); |
219 | } else { |
220 | // Output forwarded, so simply perform the scatter. |
221 | OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>( |
222 | c, indices, updates, shape, forwarded_input.get(), |
223 | false /*allocate*/)); |
224 | |
225 | c->set_output(0, *forwarded_input); |
226 | } |
227 | } |
228 | }; |
229 | |
230 | template <typename Device, typename T, typename Index, |
231 | scatter_nd_op::UpdateOp op> |
232 | class ScatterNdUpdateOp : public OpKernel { |
233 | public: |
234 | explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) { |
235 | const DataType dt = DataTypeToEnum<T>::v(); |
236 | const DataType dt_ref = DataTypeToEnum<T>::ref(); |
237 | const DataType index_t = DataTypeToEnum<Index>::v(); |
238 | dtype_ = c->input_type(0); |
239 | // If we are updating a resource, we always use the exclusive lock. |
240 | // For ref types, we lock based on the use_locking parameter |
241 | // Otherwise, we don't mutate the input tensor (we copy-on-write if needed). |
242 | if (c->input_type(0) == DT_RESOURCE) { |
243 | // TODO(apassos): what to validate here? |
244 | } else if (IsRefType(c->input_type(0))) { |
245 | OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); |
246 | OP_REQUIRES_OK(c, c->GetAttr("use_locking" , &use_exclusive_lock_)); |
247 | } else { |
248 | OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt})); |
249 | use_exclusive_lock_ = false; |
250 | } |
251 | } |
252 | |
253 | void Compute(OpKernelContext* c) override { |
254 | if (dtype_ == DT_RESOURCE) { |
255 | core::RefCountPtr<Var> v; |
256 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
257 | OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get())); |
258 | mutex_lock m(*v->mu()); |
259 | DoCompute(c); |
260 | } else if (use_exclusive_lock_) { |
261 | // If we're here, it means the input type is a ref. |
262 | DCHECK(IsRefType(c->input_dtype(0))); |
263 | // Hold mutex while we apply updates |
264 | mutex_lock l(*c->input_ref_mutex(0)); |
265 | DoCompute(c); |
266 | } else { |
267 | DoCompute(c); |
268 | } |
269 | } |
270 | |
271 | private: |
272 | DataType dtype_; |
273 | bool use_exclusive_lock_; |
274 | |
275 | void DoCompute(OpKernelContext* c) { |
276 | const Tensor& indices = c->input(1); |
277 | const Tensor& updates = c->input(2); |
278 | Tensor params; |
279 | TensorShape params_shape; |
280 | |
281 | if (dtype_ == DT_RESOURCE) { |
282 | core::RefCountPtr<Var> v; |
283 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
284 | Tensor* t = v->tensor(); |
285 | params = *t; |
286 | params_shape = params.shape(); |
287 | } else if (IsRefType(c->input_dtype(0))) { |
288 | params = c->mutable_input(0, use_exclusive_lock_); |
289 | params_shape = params.shape(); |
290 | c->forward_ref_input_to_ref_output(0, 0); |
291 | OP_REQUIRES(c, params.IsInitialized(), |
292 | errors::FailedPrecondition("Null ref for params" )); |
293 | } else { |
294 | Tensor* params_ptr; |
295 | params_shape = c->input(0).shape(); |
296 | if (!c->forward_input_to_output_with_shape(0, 0, params_shape, |
297 | ¶ms_ptr)) { |
298 | // We weren't able to forward the input to output, so just |
299 | // allocate a new output tensor and copy the values over. |
300 | OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, ¶ms_ptr)); |
301 | params = *params_ptr; |
302 | functor::DenseUpdate<Device, T, ASSIGN> copy; |
303 | const Tensor& input_copy = c->input(0); |
304 | copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>()); |
305 | } else { |
306 | params = *params_ptr; |
307 | } |
308 | } |
309 | |
310 | OP_REQUIRES_OK( |
311 | c, functor::DoScatterNd<Device, T, Index, op>( |
312 | c, indices, updates, params_shape, ¶ms, false /*allocate*/)); |
313 | } |
314 | }; |
315 | |
316 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
317 | |
318 | #define REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(type) \ |
319 | template Status functor::DoScatterNd<GPUDevice, type, int64, \ |
320 | scatter_nd_op::UpdateOp::ASSIGN>( \ |
321 | OpKernelContext*, Tensor const&, Tensor const&, TensorShape const&, \ |
322 | Tensor*, bool); |
323 | |
324 | // Explicitly instantiate DoScatterNd for template arguments which are used |
325 | // by the CSRSparseMatrixToDense op. |
326 | REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(float) |
327 | REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(double) |
328 | REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(complex64) |
329 | REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU(complex128) |
330 | |
331 | #undef REGISTER_SCATTER_ND_ASSIGN_FUNCTION_GPU |
332 | |
333 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
334 | |
335 | #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \ |
336 | REGISTER_KERNEL_BUILDER(Name(name) \ |
337 | .Device(DEVICE_##dev) \ |
338 | .TypeConstraint<type>("T") \ |
339 | .TypeConstraint<index_type>("Tindices") \ |
340 | .HostMemory("shape"), \ |
341 | ScatterNdOp<dev##Device, type, index_type>) |
342 | |
343 | #define REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(index_type, name) \ |
344 | REGISTER_KERNEL_BUILDER(Name(name) \ |
345 | .Device(DEVICE_DEFAULT) \ |
346 | .TypeConstraint<int32>("T") \ |
347 | .TypeConstraint<index_type>("Tindices") \ |
348 | .HostMemory("indices") \ |
349 | .HostMemory("updates") \ |
350 | .HostMemory("shape") \ |
351 | .HostMemory("output"), \ |
352 | ScatterNdOp<CPUDevice, int32, index_type>) |
353 | |
354 | #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \ |
355 | op) \ |
356 | REGISTER_KERNEL_BUILDER( \ |
357 | Name(name) \ |
358 | .Device(DEVICE_##dev) \ |
359 | .TypeConstraint<type>("T") \ |
360 | .TypeConstraint<index_type>("Tindices"), \ |
361 | ScatterNdUpdateOp<dev##Device, type, index_type, op>) |
362 | |
363 | #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, name, \ |
364 | op) \ |
365 | REGISTER_KERNEL_BUILDER(Name(name) \ |
366 | .Device(DEVICE_DEFAULT) \ |
367 | .TypeConstraint<int32>("T") \ |
368 | .TypeConstraint<index_type>("Tindices") \ |
369 | .HostMemory("ref") \ |
370 | .HostMemory("indices") \ |
371 | .HostMemory("updates") \ |
372 | .HostMemory("output_ref"), \ |
373 | ScatterNdUpdateOp<CPUDevice, int32, index_type, op>) |
374 | |
375 | #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU( \ |
376 | index_type, name, op) \ |
377 | REGISTER_KERNEL_BUILDER(Name(name) \ |
378 | .Device(DEVICE_DEFAULT) \ |
379 | .TypeConstraint<int32>("T") \ |
380 | .TypeConstraint<index_type>("Tindices") \ |
381 | .HostMemory("input") \ |
382 | .HostMemory("indices") \ |
383 | .HostMemory("updates") \ |
384 | .HostMemory("output"), \ |
385 | ScatterNdUpdateOp<CPUDevice, int32, index_type, op>) |
386 | |
387 | #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \ |
388 | dev, name, op) \ |
389 | REGISTER_KERNEL_BUILDER( \ |
390 | Name(name) \ |
391 | .Device(DEVICE_##dev) \ |
392 | .TypeConstraint<type>("T") \ |
393 | .TypeConstraint<index_type>("Tindices") \ |
394 | .HostMemory("ref"), \ |
395 | ScatterNdUpdateOp<dev##Device, type, index_type, op>) |
396 | |
397 | #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, \ |
398 | name, op) \ |
399 | REGISTER_KERNEL_BUILDER(Name(name) \ |
400 | .Device(DEVICE_DEFAULT) \ |
401 | .TypeConstraint<int32>("T") \ |
402 | .TypeConstraint<index_type>("Tindices") \ |
403 | .HostMemory("ref") \ |
404 | .HostMemory("indices") \ |
405 | .HostMemory("updates"), \ |
406 | ScatterNdUpdateOp<CPUDevice, int32, index_type, op>) |
407 | |
408 | #define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \ |
409 | REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \ |
410 | REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64_t, dev, name) |
411 | |
412 | #define REGISTER_SCATTER_ND_KERNEL_INT32_GPU(name) \ |
413 | REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int32, name); \ |
414 | REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int64_t, name) |
415 | |
416 | #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ |
417 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \ |
418 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op) |
419 | |
420 | #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \ |
421 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \ |
422 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op) |
423 | |
424 | #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU(name, op) \ |
425 | REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, \ |
426 | op); \ |
427 | REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, \ |
428 | name, op) |
429 | |
430 | #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ |
431 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \ |
432 | op); \ |
433 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op) |
434 | |
435 | #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \ |
436 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \ |
437 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op) |
438 | |
439 | #define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ |
440 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ |
441 | scatter_nd_op::UpdateOp::ADD); \ |
442 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \ |
443 | scatter_nd_op::UpdateOp::ADD); \ |
444 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ |
445 | scatter_nd_op::UpdateOp::SUB); \ |
446 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ |
447 | type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \ |
448 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ |
449 | type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB); |
450 | |
451 | #define REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU() \ |
452 | REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU( \ |
453 | "ScatterNdNonAliasingAdd", scatter_nd_op::UpdateOp::ADD); \ |
454 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdAdd", \ |
455 | scatter_nd_op::UpdateOp::ADD); \ |
456 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdSub", \ |
457 | scatter_nd_op::UpdateOp::SUB); \ |
458 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
459 | "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \ |
460 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
461 | "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB); |
462 | |
463 | #define REGISTER_SCATTER_ND(type, dev) \ |
464 | REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); |
465 | |
466 | #define REGISTER_SCATTER_ND_INT32_GPU() \ |
467 | REGISTER_SCATTER_ND_KERNEL_INT32_GPU("ScatterNd"); |
468 | |
469 | #define REGISTER_SCATTER_ND_UPDATE(type, dev) \ |
470 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \ |
471 | scatter_nd_op::UpdateOp::ASSIGN); \ |
472 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ |
473 | type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); |
474 | |
475 | #define REGISTER_SCATTER_ND_UPDATE_INT32_GPU() \ |
476 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
477 | "ScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); \ |
478 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
479 | "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); |
480 | |
481 | #define REGISTER_SCATTER_ND_MIN_MAX(type, dev) \ |
482 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax", \ |
483 | scatter_nd_op::UpdateOp::MAX); \ |
484 | REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin", \ |
485 | scatter_nd_op::UpdateOp::MIN); \ |
486 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ |
487 | type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \ |
488 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ |
489 | type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX); |
490 | |
491 | #define REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU() \ |
492 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMax", \ |
493 | scatter_nd_op::UpdateOp::MAX); \ |
494 | REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMin", \ |
495 | scatter_nd_op::UpdateOp::MIN); \ |
496 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
497 | "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \ |
498 | REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \ |
499 | "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX); |
500 | |
501 | // Registers CPU kernels. |
502 | #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \ |
503 | REGISTER_SCATTER_ND_ADD_SUB(type, CPU); |
504 | |
505 | #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \ |
506 | REGISTER_SCATTER_ND_UPDATE(type, CPU); |
507 | |
508 | #define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \ |
509 | REGISTER_SCATTER_ND_MIN_MAX(type, CPU); |
510 | |
511 | #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); |
512 | #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU); |
513 | |
514 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); |
515 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); |
516 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); |
517 | TF_CALL_tstring(REGISTER_SCATTER_ND_CPU); |
518 | TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU); |
519 | TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU); |
520 | TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU); |
521 | TF_CALL_bool(REGISTER_SCATTER_ND_CPU); |
522 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU); |
523 | |
524 | #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \ |
525 | dev) \ |
526 | REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \ |
527 | .Device(DEVICE_##dev) \ |
528 | .TypeConstraint<type>("T") \ |
529 | .TypeConstraint<index_type>("Tindices"), \ |
530 | TensorScatterOp<dev##Device, type, index_type, \ |
531 | scatter_nd_op::UpdateOp::ASSIGN>) |
532 | |
533 | #define REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(index_type) \ |
534 | REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \ |
535 | .Device(DEVICE_DEFAULT) \ |
536 | .TypeConstraint<int32>("T") \ |
537 | .TypeConstraint<index_type>("Tindices") \ |
538 | .HostMemory("tensor") \ |
539 | .HostMemory("indices") \ |
540 | .HostMemory("updates") \ |
541 | .HostMemory("output"), \ |
542 | TensorScatterOp<CPUDevice, int32, index_type, \ |
543 | scatter_nd_op::UpdateOp::ASSIGN>) |
544 | |
545 | #define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \ |
546 | REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \ |
547 | .Device(DEVICE_##dev) \ |
548 | .TypeConstraint<type>("T") \ |
549 | .TypeConstraint<index_type>("Tindices"), \ |
550 | TensorScatterOp<dev##Device, type, index_type, \ |
551 | scatter_nd_op::UpdateOp::ADD>) |
552 | |
553 | #define REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(index_type) \ |
554 | REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \ |
555 | .Device(DEVICE_DEFAULT) \ |
556 | .TypeConstraint<int32>("T") \ |
557 | .TypeConstraint<index_type>("Tindices") \ |
558 | .HostMemory("tensor") \ |
559 | .HostMemory("indices") \ |
560 | .HostMemory("updates") \ |
561 | .HostMemory("output"), \ |
562 | TensorScatterOp<CPUDevice, int32, index_type, \ |
563 | scatter_nd_op::UpdateOp::ADD>) |
564 | |
565 | #define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \ |
566 | REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \ |
567 | .Device(DEVICE_##dev) \ |
568 | .TypeConstraint<type>("T") \ |
569 | .TypeConstraint<index_type>("Tindices"), \ |
570 | TensorScatterOp<dev##Device, type, index_type, \ |
571 | scatter_nd_op::UpdateOp::SUB>) |
572 | |
573 | #define REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(index_type) \ |
574 | REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \ |
575 | .Device(DEVICE_DEFAULT) \ |
576 | .TypeConstraint<int32>("T") \ |
577 | .TypeConstraint<index_type>("Tindices") \ |
578 | .HostMemory("tensor") \ |
579 | .HostMemory("indices") \ |
580 | .HostMemory("updates") \ |
581 | .HostMemory("output"), \ |
582 | TensorScatterOp<CPUDevice, int32, index_type, \ |
583 | scatter_nd_op::UpdateOp::SUB>) |
584 | |
585 | #define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \ |
586 | REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \ |
587 | .Device(DEVICE_##dev) \ |
588 | .TypeConstraint<type>("T") \ |
589 | .TypeConstraint<index_type>("Tindices"), \ |
590 | TensorScatterOp<dev##Device, type, index_type, \ |
591 | scatter_nd_op::UpdateOp::MIN>) |
592 | |
593 | #define REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(index_type) \ |
594 | REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \ |
595 | .Device(DEVICE_DEFAULT) \ |
596 | .TypeConstraint<int32>("T") \ |
597 | .TypeConstraint<index_type>("Tindices") \ |
598 | .HostMemory("tensor") \ |
599 | .HostMemory("indices") \ |
600 | .HostMemory("updates") \ |
601 | .HostMemory("output"), \ |
602 | TensorScatterOp<CPUDevice, int32, index_type, \ |
603 | scatter_nd_op::UpdateOp::MIN>) |
604 | |
605 | #define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \ |
606 | REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \ |
607 | .Device(DEVICE_##dev) \ |
608 | .TypeConstraint<type>("T") \ |
609 | .TypeConstraint<index_type>("Tindices"), \ |
610 | TensorScatterOp<dev##Device, type, index_type, \ |
611 | scatter_nd_op::UpdateOp::MAX>) |
612 | |
613 | #define REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(index_type) \ |
614 | REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \ |
615 | .Device(DEVICE_DEFAULT) \ |
616 | .TypeConstraint<int32>("T") \ |
617 | .TypeConstraint<index_type>("Tindices") \ |
618 | .HostMemory("tensor") \ |
619 | .HostMemory("indices") \ |
620 | .HostMemory("updates") \ |
621 | .HostMemory("output"), \ |
622 | TensorScatterOp<CPUDevice, int32, index_type, \ |
623 | scatter_nd_op::UpdateOp::MAX>) |
624 | |
625 | #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \ |
626 | REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \ |
627 | REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, CPU); |
628 | |
629 | #define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type) \ |
630 | REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \ |
631 | REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, CPU); |
632 | |
633 | #define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type) \ |
634 | REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \ |
635 | REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, CPU); |
636 | |
637 | #define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type) \ |
638 | REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \ |
639 | REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, CPU); |
640 | |
641 | #define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type) \ |
642 | REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \ |
643 | REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, CPU); |
644 | |
645 | #define REGISTER_SCATTER_ND_TENSOR_CPU(type) \ |
646 | REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \ |
647 | REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \ |
648 | REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type); |
649 | |
650 | // Register TensorScatterUpdate/Add/Sub for all number types. |
651 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU); |
652 | // Register min/max operations only for Real number types |
653 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU); |
654 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU); |
655 | // Register only TensorScatterUpdate for string/bool types as well. |
656 | TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU); |
657 | TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU); |
658 | |
659 | #undef REGISTER_SCATTER_ND_TENSOR_CPU |
660 | |
661 | // Registers GPU kernels. |
662 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
663 | |
664 | #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \ |
665 | REGISTER_SCATTER_ND_ADD_SUB(type, GPU); |
666 | |
667 | #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \ |
668 | REGISTER_SCATTER_ND_UPDATE(type, GPU); |
669 | |
670 | #define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \ |
671 | REGISTER_SCATTER_ND_MIN_MAX(type, GPU); |
672 | |
673 | #define REGISTER_SCATTER_ND_ALL_GPU(type) \ |
674 | REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \ |
675 | REGISTER_SCATTER_ND_UPDATE_GPU(type); \ |
676 | REGISTER_SCATTER_ND_GPU(type); |
677 | |
678 | #define REGISTER_SCATTER_ND_ALL_INT32_GPU() \ |
679 | REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU(); \ |
680 | REGISTER_SCATTER_ND_UPDATE_INT32_GPU(); \ |
681 | REGISTER_SCATTER_ND_INT32_GPU(); |
682 | |
683 | REGISTER_SCATTER_ND_ALL_INT32_GPU(); |
684 | REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU(); |
685 | |
686 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ND_ALL_GPU); |
687 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ND_MIN_MAX_GPU); |
688 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU); |
689 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU); |
690 | TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_ALL_GPU); |
691 | |
692 | #undef REGISTER_SCATTER_ND_ALL_GPU |
693 | |
694 | #define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type) \ |
695 | REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \ |
696 | REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, GPU); |
697 | |
698 | #define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type) \ |
699 | REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \ |
700 | REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, GPU); |
701 | |
702 | #define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type) \ |
703 | REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \ |
704 | REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, GPU); |
705 | |
706 | #define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type) \ |
707 | REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \ |
708 | REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, GPU); |
709 | |
710 | #define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type) \ |
711 | REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \ |
712 | REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, GPU); |
713 | |
714 | #define REGISTER_SCATTER_ND_TENSOR_GPU(type) \ |
715 | REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \ |
716 | REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \ |
717 | REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type); |
718 | |
719 | #define REGISTER_SCATTER_ND_TENSOR_INT32_GPU() \ |
720 | REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int32); \ |
721 | REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int64_t); \ |
722 | REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int32); \ |
723 | REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int64_t); \ |
724 | REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int32); \ |
725 | REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int64_t); |
726 | |
727 | #define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \ |
728 | REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type); \ |
729 | REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type); |
730 | |
731 | #define REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU() \ |
732 | REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int32); \ |
733 | REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int64_t); \ |
734 | REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int32); \ |
735 | REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int64_t); |
736 | |
737 | REGISTER_SCATTER_ND_TENSOR_INT32_GPU(); |
738 | REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU(); |
739 | |
740 | TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU); |
741 | TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX); |
742 | TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU); |
743 | TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX); |
744 | TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU); |
745 | |
746 | #undef REGISTER_SCATTER_ND_ADD |
747 | #undef REGISTER_SCATTER_ND_ADD_SUB |
748 | #undef REGISTER_SCATTER_ND_ADD_SUB_CPU |
749 | #undef REGISTER_SCATTER_ND_ADD_SUB_GPU |
750 | #undef REGISTER_SCATTER_ND_MIN_MAX |
751 | #undef REGISTER_SCATTER_ND_MIN_MAX_CPU |
752 | #undef REGISTER_SCATTER_ND_MIN_MAX_GPU |
753 | #undef REGISTER_SCATTER_ND_UPDATE |
754 | #undef REGISTER_SCATTER_ND_UPDATE_CPU |
755 | #undef REGISTER_SCATTER_ND_UPDATE_GPU |
756 | #undef REGISTER_SCATTER_ND_KERNEL |
757 | #undef REGISTER_SCATTER_ND_KERNEL_INDEX |
758 | #undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE |
759 | #undef REGISTER_SCATTER_ND_TENSOR_CPU |
760 | #undef REGISTER_SCATTER_ND_TENSOR_GPU |
761 | #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE |
762 | #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE |
763 | #undef REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE |
764 | #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE |
765 | #undef REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE |
766 | #undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE |
767 | #undef REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE |
768 | #undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE |
769 | #undef REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE |
770 | #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU |
771 | #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE |
772 | #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU |
773 | #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU |
774 | #undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU |
775 | #undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU |
776 | #undef REGISTER_SCATTER_ND_TENSOR_GPU |
777 | #undef REGISTER_SCATTER_ND_TENSOR_INT32_GPU |
778 | #undef REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU |
779 | #undef REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU |
780 | #undef REGISTER_SCATTER_ND_ALL_INT32_GPU |
781 | #undef REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU |
782 | #undef REGISTER_SCATTER_ND_INT32_GPU |
783 | #undef REGISTER_SCATTER_ND_UPDATE_INT32_GPU |
784 | #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU |
785 | #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU |
786 | #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU |
787 | #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU |
788 | #undef REGISTER_SCATTER_ND_KERNEL_INT32_GPU |
789 | #undef REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU |
790 | |
791 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
792 | |
793 | namespace functor { |
794 | |
795 | template <typename Index> |
796 | Status PrepareAndValidateInputs(const TensorShape& params_shape, |
797 | const Tensor& indices, const Tensor& updates, |
798 | int64_t* slice_dim, Index* num_updates, |
799 | Index* slice_size) { |
800 | const TensorShape& indices_shape(indices.shape()); |
801 | const TensorShape& updates_shape(updates.shape()); |
802 | |
803 | if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) { |
804 | return errors::InvalidArgument("Output must be at least 1-D, " , |
805 | "got shape: " , params_shape.DebugString()); |
806 | } |
807 | |
808 | if (!ValidEmptyOutputShape(params_shape.num_elements(), |
809 | indices_shape.num_elements(), |
810 | updates_shape.num_elements())) { |
811 | return errors::InvalidArgument( |
812 | "Indices and updates specified for empty output. indices shape: " , |
813 | indices.shape().DebugString()); |
814 | } |
815 | |
816 | if (updates.dim_size(0) != indices.dim_size(0)) { |
817 | return errors::InvalidArgument( |
818 | "Dimensions [0,1) of indices[shape=" , indices_shape.DebugString(), |
819 | "] = " , indices.dim_size(0), " must match dimensions [0,1) of updates[" , |
820 | "shape=" , updates_shape.DebugString(), "] = " , updates.dim_size(0)); |
821 | } |
822 | TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(), |
823 | updates.shape())); |
824 | |
825 | // Check that we have enough index space |
826 | const int64_t N_big = indices.NumElements(); |
827 | if (N_big > std::numeric_limits<Index>::max()) { |
828 | return errors::InvalidArgument("indices has too many elements for " , |
829 | DataTypeString(DataTypeToEnum<Index>::v()), |
830 | " indexing: " , N_big, " > " , |
831 | std::numeric_limits<Index>::max()); |
832 | } |
833 | if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) { |
834 | return errors::InvalidArgument("params_shape[0] too large for " , |
835 | DataTypeString(DataTypeToEnum<Index>::v()), |
836 | " indexing: " , params_shape.dim_size(0), |
837 | " > " , std::numeric_limits<Index>::max()); |
838 | } |
839 | |
840 | // Calculate the number of dimensions in indices |
841 | *slice_dim = (indices_shape.dims() > 1) |
842 | ? indices_shape.dim_size(indices_shape.dims() - 1) |
843 | : 1; |
844 | |
845 | // Calculate the number of elements that make up each slice of our updated |
846 | // tensor. This allows us to work with flattened tensors and copy over whole |
847 | // slices at a time. |
848 | Index total_nd = params_shape.dims(); |
849 | |
850 | int64_t slice_size_big = 1; |
851 | for (int64_t i = *slice_dim; i < total_nd; ++i) { |
852 | slice_size_big *= params_shape.dim_size(i); |
853 | } |
854 | |
855 | if (slice_size_big > std::numeric_limits<Index>::max()) { |
856 | return errors::InvalidArgument( |
857 | "slice size is too large for indexing: " , slice_size_big, " > " , |
858 | std::numeric_limits<Index>::max()); |
859 | } |
860 | |
861 | *slice_size = static_cast<Index>(slice_size_big); |
862 | |
863 | const int64_t safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim; |
864 | *num_updates = indices_shape.num_elements() / safe_slice_dim; |
865 | |
866 | return OkStatus(); |
867 | } |
868 | |
869 | template <typename Device, typename Index> |
870 | class IndexFlattener { |
871 | public: |
872 | inline typename TTypes<Index, 2>::ConstTensor operator()( |
873 | OpKernelContext*, const Tensor& indices) { |
874 | return indices.flat_inner_dims<Index>(); |
875 | } |
876 | }; |
877 | |
878 | namespace { |
879 | |
880 | template <typename Device, typename T, typename Index, |
881 | scatter_nd_op::UpdateOp Op> |
882 | Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, |
883 | const Tensor& updates, const TensorShape& shape, |
884 | Tensor* out, bool allocate) { |
885 | int64_t slice_dim; |
886 | Index num_updates; |
887 | Index slice_size; |
888 | TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>( |
889 | shape, indices, updates, &slice_dim, &num_updates, &slice_size)); |
890 | |
891 | IndexFlattener<Device, Index> index_flattener; |
892 | auto indices_flat = index_flattener(c, indices); |
893 | auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); |
894 | |
895 | if (allocate) { |
896 | AllocatorAttributes alloc_attr; |
897 | if (std::is_same<Device, CPUDevice>::value) { |
898 | alloc_attr.set_on_host(true); |
899 | } |
900 | TF_RETURN_IF_ERROR( |
901 | c->allocate_temp(DataTypeToEnum<T>::value, shape, out, alloc_attr)); |
902 | } else { |
903 | CHECK_NOTNULL(out); |
904 | } |
905 | |
906 | if (shape.num_elements() == 0) { |
907 | return OkStatus(); |
908 | } |
909 | |
910 | if (allocate) { |
911 | // Brand new tensor, zero it out. |
912 | functor::SetZeroFunctor<Device, T> fill; |
913 | fill(c->eigen_device<Device>(), out->flat<T>()); |
914 | } |
915 | auto output_matrix = |
916 | out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size}); |
917 | |
918 | Index bad_i = -1; |
919 | |
920 | if (shape.num_elements() > 0) { |
921 | switch (slice_dim) { |
922 | #define PARAMS_CASE(IXDIM) \ |
923 | case IXDIM: { \ |
924 | typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \ |
925 | for (int i = 0; i < IXDIM; ++i) { \ |
926 | output_shape_prefix[i] = shape.dim_size(i); \ |
927 | } \ |
928 | functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \ |
929 | bad_i = \ |
930 | functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \ |
931 | output_matrix, indices_flat, updates_flat, output_matrix); \ |
932 | } break |
933 | // TODO(simister): Re-enable this once binary size is under control. |
934 | // PARAMS_CASE(0); |
935 | PARAMS_CASE(1); |
936 | PARAMS_CASE(2); |
937 | PARAMS_CASE(3); |
938 | PARAMS_CASE(4); |
939 | PARAMS_CASE(5); |
940 | PARAMS_CASE(6); |
941 | PARAMS_CASE(7); |
942 | #undef PARAMS_CASE |
943 | default: |
944 | return errors::InvalidArgument( |
945 | "Only indices.shape[-1] values between 1 and 5 " |
946 | "are currently supported. Requested rank: " , |
947 | slice_dim); |
948 | } |
949 | } |
950 | if (bad_i >= 0) { |
951 | auto slice_shape = indices.shape(); |
952 | slice_shape.RemoveLastDims(1); |
953 | return errors::InvalidArgument( |
954 | "indices" , SliceDebugString(slice_shape, bad_i), " = [" , |
955 | absl::StrJoin( |
956 | gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", " ), |
957 | "] does not index into shape " , shape.DebugString()); |
958 | } |
959 | return OkStatus(); |
960 | } |
961 | |
962 | template <typename T, typename Index, scatter_nd_op::UpdateOp Op> |
963 | Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, |
964 | const Tensor& updates, const TensorShape& shape, |
965 | Tensor* out, bool allocate); |
966 | |
967 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
968 | |
969 | // Copies inputs to the CPU, runs DoScatterNd on the CPU, then copies output |
970 | // back to GPU. This is useful because the CPU implementation is deterministic |
971 | // and the GPU implementation is not. Tensor inputs to this function must be on |
972 | // the GPU. |
973 | template <typename T, typename Index, scatter_nd_op::UpdateOp Op> |
974 | Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, |
975 | const Tensor& updates, const TensorShape& shape, |
976 | Tensor* out, bool allocate) { |
977 | AllocatorAttributes alloc_attr; |
978 | alloc_attr.set_on_host(true); |
979 | alloc_attr.set_gpu_compatible(true); |
980 | auto stream = c->op_device_context()->stream(); |
981 | |
982 | // Copy 'indices' to host. |
983 | Tensor host_indices; |
984 | TF_RETURN_IF_ERROR(c->allocate_temp(indices.dtype(), indices.shape(), |
985 | &host_indices, alloc_attr)); |
986 | se::DeviceMemoryBase indices_ptr( |
987 | const_cast<Tensor&>(indices).flat<Index>().data(), |
988 | indices.flat<Index>().size() * sizeof(Index)); |
989 | stream->ThenMemcpy(host_indices.flat<Index>().data(), indices_ptr, |
990 | indices.NumElements() * sizeof(Index)); |
991 | if (!stream) { |
992 | return errors::Internal("Failed to copy indices to host" ); |
993 | } |
994 | |
995 | // Copy 'updates' to host. |
996 | Tensor host_updates; |
997 | TF_RETURN_IF_ERROR(c->allocate_temp(updates.dtype(), updates.shape(), |
998 | &host_updates, alloc_attr)); |
999 | se::DeviceMemoryBase updates_ptr( |
1000 | const_cast<Tensor&>(updates).flat<T>().data(), |
1001 | updates.flat<T>().size() * sizeof(T)); |
1002 | stream->ThenMemcpy(host_updates.flat<T>().data(), updates_ptr, |
1003 | updates.NumElements() * sizeof(T)); |
1004 | if (!stream) { |
1005 | return errors::Internal("Failed to copy updates to host" ); |
1006 | } |
1007 | |
1008 | // Create 'out' on host, copying from device if 'allocate' is false. |
1009 | Tensor host_out; |
1010 | TF_RETURN_IF_ERROR( |
1011 | c->allocate_temp(updates.dtype(), shape, &host_out, alloc_attr)); |
1012 | if (allocate) { |
1013 | TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out)); |
1014 | functor::SetZeroFunctor<CPUDevice, T> fill; |
1015 | fill(c->eigen_device<CPUDevice>(), host_out.flat<T>()); |
1016 | } else { |
1017 | CHECK_NOTNULL(out); // Crash OK |
1018 | se::DeviceMemoryBase out_ptr(out->flat<T>().data(), |
1019 | out->flat<T>().size() * sizeof(T)); |
1020 | stream->ThenMemcpy(host_out.flat<T>().data(), out_ptr, |
1021 | host_out.NumElements() * sizeof(T)); |
1022 | if (!stream) { |
1023 | return errors::Internal("Failed to copy output to host" ); |
1024 | } |
1025 | } |
1026 | |
1027 | TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); |
1028 | TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>( |
1029 | c, host_indices, host_updates, shape, &host_out, /*allocate=*/false)); |
1030 | |
1031 | // Copy 'host_out' to device. |
1032 | se::DeviceMemoryBase out_ptr(out->flat<T>().data(), |
1033 | out->flat<T>().size() * sizeof(T)); |
1034 | stream->ThenMemcpy(&out_ptr, host_out.flat<T>().data(), |
1035 | host_out.NumElements() * sizeof(T)); |
1036 | if (!stream) { |
1037 | return errors::Internal("Failed to copy output to device" ); |
1038 | } |
1039 | // Block host, since 'host_out' cannot be destructed until the copy is done. |
1040 | TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); |
1041 | return OkStatus(); |
1042 | } |
1043 | |
1044 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1045 | |
1046 | } // namespace |
1047 | |
1048 | template <typename Device, typename T, typename Index, |
1049 | scatter_nd_op::UpdateOp Op> |
1050 | Status DoScatterNd(OpKernelContext* c, const Tensor& indices, |
1051 | const Tensor& updates, const TensorShape& shape, Tensor* out, |
1052 | bool allocate) { |
1053 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1054 | if (std::is_same<Device, GPUDevice>::value && |
1055 | tensorflow::OpDeterminismRequired()) { |
1056 | return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out, |
1057 | allocate); |
1058 | } |
1059 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1060 | |
1061 | // Run on the CPU for integer types, since the GPU implementation uses |
1062 | // atomics, which are not supported for all integer types. |
1063 | if constexpr (std::is_same<Device, GPUDevice>::value && |
1064 | std::is_integral<T>::value) { |
1065 | return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out, |
1066 | allocate); |
1067 | } else { |
1068 | return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape, |
1069 | out, allocate); |
1070 | } |
1071 | } |
1072 | } // namespace functor |
1073 | |
1074 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1075 | // Forward declarations of the functor specializations for GPU. |
1076 | namespace functor { |
1077 | #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ |
1078 | template <> \ |
1079 | Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \ |
1080 | const GPUDevice& d, const Index slice_size, \ |
1081 | const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \ |
1082 | typename TTypes<T, 2>::Tensor Tparams, \ |
1083 | typename TTypes<Index, 2>::ConstTensor Tindices, \ |
1084 | typename TTypes<T, 2>::ConstTensor Tupdates, \ |
1085 | typename TTypes<T, 2>::Tensor Toutput); \ |
1086 | extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>; |
1087 | |
1088 | #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ |
1089 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ |
1090 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \ |
1091 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \ |
1092 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \ |
1093 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \ |
1094 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \ |
1095 | DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7); |
1096 | |
1097 | #define DECLARE_GPU_SPECS_INDEX(T, Index) \ |
1098 | DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ |
1099 | DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \ |
1100 | DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB) |
1101 | |
1102 | #define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index) \ |
1103 | DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \ |
1104 | DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX) |
1105 | |
1106 | #define DECLARE_GPU_SPECS(T) \ |
1107 | DECLARE_GPU_SPECS_INDEX(T, int32); \ |
1108 | DECLARE_GPU_SPECS_INDEX(T, int64_t) |
1109 | |
1110 | #define DECLARE_GPU_SPECS_MIN_MAX(T) \ |
1111 | DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \ |
1112 | DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64_t) |
1113 | |
1114 | TF_CALL_int32(DECLARE_GPU_SPECS); |
1115 | TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX); |
1116 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); |
1117 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX); |
1118 | TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS); |
1119 | |
1120 | #undef DECLARE_GPU_SPECS_MIN_MAX |
1121 | #undef DECLARE_GPU_SPECS |
1122 | #undef DECLARE_GPU_SPECS_INDEX_MIN_MAX |
1123 | #undef DECLARE_GPU_SPECS_INDEX |
1124 | #undef DECLARE_GPU_SPECS_INDEX_OP |
1125 | |
1126 | } // namespace functor |
1127 | |
1128 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1129 | |
1130 | } // namespace tensorflow |
1131 | |