1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Our general strategy for preventing conflicts between concurrent
17// reads and writes of resource variables is to:
18// * For read operations, we:
19// - acquire the variable's mutex (in "shared" mode);
20// - make a (shallow) copy of the Tensor object, which increments
21// the reference count on the variable's TensorBuffer;
22// - release the variable's mutex;
23// - use the copy of the Tensor object to do the read.
24// * For write operations, we:
25// - acquire the variable's mutex (in "exclusive" mode);
26// - check the reference count of variable's TensorBuffer and
27// if it is >1, make a deep copy of the variable's Tensor;
28// - mutate the variable's Tensor;
29// - and release the variable's mutex.
30// This allows several read operations to all use the same
31// TensorBuffer without needing to copy. When it comes time to write
32// it will only make a copy if there is an outstanding read using the
33// buffer. Write operations are serialized by the variable's mutex.
34//
35// For sparse operations (scatter, gather, sparse optimizer updates),
36// we need to avoid copies, since there may not be enough memory for
37// to copies of the whole tensor. To support this, we make two
38// modifications to the above strategy:
39// * For sparse reads (gather), we hold the variable's mutex (still in
40// "shared" mode) for the duration of the whole read. This means
41// that as long as you only do sparse read operations no write will
42// see the reference count >1.
43// * For sparse write operations where the user explicitly specifies
44// that they want to perform the write without locks held
45// (use_locking=false), we never copy even if the variable's
46// reference count is >1.
47
48#define EIGEN_USE_THREADS
49
50#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51#define EIGEN_USE_GPU
52#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
53#include "tensorflow/core/platform/stream_executor.h"
54#endif
55
56#include <memory>
57#include <type_traits>
58#include <vector>
59
60#include "absl/strings/str_join.h"
61#include "tensorflow/core/common_runtime/device.h"
62#include "tensorflow/core/framework/bounds_check.h"
63#include "tensorflow/core/framework/op_kernel.h"
64#include "tensorflow/core/framework/register_types.h"
65#include "tensorflow/core/framework/resource_mgr.h"
66#include "tensorflow/core/framework/tensor_shape.h"
67#include "tensorflow/core/framework/tensor_types.h"
68#include "tensorflow/core/framework/variant_op_registry.h"
69#include "tensorflow/core/kernels/dense_update_functor.h"
70#include "tensorflow/core/kernels/gather_functor.h"
71#include "tensorflow/core/kernels/gather_nd_op.h"
72#include "tensorflow/core/kernels/resource_variable_ops.h"
73#include "tensorflow/core/kernels/resource_variable_util.h"
74#include "tensorflow/core/kernels/scatter_functor.h"
75#include "tensorflow/core/kernels/training_op_helpers.h"
76#include "tensorflow/core/kernels/variable_ops.h"
77#include "tensorflow/core/lib/core/errors.h"
78#include "tensorflow/core/lib/core/refcount.h"
79#include "tensorflow/core/platform/casts.h"
80#include "tensorflow/core/platform/mem.h"
81#include "tensorflow/core/platform/mutex.h"
82#include "tensorflow/core/platform/types.h"
83#include "tensorflow/core/util/determinism.h"
84#include "tensorflow/core/util/util.h"
85
86namespace tensorflow {
87
88REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
89 ResourceHandlesOp<Var>);
90
91ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
92 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
93}
94
95namespace {
96
97Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
98 Tensor* output;
99 Notification n;
100 Status status;
101 AllocatorAttributes attr;
102 if (t->dtype() == DT_VARIANT) {
103 attr.set_on_host(true);
104 }
105 TF_RETURN_IF_ERROR(
106 ctx->allocate_output(output_idx, t->shape(), &output, attr));
107 if (t->dtype() == DT_VARIANT) {
108 output->flat<Variant>() = t->flat<Variant>();
109 } else if (ctx->op_device_context() != nullptr) {
110 // TODO(apassos): remove the down_cast by just returning Device* from
111 // OpKernelContext
112 Device* device = down_cast<Device*>(ctx->device());
113 ctx->op_device_context()->CopyTensorInSameDevice(
114 t, device, output, [&n, &status](const Status& s) {
115 status = s;
116 n.Notify();
117 });
118 n.WaitForNotification();
119 return status;
120 } else {
121 switch (t->dtype()) {
122#define HANDLER(type) \
123 case DataTypeToEnum<type>::value: \
124 output->flat<type>() = t->flat<type>(); \
125 break;
126 TF_CALL_ALL_TYPES(HANDLER);
127#undef HANDLER
128 default:
129 return errors::Internal("Unsupported dtype", t->dtype());
130 }
131 }
132 return OkStatus();
133}
134
135} // namespace
136
137void ReadVariableOp::Compute(OpKernelContext* ctx) {
138 core::RefCountPtr<Var> variable;
139 const ResourceHandle& handle = HandleFromInput(ctx, 0);
140 const auto status = LookupResource(ctx, handle, &variable);
141 OP_REQUIRES(ctx, status.ok(),
142 errors::FailedPrecondition(
143 "Could not find variable ", handle.name(), ". ",
144 "This could mean that the variable has been deleted. ",
145 "In TF1, it can also mean the variable is uninitialized. ",
146 "Debug info: container=", handle.container(),
147 ", status error message=", status.error_message()));
148
149 tf_shared_lock ml(*variable->mu());
150 // We're acquiring a reference to the underlying buffer while
151 // holding a shared lock to guarantee ordering of reads and
152 // writes when in copy-on-write mode.
153 const Tensor* t = variable->tensor();
154 if (!variable->copy_on_read_mode.load()) {
155 OP_REQUIRES(
156 ctx, dtype_ == t->dtype(),
157 errors::InvalidArgument(
158 "Trying to read variable with wrong dtype. Expected ",
159 DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
160 ctx->set_output(0, *t);
161 } else {
162 OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
163 }
164}
165
166ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
167 int n;
168 OP_REQUIRES_OK(c, c->GetAttr("N", &n));
169 OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
170 OP_REQUIRES(c, n == dtypes_.size(),
171 errors::InvalidArgument(
172 "Mismatched number of arguments to ReadVariablesOp (", n,
173 " vs. ", dtypes_.size(), ")"));
174}
175
176void ReadVariablesOp::Compute(OpKernelContext* ctx) {
177 std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
178 std::vector<const ResourceHandle*> handles(dtypes_.size());
179 for (size_t i = 0; i < dtypes_.size(); ++i) {
180 handles[i] = &HandleFromInput(ctx, i);
181 }
182
183 OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
184
185 std::vector<string> uninitialized_vars;
186 for (int64_t i = 0; i < variables.size(); i++) {
187 if (variables[i] == nullptr) {
188 uninitialized_vars.push_back(handles[i]->name());
189 }
190 }
191
192 OP_REQUIRES(ctx, uninitialized_vars.empty(),
193 errors::FailedPrecondition(
194 "In ReadVariablesOp the following variables were "
195 "found uninitialized: ",
196 absl::StrJoin(uninitialized_vars, ", ")));
197
198 for (size_t i = 0; i < dtypes_.size(); ++i) {
199 // We're acquiring a reference to the underlying buffer while
200 // holding a shared lock to guarantee ordering of reads and
201 // writes.
202 tf_shared_lock ml(*variables[i]->mu());
203 OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
204 errors::InvalidArgument(
205 "Trying to read variable ", handles[i]->name(),
206 " from Container: ", handles[i]->container(),
207 " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
208 " got ", DataTypeString(variables[i]->tensor()->dtype())));
209 if (variables[i]->copy_on_read_mode.load()) {
210 OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
211 } else {
212 const Tensor& t = *variables[i]->tensor();
213 ctx->set_output(i, t);
214 }
215 }
216}
217
218REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
219 ReadVariableOp);
220REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
221 ReadVariablesOp);
222
223REGISTER_KERNEL_BUILDER(
224 Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
225 ReadVariableOp);
226REGISTER_KERNEL_BUILDER(
227 Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
228 ReadVariablesOp);
229
230VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
231 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
232 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
233
234 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
235 OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
236
237 is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
238
239 // Use const_tensor_ if the variable is non-anonymous.
240 if (!is_anonymous_) {
241 AllocatorAttributes attr;
242 attr.set_on_host(true);
243 OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
244 &const_tensor_, attr));
245 const_tensor_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
246 context, container_, name_,
247 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
248 }
249}
250
251void VarHandleOp::Compute(OpKernelContext* ctx) {
252 if (is_anonymous_) {
253 Var* resource = new Var(dtype_and_shape_.dtype);
254 ResourceMgr* mgr = ctx->resource_manager();
255 ResourceHandle handle = ResourceHandle::MakeRefCountingHandle<Var>(
256 resource, ctx->device()->name(),
257 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
258 ctx->stack_trace());
259 // TODO(b/203901837): See if we can abolish all code paths that lookup
260 // anonymous variables and then stop publishing them to the manager.
261 OP_REQUIRES_OK(ctx, mgr->CreateUnowned<Var>(handle.container(),
262 handle.name(), resource));
263
264 AllocatorAttributes attr;
265 attr.set_on_host(true);
266 Tensor tensor;
267 OP_REQUIRES_OK(
268 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
269
270 tensor.scalar<ResourceHandle>()() = std::move(handle);
271
272 ctx->set_output(0, tensor);
273 } else {
274 ctx->set_output(0, const_tensor_);
275 }
276}
277
278REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
279
280#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
281#define REGISTER_GPU_KERNELS(type) \
282 namespace functor { \
283 template <> \
284 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
285 const GPUDevice& d, typename TTypes<type>::Flat lhs, \
286 typename TTypes<type>::ConstFlat rhs); \
287 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
288 }
289
290TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
291TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
292TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS);
293TF_CALL_variant(REGISTER_GPU_KERNELS);
294#undef REGISTER_GPU_KERNELS
295
296#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
297
298#define REGISTER_DEFAULT_KERNELS(type) \
299 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
300 .Device(DEVICE_DEFAULT) \
301 .HostMemory("resource") \
302 .TypeConstraint<type>("dtype"), \
303 VarHandleOp)
304TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
305TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_DEFAULT_KERNELS);
306TF_CALL_bfloat16(REGISTER_DEFAULT_KERNELS);
307TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
308#undef REGISTER_DEFAULT_KERNELS
309
310REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
311 .Device(DEVICE_DEFAULT)
312 .HostMemory("resources")
313 .TypeConstraint("dtypes",
314 {DT_INT64, DT_COMPLEX64,
315 DT_COMPLEX128, DT_HALF, DT_FLOAT,
316 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
317 ResourceHandlesOp<Var>);
318
319REGISTER_KERNEL_BUILDER(
320 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
321 VariableShapeOp<int32>);
322REGISTER_KERNEL_BUILDER(Name("VariableShape")
323 .Device(DEVICE_CPU)
324 .TypeConstraint<int64_t>("out_type"),
325 VariableShapeOp<int64_t>);
326
327REGISTER_KERNEL_BUILDER(Name("VariableShape")
328 .Device(DEVICE_DEFAULT)
329 .TypeConstraint<int32>("out_type")
330 .HostMemory("output")
331 .HostMemory("input"),
332 VariableShapeOp<int32>);
333REGISTER_KERNEL_BUILDER(Name("VariableShape")
334 .Device(DEVICE_DEFAULT)
335 .TypeConstraint<int64_t>("out_type")
336 .HostMemory("output")
337 .HostMemory("input"),
338 VariableShapeOp<int64_t>);
339
340DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
341 : OpKernel(ctx) {
342 OP_REQUIRES_OK(ctx,
343 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
344}
345
346void DestroyResourceOp::Compute(OpKernelContext* ctx) {
347 const ResourceHandle& p = HandleFromInput(ctx, 0);
348 Status status = DeleteResource(ctx, p);
349 if (ignore_lookup_error_ && errors::IsNotFound(status)) {
350 return;
351 }
352 OP_REQUIRES_OK(ctx, status);
353}
354
355REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
356 DestroyResourceOp);
357REGISTER_KERNEL_BUILDER(
358 Name("DestroyResourceOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
359 DestroyResourceOp);
360
361void DisableCopyOnReadOp::Compute(OpKernelContext* ctx) {
362 core::RefCountPtr<Var> variable;
363 const ResourceHandle& handle = HandleFromInput(ctx, 0);
364 const auto status = LookupResource(ctx, handle, &variable);
365 OP_REQUIRES(ctx, status.ok(),
366 errors::FailedPrecondition(
367 "Could not find variable ", handle.name(), ". ",
368 "This could mean that the variable has been deleted. ",
369 "In TF1, it can also mean the variable is uninitialized. ",
370 "Debug info: container=", handle.container(),
371 ", status error message=", status.error_message()));
372 // If the variable is currently in copy-on-read mode, its refcount is 1
373 if (variable->copy_on_read_mode.load()) {
374 // Obtain an exclusive lock on the variable and change the access mode
375 mutex_lock ml(*variable->mu());
376 variable->copy_on_read_mode.store(false);
377 }
378}
379
380REGISTER_KERNEL_BUILDER(Name("DisableCopyOnRead").Device(DEVICE_CPU),
381 DisableCopyOnReadOp);
382REGISTER_KERNEL_BUILDER(
383 Name("DisableCopyOnRead").Device(DEVICE_DEFAULT).HostMemory("resource"),
384 DisableCopyOnReadOp);
385
386template <typename Device, typename T>
387class AssignVariableOp : public OpKernel {
388 public:
389 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
390 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
391 if (!c->GetAttr("_grappler_relax_allocator_constraints",
392 &relax_constraints_)
393 .ok()) {
394 relax_constraints_ = false;
395 }
396 if (c->HasAttr("validate_shape")) {
397 OP_REQUIRES_OK(c, c->GetAttr("validate_shape", &validate_shape_));
398 }
399 }
400
401 void Compute(OpKernelContext* context) override {
402 OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
403 errors::InvalidArgument(
404 "Variable and value dtypes don't match; respectively, ",
405 DataTypeString(dtype_), " and ",
406 DataTypeString(context->input(1).dtype())));
407 core::RefCountPtr<Var> variable;
408 const Tensor& value = context->input(1);
409 // Note: every resource-variable-manipulating op assumes copy-on-write
410 // semantics, and creates a copy of the variable's Tensor if its refcount is
411 // bigger than 1 when we try to modify it. This means we never need to copy
412 // the original tensor for AssignVariableOp; even if there are other live
413 // users of it we know none can modify it so this is always safe (even in
414 // esoteric cases where the same tensor is used to initialize multiple
415 // variables or the tensor is a constant this is safe, as future writes will
416 // trigger copies).
417 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
418 context, HandleFromInput(context, 0), &variable,
419 [this, &value](Var** ptr) {
420 *ptr = new Var(dtype_);
421 *(*ptr)->tensor() = value;
422 (*ptr)->is_initialized = true;
423 return OkStatus();
424 }));
425 mutex_lock ml(*variable->mu());
426 // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
427 // check below is to allow an XLA specific situation wherein update can
428 // happen first by the AssignVariableOp,
429 // in which case the variable is still uninitialized.
430 // When using TF-XLA, this scenario is possible when the execution uses the
431 // 'fallback' path (which essentially invokes Tensorflow ops via
432 // partitioned_call).
433 OP_REQUIRES(context,
434 (variable->tensor()->dtype() == DT_INVALID &&
435 !variable->is_initialized) ||
436 variable->tensor()->dtype() == dtype_,
437 errors::InvalidArgument(
438 "Trying to assign variable with wrong dtype. Expected ",
439 DataTypeString(variable->tensor()->dtype()), " got ",
440 DataTypeString(dtype_)));
441 if (validate_shape_) {
442 OP_REQUIRES(
443 context,
444 (!variable->is_initialized ||
445 variable->tensor()->shape().IsSameSize(value.shape())),
446 errors::InvalidArgument(
447 "Trying to assign to variable with tensor with wrong shape."
448 " Expected ",
449 variable->tensor()->shape().DebugString(), " got ",
450 value.shape().DebugString()));
451 }
452 if (variable->copy_on_read_mode.load()) {
453 AllocatorAttributes attr;
454 attr.set_gpu_compatible(true);
455 attr.set_nic_compatible(true);
456 OP_REQUIRES_OK(context,
457 context->allocate_temp(value.dtype(), value.shape(),
458 variable->tensor(), attr));
459 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
460 copy_functor(context->eigen_device<Device>(),
461 variable->tensor()->flat<T>(), value.flat<T>());
462 } else {
463 *variable->tensor() = value;
464 }
465 variable->is_initialized = true;
466 }
467
468 private:
469 DataType dtype_;
470 bool relax_constraints_;
471 bool validate_shape_ = false;
472};
473
474template <typename Device>
475class AssignVariableOp<Device, Variant> : public OpKernel {
476 public:
477 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
478 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
479 OP_REQUIRES(c, dtype_ == DT_VARIANT,
480 errors::Internal("Variant kernel called with dtype: ",
481 DataTypeString(dtype_)));
482 }
483
484 void Compute(OpKernelContext* context) override {
485 const Tensor& value = context->input(1);
486 core::RefCountPtr<Var> variable;
487 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
488 context, HandleFromInput(context, 0), &variable,
489 [](Var** ptr) {
490 // Created on host.
491 *ptr = new Var(DT_VARIANT);
492 return OkStatus();
493 }));
494
495 // For purposes of forwarding DT_VARIANT, we want the least
496 // restrictive attr; we already know the input is on host.
497 AllocatorAttributes attr;
498
499 // Copying is unnecessary if we are the last user of the value
500 // tensor, we can just adopt the input tensor's buffer instead.
501 // Note that Variant objects themselves always reside on host.
502 //
503 // We nevertheless want to signal to the runtime that the tensor
504 // should reside in memory of the associated device, as Variant
505 // tensors may be marked as sitting on either CPU or GPU. This
506 // helps to elide one or more copies.
507 std::unique_ptr<Tensor> input_alias = context->forward_input(
508 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
509 value.shape(),
510 DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
511 attr);
512
513 mutex_lock ml(*variable->mu());
514 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
515 errors::InvalidArgument(
516 "Trying to assign variable with wrong dtype. Expected ",
517 DataTypeString(variable->tensor()->dtype()), " got ",
518 DataTypeString(DT_VARIANT)));
519 variable->is_initialized = true;
520 *variable->tensor() = Tensor(DT_VARIANT, value.shape());
521
522 if (input_alias) {
523 *variable->tensor() = *input_alias;
524 return;
525 }
526
527 // Need to copy, but maybe we can re-use variable's buffer?
528 if (!variable->tensor()->RefCountIsOne() ||
529 !variable->tensor()->shape().IsSameSize(value.shape())) {
530 // Allocation of DT_VARIANT is always on host.
531 attr.set_on_host(true);
532 OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
533 variable->tensor(), attr));
534 }
535
536 const auto elements_in = value.flat<Variant>();
537 auto elements_out = variable->tensor()->flat<Variant>();
538 for (int64_t i = 0; i < elements_in.size(); ++i) {
539 elements_out(i) = elements_in(i);
540 }
541 }
542
543 private:
544 DataType dtype_;
545};
546
547#define REGISTER_KERNELS(type) \
548 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
549 .Device(DEVICE_CPU) \
550 .TypeConstraint<type>("dtype"), \
551 AssignVariableOp<Eigen::ThreadPoolDevice, type>);
552
553TF_CALL_ALL_TYPES(REGISTER_KERNELS);
554TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
555#undef REGISTER_KERNELS
556
557#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
558#define REGISTER_GPU_KERNELS(type) \
559 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
560 .Device(DEVICE_GPU) \
561 .TypeConstraint<type>("dtype") \
562 .HostMemory("resource"), \
563 AssignVariableOp<GPUDevice, type>);
564
565TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
566TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS);
567TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
568#undef REGISTER_GPU_KERNELS
569#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
570
571#define REGISTER_KERNELS(type) \
572 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
573 .Device(DEVICE_DEFAULT) \
574 .TypeConstraint<type>("dtype") \
575 .HostMemory("resource"), \
576 AssignVariableOp<CPUDevice, type>);
577
578TF_CALL_ALL_TYPES(REGISTER_KERNELS);
579TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
580#undef REGISTER_KERNELS
581
582template <typename Device, typename T, DenseUpdateType Op>
583class AssignUpdateVariableOp : public OpKernel {
584 public:
585 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
586
587 void Compute(OpKernelContext* context) override {
588 core::RefCountPtr<Var> variable;
589 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
590 &variable));
591
592 const Tensor& value = context->input(1);
593 // TODO(apassos): We could possibly avoid the copy done by
594 // PrepareToUpdateVariable() for commutative operations like Op ==
595 // ADD if value's refcount was 1.
596 mutex_lock ml(*variable->mu());
597 Tensor* var_tensor = variable->tensor();
598 OP_REQUIRES_OK(context, ValidateAssignUpdateVariableOpShapes(
599 var_tensor->shape(), value.shape()));
600 OP_REQUIRES_OK(
601 context, PrepareToUpdateVariable<Device, T>(
602 context, var_tensor, variable->copy_on_read_mode.load()));
603 functor::DenseUpdate<Device, T, Op> update_functor;
604 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
605 value.flat<T>());
606 }
607};
608
609#define REGISTER_KERNELS(type) \
610 REGISTER_KERNEL_BUILDER( \
611 Name("AssignAddVariableOp") \
612 .Device(DEVICE_CPU) \
613 .TypeConstraint<type>("dtype"), \
614 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
615 REGISTER_KERNEL_BUILDER( \
616 Name("AssignSubVariableOp") \
617 .Device(DEVICE_CPU) \
618 .TypeConstraint<type>("dtype"), \
619 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
620
621TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
622#undef REGISTER_KERNELS
623
624#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
625#define REGISTER_GPU_KERNELS(type) \
626 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
627 .Device(DEVICE_GPU) \
628 .HostMemory("resource") \
629 .TypeConstraint<type>("dtype"), \
630 AssignUpdateVariableOp<GPUDevice, type, ADD>); \
631 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
632 .Device(DEVICE_GPU) \
633 .HostMemory("resource") \
634 .TypeConstraint<type>("dtype"), \
635 AssignUpdateVariableOp<GPUDevice, type, SUB>);
636
637TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
638TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS);
639#undef REGISTER_GPU_KERNELS
640#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
641
642class VarIsInitializedOp : public OpKernel {
643 public:
644 explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
645
646 void Compute(OpKernelContext* context) override {
647 Tensor* output = nullptr;
648 OP_REQUIRES_OK(context,
649 context->allocate_output(0, TensorShape({}), &output));
650 auto output_tensor = output->tensor<bool, 0>();
651 core::RefCountPtr<Var> variable;
652 Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
653 if (!s.ok()) {
654 output_tensor() = false;
655 return;
656 }
657 mutex_lock ml(*variable->mu());
658 output_tensor() = variable->is_initialized;
659 }
660};
661
662REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
663 VarIsInitializedOp);
664
665REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
666 .Device(DEVICE_DEFAULT)
667 .HostMemory("resource")
668 .HostMemory("is_initialized"),
669 VarIsInitializedOp);
670
671template <typename Device, typename T, typename Index>
672class ResourceGatherOp : public OpKernel {
673 public:
674 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
675 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
676 }
677
678 void Compute(OpKernelContext* c) override {
679 core::RefCountPtr<Var> v;
680 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
681 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
682 // NOTE: We hold the lock for the whole gather operation instead
683 // of increasing the reference count of v->tensor() to avoid a
684 // situation where a write to the same variable will see a
685 // reference count greater than one and make a copy of the
686 // (potentially very large) tensor buffer.
687 tf_shared_lock ml(*v->mu());
688 const Tensor& params = *v->tensor();
689 const Tensor& indices = c->input(1);
690 OP_REQUIRES(
691 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
692 errors::InvalidArgument("params must be at least 1 dimensional"));
693 OP_REQUIRES(
694 c, params.shape().dims() >= batch_dims_,
695 errors::InvalidArgument("params must have at least ", batch_dims_,
696 " (batch_dims) dimensions but it has shape ",
697 params.shape().DebugString()));
698
699 // Check that we have enough index space
700 const int64_t N = indices.NumElements();
701 OP_REQUIRES(
702 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
703 errors::InvalidArgument("params.shape[0] too large for ",
704 DataTypeString(DataTypeToEnum<Index>::v()),
705 " indexing: ", params.dim_size(0), " > ",
706 std::numeric_limits<Index>::max()));
707
708 // The result shape is params.shape[:batch_dims] +
709 // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
710 TensorShape result_shape;
711 for (int i = 0; i < batch_dims_; ++i) {
712 result_shape.AddDim(params.dim_size(i));
713 }
714 for (int i = batch_dims_; i < indices.dims(); ++i) {
715 result_shape.AddDim(indices.dim_size(i));
716 }
717 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
718 result_shape.AddDim(params.dim_size(i));
719 }
720
721 Tensor* out = nullptr;
722 Tensor tmp;
723 if (params.dtype() == DT_VARIANT) {
724 tmp = Tensor(DT_VARIANT, result_shape);
725 c->set_output(0, tmp);
726 out = &tmp;
727 } else {
728 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
729 }
730
731 if (N > 0) {
732 Tensor tmp_indices;
733
734 // Points to the original or updated (if batch_dims is set) indices.
735 const Tensor* op_indices = &indices;
736 if (batch_dims_ > 0) {
737 OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
738 &tmp_indices));
739 functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
740 copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
741 indices.flat<Index>());
742
743 AddBatchOffsets(c, &tmp_indices, params);
744 if (!c->status().ok()) return;
745 op_indices = &tmp_indices;
746 }
747
748 int64_t gather_dim_size = 1;
749 for (int idx = 0; idx <= batch_dims_; ++idx) {
750 gather_dim_size *= params.dim_size(idx);
751 }
752 int64_t inner_size = 1;
753 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
754 inner_size *= params.dim_size(i);
755 }
756 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
757 const auto indices_flat = op_indices->flat<Index>();
758 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
759
760 functor::GatherFunctor<Device, T, Index> functor;
761 int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
762
763 OP_REQUIRES(
764 c, bad_i < 0,
765 errors::InvalidArgument(
766 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
767 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
768 }
769 }
770
771 private:
772 // Add the batch offset derived from params to each batch of indices.
773 // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
774 // If indexing into a params dimension of size 4, then the indices will become
775 // [0, 1, 2, 4, 5, 6]
776 void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
777 const Tensor& params) {
778 int64_t batch_size = 1; // The size of all batch dimensions.
779 for (int idx = 0; idx < batch_dims_; ++idx) {
780 batch_size *= params.dim_size(idx);
781 }
782 OP_REQUIRES(
783 ctx, batch_size != 0,
784 errors::InvalidArgument(
785 "Inner size of indices would result in batch_size of 0 and a ",
786 "division by 0 in the implementation. This is illegal"));
787
788 auto indices_flat = indices->flat<Index>();
789 int64_t const index_inner_size = indices->NumElements() / batch_size;
790 int64_t const batch_offset = params.dim_size(batch_dims_);
791 for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
792 ++batch_idx) {
793 for (int64_t idx = 0; idx < index_inner_size; ++idx) {
794 indices_flat(dest_idx++) += batch_offset * batch_idx;
795 }
796 }
797 }
798
799 int32 batch_dims_ = 0;
800};
801
802#define REGISTER_GATHER_FULL(dev, type, index_type) \
803 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
804 .Device(DEVICE_##dev) \
805 .HostMemory("resource") \
806 .TypeConstraint<type>("dtype") \
807 .TypeConstraint<index_type>("Tindices"), \
808 ResourceGatherOp<dev##Device, type, index_type>)
809
810#define REGISTER_GATHER_ALL_INDICES(dev, type) \
811 REGISTER_GATHER_FULL(dev, type, int32); \
812 REGISTER_GATHER_FULL(dev, type, int64_t)
813
814#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
815
816// Registration of the CPU implementations.
817TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
818TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
819
820// Registers GPU kernels.
821#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
822#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
823
824TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GATHER_GPU);
825TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
826
827// Variant objects themselves sit on CPU, even if they contain data
828// pointing to a device.
829REGISTER_KERNEL_BUILDER(Name("ResourceGather")
830 .Device(DEVICE_DEFAULT)
831 .HostMemory("resource")
832 .HostMemory("indices")
833 .TypeConstraint<Variant>("dtype")
834 .TypeConstraint<int32>("Tindices"),
835 ResourceGatherOp<CPUDevice, Variant, int32>)
836REGISTER_KERNEL_BUILDER(Name("ResourceGather")
837 .Device(DEVICE_DEFAULT)
838 .HostMemory("resource")
839 .HostMemory("indices")
840 .TypeConstraint<Variant>("dtype")
841 .TypeConstraint<int64_t>("Tindices"),
842 ResourceGatherOp<CPUDevice, Variant, int64>)
843
844#undef REGISTER_GATHER_GPU
845#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
846
847#undef REGISTER_GATHER_CPU
848#undef REGISTER_GATHER_ALL_INDICES
849#undef REGISTER_GATHER_FULL
850
851template <typename Device, typename T, typename Index>
852class ResourceGatherNdOp : public OpKernel {
853 public:
854 explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
855
856 void Compute(OpKernelContext* c) override {
857 core::RefCountPtr<Var> v;
858 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
859 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
860 // NOTE: We hold the lock for the whole gather operation instead
861 // of increasing the reference count of v->tensor() to avoid a
862 // situation where a write to the same variable will see a
863 // reference count greater than one and make a copy of the
864 // (potentially very large) tensor buffer.
865 tf_shared_lock ml(*v->mu());
866 const Tensor& params = *v->tensor();
867 const Tensor& indices = c->input(1);
868
869 Tensor out;
870 OP_REQUIRES_OK(
871 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
872 c->set_output(0, out);
873 }
874};
875
876#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
877 REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \
878 .Device(DEVICE_##dev) \
879 .HostMemory("resource") \
880 .TypeConstraint<type>("dtype") \
881 .TypeConstraint<index_type>("Tindices"), \
882 ResourceGatherNdOp<dev##Device, type, index_type>)
883
884#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
885 REGISTER_GATHER_ND_FULL(dev, type, int32); \
886 REGISTER_GATHER_ND_FULL(dev, type, int64_t)
887
888#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
889
890// Registration of the CPU implementations.
891TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
892
893// Registers GPU kernels.
894#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
895#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
896
897TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
898TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GATHER_ND_GPU);
899
900#undef REGISTER_GATHER_ND_GPU
901#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
902
903#undef REGISTER_GATHER_ND_CPU
904#undef REGISTER_GATHER_ND_ALL_INDICES
905#undef REGISTER_GATHER_ND_FULL
906
907namespace {
908
909template <typename Device>
910bool isCPUDevice() {
911 return false;
912}
913
914template <>
915bool isCPUDevice<CPUDevice>() {
916 return true;
917}
918
919template <typename T>
920bool ValidateInput(const Tensor& updates) {
921 const auto updates_flat = updates.flat<T>();
922 for (int i = 0; i < updates.NumElements(); ++i) {
923 if (updates_flat(i) == T{}) return false;
924 }
925 return true;
926}
927
928template <>
929bool ValidateInput<Variant>(const Tensor& updates) {
930 return true;
931}
932
933template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
934Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
935 const Tensor& updates, Index num_indices);
936
937template <typename T, typename Index, scatter_op::UpdateOp Op>
938Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices,
939 const Tensor& updates, Index num_indices);
940
941#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
942
943template <typename T>
944Status CopyTensorToHost(OpKernelContext* c, const Tensor& device_tensor,
945 Tensor* host_tensor) {
946 AllocatorAttributes alloc_attr;
947 alloc_attr.set_on_host(true);
948 alloc_attr.set_gpu_compatible(true);
949 auto stream = c->op_device_context()->stream();
950 TF_RETURN_IF_ERROR(c->allocate_temp(
951 device_tensor.dtype(), device_tensor.shape(), host_tensor, alloc_attr));
952 se::DeviceMemoryBase device_ptr(
953 const_cast<Tensor&>(device_tensor).flat<T>().data(),
954 device_tensor.flat<T>().size() * sizeof(T));
955 stream->ThenMemcpy(host_tensor->flat<T>().data(), device_ptr,
956 device_tensor.NumElements() * sizeof(T));
957 if (!stream) {
958 return errors::Internal("Failed to copy indices to host");
959 }
960 return OkStatus();
961}
962
963// Copies inputs to the CPU, runs DoScatter on the CPU, then copies output
964// back to GPU. This is useful because the CPU implementation is deterministic
965// and the GPU implementation is not. Tensor inputs to this function must be on
966// the GPU.
967template <typename T, typename Index, scatter_op::UpdateOp Op>
968Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices,
969 const Tensor& updates, Index num_indices) {
970 if (!DataTypeCanUseMemcpy(params->dtype())) {
971 return errors::Unimplemented(
972 "GPU Scatter ops for dtype ", DataTypeString(params->dtype()),
973 " do not yet have a deterministic implementation");
974 }
975 auto stream = c->op_device_context()->stream();
976
977 Tensor host_indices;
978 TF_RETURN_IF_ERROR(CopyTensorToHost<Index>(c, indices, &host_indices));
979 Tensor host_updates;
980 TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, updates, &host_updates));
981 Tensor host_params;
982 TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, *params, &host_params));
983
984 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
985 TF_RETURN_IF_ERROR(DoScatter<CPUDevice, T, Index, Op>(
986 c, &host_params, host_indices, host_updates, num_indices));
987
988 // Copy 'host_params' to device.
989 se::DeviceMemoryBase params_ptr(params->flat<T>().data(),
990 params->flat<T>().size() * sizeof(T));
991 stream->ThenMemcpy(&params_ptr, host_params.flat<T>().data(),
992 host_params.NumElements() * sizeof(T));
993 if (!stream) {
994 return errors::Internal("Failed to copy params to device");
995 }
996 // Deallocate host_params' buffer once the host-to-device copy is complete.
997 // host_params is captured by value in the lambda so that its buffer is only
998 // destructed once the lambda is destructed.
999 c->device()->tensorflow_accelerator_device_info()->event_mgr->ThenExecute(
1000 stream, [host_params] {});
1001 return OkStatus();
1002}
1003
1004#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1005
1006template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
1007Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
1008 const Tensor& updates, Index num_indices) {
1009#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1010 if (std::is_same<Device, GPUDevice>::value &&
1011 tensorflow::OpDeterminismRequired()) {
1012 return DoScatterOnCpu<T, Index, op>(c, params, indices, updates,
1013 num_indices);
1014 }
1015#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1016
1017 // Run on the CPU for integer types, since the GPU implementation uses
1018 // atomics, which are not supported for all integer types.
1019 if constexpr (std::is_same<Device, GPUDevice>::value &&
1020 std::is_integral<T>::value) {
1021 return DoScatterOnCpu<T, Index, op>(c, params, indices, updates,
1022 num_indices);
1023 } else {
1024 auto indices_flat = indices.flat<Index>();
1025 auto params_flat = params->flat_outer_dims<T>();
1026 if (TensorShapeUtils::IsScalar(updates.shape())) {
1027 const auto update = updates.scalar<T>();
1028
1029 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
1030 const Index bad_i = functor(c, c->template eigen_device<Device>(),
1031 params_flat, update, indices_flat);
1032 if (bad_i >= 0) {
1033 return errors::InvalidArgument(
1034 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
1035 indices_flat(bad_i), " is not in [0, ", params->dim_size(0), ")");
1036 }
1037 } else {
1038 int64_t num_updates = updates.NumElements();
1039 if (!TensorShapeUtils::StartsWith(updates.shape(), indices.shape())) {
1040 return errors::InvalidArgument(
1041 "The shape of indices (", indices.shape().DebugString(),
1042 ") must be a prefix of the shape of updates (",
1043 updates.shape().DebugString(), ")");
1044 }
1045 auto updates_flat =
1046 updates.shaped<T, 2>({num_indices, num_updates / num_indices});
1047 functor::ScatterFunctor<Device, T, Index, op> functor;
1048 const Index bad_i = functor(c, c->template eigen_device<Device>(),
1049 params_flat, updates_flat, indices_flat);
1050 if (bad_i >= 0) {
1051 return errors::InvalidArgument(
1052 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
1053 indices_flat(bad_i), " is not in [0, ", params->dim_size(0), ")");
1054 }
1055 }
1056 }
1057 return OkStatus();
1058}
1059
1060} // namespace
1061
1062template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
1063class ResourceScatterUpdateOp : public OpKernel {
1064 public:
1065 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
1066 // We use the same kernel for many operations.
1067 // Each operation has a different set of attributes defined in its nodes.
1068 Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
1069 if (!s.ok()) {
1070 use_exclusive_lock_ = false;
1071 }
1072 }
1073
1074 void Compute(OpKernelContext* c) override {
1075 core::RefCountPtr<Var> v;
1076 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1077 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
1078 const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
1079 c->input_dtype(0) == DT_STRING ||
1080 c->input_dtype(0) == DT_VARIANT;
1081 if (is_non_pod_dtype || use_exclusive_lock_) {
1082 mutex_lock ml(*v->mu());
1083 DoCompute(c);
1084 } else {
1085 // For POD dtypes, we can safely run the update without the mutex.
1086 tf_shared_lock ml(*v->mu());
1087 DoCompute(c);
1088 }
1089 }
1090
1091 private:
1092 bool use_exclusive_lock_;
1093
1094 void DoCompute(OpKernelContext* c) {
1095 core::RefCountPtr<Var> v;
1096 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1097 Tensor* params = v->tensor();
1098 const Tensor& indices = c->input(1);
1099 const Tensor& updates = c->input(2);
1100
1101 // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
1102 OP_REQUIRES(c,
1103 updates.dims() == 0 ||
1104 updates.dims() == indices.dims() + params->dims() - 1,
1105 errors::InvalidArgument(
1106 "Must have updates.shape = indices.shape + "
1107 "params.shape[1:] or updates.shape = [], got ",
1108 "updates.shape ", updates.shape().DebugString(),
1109 ", indices.shape ", indices.shape().DebugString(),
1110 ", params.shape ", params->shape().DebugString()));
1111
1112 // Check that we have enough index space
1113 const int64_t N_big = indices.NumElements();
1114 OP_REQUIRES(
1115 c, N_big <= std::numeric_limits<Index>::max(),
1116 errors::InvalidArgument("indices has too many elements for ",
1117 DataTypeString(DataTypeToEnum<Index>::v()),
1118 " indexing: ", N_big, " > ",
1119 std::numeric_limits<Index>::max()));
1120 const Index N = static_cast<Index>(N_big);
1121 OP_REQUIRES(
1122 c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
1123 errors::InvalidArgument("params.shape[0] too large for ",
1124 DataTypeString(DataTypeToEnum<Index>::v()),
1125 " indexing: ", params->dim_size(0), " > ",
1126 std::numeric_limits<Index>::max()));
1127
1128 // Prevent division by 0
1129 if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
1130 OP_REQUIRES(c, ValidateInput<T>(updates),
1131 errors::InvalidArgument("updates must not contain 0"));
1132 }
1133
1134 if (N > 0) {
1135 OP_REQUIRES_OK(
1136 c, DoScatter<Device, T, Index, op>(c, params, indices, updates, N));
1137 }
1138 }
1139};
1140
1141#define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
1142 REGISTER_KERNEL_BUILDER( \
1143 Name(name) \
1144 .Device(DEVICE_##dev) \
1145 .HostMemory("resource") \
1146 .TypeConstraint<type>("dtype") \
1147 .TypeConstraint<index_type>("Tindices"), \
1148 ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
1149
1150#define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
1151 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
1152 REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op);
1153
1154#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
1155 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
1156 scatter_op::UpdateOp::ADD); \
1157 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
1158 scatter_op::UpdateOp::SUB); \
1159 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
1160 scatter_op::UpdateOp::MUL); \
1161 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
1162 scatter_op::UpdateOp::DIV);
1163
1164#define REGISTER_SCATTER_UPDATE(type, dev) \
1165 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
1166 scatter_op::UpdateOp::ASSIGN);
1167
1168#define REGISTER_SCATTER_MINMAX(type, dev) \
1169 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
1170 scatter_op::UpdateOp::MIN); \
1171 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
1172 scatter_op::UpdateOp::MAX);
1173
1174// Registers CPU kernels.
1175#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
1176 REGISTER_SCATTER_ARITHMETIC(type, CPU);
1177#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
1178#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
1179
1180TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
1181TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
1182TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_CPU);
1183TF_CALL_tstring(REGISTER_SCATTER_UPDATE_CPU);
1184TF_CALL_bool(REGISTER_SCATTER_UPDATE_CPU);
1185TF_CALL_variant(REGISTER_SCATTER_UPDATE_CPU);
1186
1187// Registers GPU kernels.
1188#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1189#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
1190 REGISTER_SCATTER_ARITHMETIC(type, GPU);
1191#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
1192
1193#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
1194
1195TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
1196TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ARITHMETIC_GPU);
1197TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
1198TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_MINMAX_GPU);
1199TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
1200TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_UPDATE_GPU);
1201
1202REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1203 .Device(DEVICE_DEFAULT)
1204 .HostMemory("resource")
1205 .HostMemory("indices")
1206 .TypeConstraint<Variant>("dtype")
1207 .TypeConstraint<int32>("Tindices"),
1208 ResourceScatterUpdateOp<CPUDevice, Variant, int32,
1209 scatter_op::UpdateOp::ASSIGN>)
1210REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1211 .Device(DEVICE_GPU)
1212 .HostMemory("resource")
1213 .TypeConstraint<bool>("dtype")
1214 .TypeConstraint<int32>("Tindices"),
1215 ResourceScatterUpdateOp<GPUDevice, bool, int32,
1216 scatter_op::UpdateOp::ASSIGN>)
1217REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1218 .Device(DEVICE_DEFAULT)
1219 .HostMemory("resource")
1220 .HostMemory("indices")
1221 .TypeConstraint<Variant>("dtype")
1222 .TypeConstraint<int64_t>("Tindices"),
1223 ResourceScatterUpdateOp<CPUDevice, Variant, int64,
1224 scatter_op::UpdateOp::ASSIGN>)
1225#undef REGISTER_SCATTER_ARITHMETIC_GPU
1226#undef REGISTER_SCATTER_MINMAX_GPU
1227#undef REGISTER_SCATTER_UPDATE_GPU
1228#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1229
1230#undef REGISTER_SCATTER_ARITHMETIC
1231#undef REGISTER_SCATTER_ARITHMETIC_CPU
1232#undef REGISTER_SCATTER_MINMAX
1233#undef REGISTER_SCATTER_MINMAX_CPU
1234#undef REGISTER_SCATTER_UPDATE
1235#undef REGISTER_SCATTER_UPDATE_CPU
1236#undef REGISTER_SCATTER_KERNEL
1237#undef REGISTER_SCATTER_KERNEL_INDEX
1238
1239} // namespace tensorflow
1240