1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Our general strategy for preventing conflicts between concurrent |
17 | // reads and writes of resource variables is to: |
18 | // * For read operations, we: |
19 | // - acquire the variable's mutex (in "shared" mode); |
20 | // - make a (shallow) copy of the Tensor object, which increments |
21 | // the reference count on the variable's TensorBuffer; |
22 | // - release the variable's mutex; |
23 | // - use the copy of the Tensor object to do the read. |
24 | // * For write operations, we: |
25 | // - acquire the variable's mutex (in "exclusive" mode); |
26 | // - check the reference count of variable's TensorBuffer and |
27 | // if it is >1, make a deep copy of the variable's Tensor; |
28 | // - mutate the variable's Tensor; |
29 | // - and release the variable's mutex. |
30 | // This allows several read operations to all use the same |
31 | // TensorBuffer without needing to copy. When it comes time to write |
32 | // it will only make a copy if there is an outstanding read using the |
33 | // buffer. Write operations are serialized by the variable's mutex. |
34 | // |
35 | // For sparse operations (scatter, gather, sparse optimizer updates), |
36 | // we need to avoid copies, since there may not be enough memory for |
37 | // to copies of the whole tensor. To support this, we make two |
38 | // modifications to the above strategy: |
39 | // * For sparse reads (gather), we hold the variable's mutex (still in |
40 | // "shared" mode) for the duration of the whole read. This means |
41 | // that as long as you only do sparse read operations no write will |
42 | // see the reference count >1. |
43 | // * For sparse write operations where the user explicitly specifies |
44 | // that they want to perform the write without locks held |
45 | // (use_locking=false), we never copy even if the variable's |
46 | // reference count is >1. |
47 | |
48 | #define EIGEN_USE_THREADS |
49 | |
50 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
51 | #define EIGEN_USE_GPU |
52 | #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" |
53 | #include "tensorflow/core/platform/stream_executor.h" |
54 | #endif |
55 | |
56 | #include <memory> |
57 | #include <type_traits> |
58 | #include <vector> |
59 | |
60 | #include "absl/strings/str_join.h" |
61 | #include "tensorflow/core/common_runtime/device.h" |
62 | #include "tensorflow/core/framework/bounds_check.h" |
63 | #include "tensorflow/core/framework/op_kernel.h" |
64 | #include "tensorflow/core/framework/register_types.h" |
65 | #include "tensorflow/core/framework/resource_mgr.h" |
66 | #include "tensorflow/core/framework/tensor_shape.h" |
67 | #include "tensorflow/core/framework/tensor_types.h" |
68 | #include "tensorflow/core/framework/variant_op_registry.h" |
69 | #include "tensorflow/core/kernels/dense_update_functor.h" |
70 | #include "tensorflow/core/kernels/gather_functor.h" |
71 | #include "tensorflow/core/kernels/gather_nd_op.h" |
72 | #include "tensorflow/core/kernels/resource_variable_ops.h" |
73 | #include "tensorflow/core/kernels/resource_variable_util.h" |
74 | #include "tensorflow/core/kernels/scatter_functor.h" |
75 | #include "tensorflow/core/kernels/training_op_helpers.h" |
76 | #include "tensorflow/core/kernels/variable_ops.h" |
77 | #include "tensorflow/core/lib/core/errors.h" |
78 | #include "tensorflow/core/lib/core/refcount.h" |
79 | #include "tensorflow/core/platform/casts.h" |
80 | #include "tensorflow/core/platform/mem.h" |
81 | #include "tensorflow/core/platform/mutex.h" |
82 | #include "tensorflow/core/platform/types.h" |
83 | #include "tensorflow/core/util/determinism.h" |
84 | #include "tensorflow/core/util/util.h" |
85 | |
86 | namespace tensorflow { |
87 | |
88 | REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp" ).Device(DEVICE_CPU), |
89 | ResourceHandlesOp<Var>); |
90 | |
91 | ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { |
92 | OP_REQUIRES_OK(c, c->GetAttr("dtype" , &dtype_)); |
93 | } |
94 | |
95 | namespace { |
96 | |
97 | Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { |
98 | Tensor* output; |
99 | Notification n; |
100 | Status status; |
101 | AllocatorAttributes attr; |
102 | if (t->dtype() == DT_VARIANT) { |
103 | attr.set_on_host(true); |
104 | } |
105 | TF_RETURN_IF_ERROR( |
106 | ctx->allocate_output(output_idx, t->shape(), &output, attr)); |
107 | if (t->dtype() == DT_VARIANT) { |
108 | output->flat<Variant>() = t->flat<Variant>(); |
109 | } else if (ctx->op_device_context() != nullptr) { |
110 | // TODO(apassos): remove the down_cast by just returning Device* from |
111 | // OpKernelContext |
112 | Device* device = down_cast<Device*>(ctx->device()); |
113 | ctx->op_device_context()->CopyTensorInSameDevice( |
114 | t, device, output, [&n, &status](const Status& s) { |
115 | status = s; |
116 | n.Notify(); |
117 | }); |
118 | n.WaitForNotification(); |
119 | return status; |
120 | } else { |
121 | switch (t->dtype()) { |
122 | #define HANDLER(type) \ |
123 | case DataTypeToEnum<type>::value: \ |
124 | output->flat<type>() = t->flat<type>(); \ |
125 | break; |
126 | TF_CALL_ALL_TYPES(HANDLER); |
127 | #undef HANDLER |
128 | default: |
129 | return errors::Internal("Unsupported dtype" , t->dtype()); |
130 | } |
131 | } |
132 | return OkStatus(); |
133 | } |
134 | |
135 | } // namespace |
136 | |
137 | void ReadVariableOp::Compute(OpKernelContext* ctx) { |
138 | core::RefCountPtr<Var> variable; |
139 | const ResourceHandle& handle = HandleFromInput(ctx, 0); |
140 | const auto status = LookupResource(ctx, handle, &variable); |
141 | OP_REQUIRES(ctx, status.ok(), |
142 | errors::FailedPrecondition( |
143 | "Could not find variable " , handle.name(), ". " , |
144 | "This could mean that the variable has been deleted. " , |
145 | "In TF1, it can also mean the variable is uninitialized. " , |
146 | "Debug info: container=" , handle.container(), |
147 | ", status error message=" , status.error_message())); |
148 | |
149 | tf_shared_lock ml(*variable->mu()); |
150 | // We're acquiring a reference to the underlying buffer while |
151 | // holding a shared lock to guarantee ordering of reads and |
152 | // writes when in copy-on-write mode. |
153 | const Tensor* t = variable->tensor(); |
154 | if (!variable->copy_on_read_mode.load()) { |
155 | OP_REQUIRES( |
156 | ctx, dtype_ == t->dtype(), |
157 | errors::InvalidArgument( |
158 | "Trying to read variable with wrong dtype. Expected " , |
159 | DataTypeString(dtype_), " got " , DataTypeString(t->dtype()))); |
160 | ctx->set_output(0, *t); |
161 | } else { |
162 | OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t)); |
163 | } |
164 | } |
165 | |
166 | ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) { |
167 | int n; |
168 | OP_REQUIRES_OK(c, c->GetAttr("N" , &n)); |
169 | OP_REQUIRES_OK(c, c->GetAttr("dtypes" , &dtypes_)); |
170 | OP_REQUIRES(c, n == dtypes_.size(), |
171 | errors::InvalidArgument( |
172 | "Mismatched number of arguments to ReadVariablesOp (" , n, |
173 | " vs. " , dtypes_.size(), ")" )); |
174 | } |
175 | |
176 | void ReadVariablesOp::Compute(OpKernelContext* ctx) { |
177 | std::vector<core::RefCountPtr<Var>> variables(dtypes_.size()); |
178 | std::vector<const ResourceHandle*> handles(dtypes_.size()); |
179 | for (size_t i = 0; i < dtypes_.size(); ++i) { |
180 | handles[i] = &HandleFromInput(ctx, i); |
181 | } |
182 | |
183 | OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables)); |
184 | |
185 | std::vector<string> uninitialized_vars; |
186 | for (int64_t i = 0; i < variables.size(); i++) { |
187 | if (variables[i] == nullptr) { |
188 | uninitialized_vars.push_back(handles[i]->name()); |
189 | } |
190 | } |
191 | |
192 | OP_REQUIRES(ctx, uninitialized_vars.empty(), |
193 | errors::FailedPrecondition( |
194 | "In ReadVariablesOp the following variables were " |
195 | "found uninitialized: " , |
196 | absl::StrJoin(uninitialized_vars, ", " ))); |
197 | |
198 | for (size_t i = 0; i < dtypes_.size(); ++i) { |
199 | // We're acquiring a reference to the underlying buffer while |
200 | // holding a shared lock to guarantee ordering of reads and |
201 | // writes. |
202 | tf_shared_lock ml(*variables[i]->mu()); |
203 | OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(), |
204 | errors::InvalidArgument( |
205 | "Trying to read variable " , handles[i]->name(), |
206 | " from Container: " , handles[i]->container(), |
207 | " with wrong dtype. Expected " , DataTypeString(dtypes_[i]), |
208 | " got " , DataTypeString(variables[i]->tensor()->dtype()))); |
209 | if (variables[i]->copy_on_read_mode.load()) { |
210 | OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor())); |
211 | } else { |
212 | const Tensor& t = *variables[i]->tensor(); |
213 | ctx->set_output(i, t); |
214 | } |
215 | } |
216 | } |
217 | |
218 | REGISTER_KERNEL_BUILDER(Name("ReadVariableOp" ).Device(DEVICE_CPU), |
219 | ReadVariableOp); |
220 | REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp" ).Device(DEVICE_CPU), |
221 | ReadVariablesOp); |
222 | |
223 | REGISTER_KERNEL_BUILDER( |
224 | Name("ReadVariableOp" ).Device(DEVICE_DEFAULT).HostMemory("resource" ), |
225 | ReadVariableOp); |
226 | REGISTER_KERNEL_BUILDER( |
227 | Name("_ReadVariablesOp" ).Device(DEVICE_DEFAULT).HostMemory("resources" ), |
228 | ReadVariablesOp); |
229 | |
230 | VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { |
231 | OP_REQUIRES_OK(context, context->GetAttr("container" , &container_)); |
232 | OP_REQUIRES_OK(context, context->GetAttr("shared_name" , &name_)); |
233 | |
234 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_and_shape_.dtype)); |
235 | OP_REQUIRES_OK(context, context->GetAttr("shape" , &dtype_and_shape_.shape)); |
236 | |
237 | is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME; |
238 | |
239 | // Use const_tensor_ if the variable is non-anonymous. |
240 | if (!is_anonymous_) { |
241 | AllocatorAttributes attr; |
242 | attr.set_on_host(true); |
243 | OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}), |
244 | &const_tensor_, attr)); |
245 | const_tensor_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>( |
246 | context, container_, name_, |
247 | std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_}); |
248 | } |
249 | } |
250 | |
251 | void VarHandleOp::Compute(OpKernelContext* ctx) { |
252 | if (is_anonymous_) { |
253 | Var* resource = new Var(dtype_and_shape_.dtype); |
254 | ResourceMgr* mgr = ctx->resource_manager(); |
255 | ResourceHandle handle = ResourceHandle::MakeRefCountingHandle<Var>( |
256 | resource, ctx->device()->name(), |
257 | std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_}, |
258 | ctx->stack_trace()); |
259 | // TODO(b/203901837): See if we can abolish all code paths that lookup |
260 | // anonymous variables and then stop publishing them to the manager. |
261 | OP_REQUIRES_OK(ctx, mgr->CreateUnowned<Var>(handle.container(), |
262 | handle.name(), resource)); |
263 | |
264 | AllocatorAttributes attr; |
265 | attr.set_on_host(true); |
266 | Tensor tensor; |
267 | OP_REQUIRES_OK( |
268 | ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr)); |
269 | |
270 | tensor.scalar<ResourceHandle>()() = std::move(handle); |
271 | |
272 | ctx->set_output(0, tensor); |
273 | } else { |
274 | ctx->set_output(0, const_tensor_); |
275 | } |
276 | } |
277 | |
278 | REGISTER_KERNEL_BUILDER(Name("VarHandleOp" ).Device(DEVICE_CPU), VarHandleOp); |
279 | |
280 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
281 | #define REGISTER_GPU_KERNELS(type) \ |
282 | namespace functor { \ |
283 | template <> \ |
284 | void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ |
285 | const GPUDevice& d, typename TTypes<type>::Flat lhs, \ |
286 | typename TTypes<type>::ConstFlat rhs); \ |
287 | extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \ |
288 | } |
289 | |
290 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); |
291 | TF_CALL_bfloat16(REGISTER_GPU_KERNELS); |
292 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS); |
293 | TF_CALL_variant(REGISTER_GPU_KERNELS); |
294 | #undef REGISTER_GPU_KERNELS |
295 | |
296 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
297 | |
298 | #define REGISTER_DEFAULT_KERNELS(type) \ |
299 | REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \ |
300 | .Device(DEVICE_DEFAULT) \ |
301 | .HostMemory("resource") \ |
302 | .TypeConstraint<type>("dtype"), \ |
303 | VarHandleOp) |
304 | TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS); |
305 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_DEFAULT_KERNELS); |
306 | TF_CALL_bfloat16(REGISTER_DEFAULT_KERNELS); |
307 | TF_CALL_variant(REGISTER_DEFAULT_KERNELS); |
308 | #undef REGISTER_DEFAULT_KERNELS |
309 | |
310 | REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp" ) |
311 | .Device(DEVICE_DEFAULT) |
312 | .HostMemory("resources" ) |
313 | .TypeConstraint("dtypes" , |
314 | {DT_INT64, DT_COMPLEX64, |
315 | DT_COMPLEX128, DT_HALF, DT_FLOAT, |
316 | DT_DOUBLE, DT_BOOL, DT_VARIANT}), |
317 | ResourceHandlesOp<Var>); |
318 | |
319 | REGISTER_KERNEL_BUILDER( |
320 | Name("VariableShape" ).Device(DEVICE_CPU).TypeConstraint<int32>("out_type" ), |
321 | VariableShapeOp<int32>); |
322 | REGISTER_KERNEL_BUILDER(Name("VariableShape" ) |
323 | .Device(DEVICE_CPU) |
324 | .TypeConstraint<int64_t>("out_type" ), |
325 | VariableShapeOp<int64_t>); |
326 | |
327 | REGISTER_KERNEL_BUILDER(Name("VariableShape" ) |
328 | .Device(DEVICE_DEFAULT) |
329 | .TypeConstraint<int32>("out_type" ) |
330 | .HostMemory("output" ) |
331 | .HostMemory("input" ), |
332 | VariableShapeOp<int32>); |
333 | REGISTER_KERNEL_BUILDER(Name("VariableShape" ) |
334 | .Device(DEVICE_DEFAULT) |
335 | .TypeConstraint<int64_t>("out_type" ) |
336 | .HostMemory("output" ) |
337 | .HostMemory("input" ), |
338 | VariableShapeOp<int64_t>); |
339 | |
340 | DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx) |
341 | : OpKernel(ctx) { |
342 | OP_REQUIRES_OK(ctx, |
343 | ctx->GetAttr("ignore_lookup_error" , &ignore_lookup_error_)); |
344 | } |
345 | |
346 | void DestroyResourceOp::Compute(OpKernelContext* ctx) { |
347 | const ResourceHandle& p = HandleFromInput(ctx, 0); |
348 | Status status = DeleteResource(ctx, p); |
349 | if (ignore_lookup_error_ && errors::IsNotFound(status)) { |
350 | return; |
351 | } |
352 | OP_REQUIRES_OK(ctx, status); |
353 | } |
354 | |
355 | REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp" ).Device(DEVICE_CPU), |
356 | DestroyResourceOp); |
357 | REGISTER_KERNEL_BUILDER( |
358 | Name("DestroyResourceOp" ).Device(DEVICE_DEFAULT).HostMemory("resource" ), |
359 | DestroyResourceOp); |
360 | |
361 | void DisableCopyOnReadOp::Compute(OpKernelContext* ctx) { |
362 | core::RefCountPtr<Var> variable; |
363 | const ResourceHandle& handle = HandleFromInput(ctx, 0); |
364 | const auto status = LookupResource(ctx, handle, &variable); |
365 | OP_REQUIRES(ctx, status.ok(), |
366 | errors::FailedPrecondition( |
367 | "Could not find variable " , handle.name(), ". " , |
368 | "This could mean that the variable has been deleted. " , |
369 | "In TF1, it can also mean the variable is uninitialized. " , |
370 | "Debug info: container=" , handle.container(), |
371 | ", status error message=" , status.error_message())); |
372 | // If the variable is currently in copy-on-read mode, its refcount is 1 |
373 | if (variable->copy_on_read_mode.load()) { |
374 | // Obtain an exclusive lock on the variable and change the access mode |
375 | mutex_lock ml(*variable->mu()); |
376 | variable->copy_on_read_mode.store(false); |
377 | } |
378 | } |
379 | |
380 | REGISTER_KERNEL_BUILDER(Name("DisableCopyOnRead" ).Device(DEVICE_CPU), |
381 | DisableCopyOnReadOp); |
382 | REGISTER_KERNEL_BUILDER( |
383 | Name("DisableCopyOnRead" ).Device(DEVICE_DEFAULT).HostMemory("resource" ), |
384 | DisableCopyOnReadOp); |
385 | |
386 | template <typename Device, typename T> |
387 | class AssignVariableOp : public OpKernel { |
388 | public: |
389 | explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) { |
390 | OP_REQUIRES_OK(c, c->GetAttr("dtype" , &dtype_)); |
391 | if (!c->GetAttr("_grappler_relax_allocator_constraints" , |
392 | &relax_constraints_) |
393 | .ok()) { |
394 | relax_constraints_ = false; |
395 | } |
396 | if (c->HasAttr("validate_shape" )) { |
397 | OP_REQUIRES_OK(c, c->GetAttr("validate_shape" , &validate_shape_)); |
398 | } |
399 | } |
400 | |
401 | void Compute(OpKernelContext* context) override { |
402 | OP_REQUIRES(context, dtype_ == context->input(1).dtype(), |
403 | errors::InvalidArgument( |
404 | "Variable and value dtypes don't match; respectively, " , |
405 | DataTypeString(dtype_), " and " , |
406 | DataTypeString(context->input(1).dtype()))); |
407 | core::RefCountPtr<Var> variable; |
408 | const Tensor& value = context->input(1); |
409 | // Note: every resource-variable-manipulating op assumes copy-on-write |
410 | // semantics, and creates a copy of the variable's Tensor if its refcount is |
411 | // bigger than 1 when we try to modify it. This means we never need to copy |
412 | // the original tensor for AssignVariableOp; even if there are other live |
413 | // users of it we know none can modify it so this is always safe (even in |
414 | // esoteric cases where the same tensor is used to initialize multiple |
415 | // variables or the tensor is a constant this is safe, as future writes will |
416 | // trigger copies). |
417 | OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( |
418 | context, HandleFromInput(context, 0), &variable, |
419 | [this, &value](Var** ptr) { |
420 | *ptr = new Var(dtype_); |
421 | *(*ptr)->tensor() = value; |
422 | (*ptr)->is_initialized = true; |
423 | return OkStatus(); |
424 | })); |
425 | mutex_lock ml(*variable->mu()); |
426 | // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized) |
427 | // check below is to allow an XLA specific situation wherein update can |
428 | // happen first by the AssignVariableOp, |
429 | // in which case the variable is still uninitialized. |
430 | // When using TF-XLA, this scenario is possible when the execution uses the |
431 | // 'fallback' path (which essentially invokes Tensorflow ops via |
432 | // partitioned_call). |
433 | OP_REQUIRES(context, |
434 | (variable->tensor()->dtype() == DT_INVALID && |
435 | !variable->is_initialized) || |
436 | variable->tensor()->dtype() == dtype_, |
437 | errors::InvalidArgument( |
438 | "Trying to assign variable with wrong dtype. Expected " , |
439 | DataTypeString(variable->tensor()->dtype()), " got " , |
440 | DataTypeString(dtype_))); |
441 | if (validate_shape_) { |
442 | OP_REQUIRES( |
443 | context, |
444 | (!variable->is_initialized || |
445 | variable->tensor()->shape().IsSameSize(value.shape())), |
446 | errors::InvalidArgument( |
447 | "Trying to assign to variable with tensor with wrong shape." |
448 | " Expected " , |
449 | variable->tensor()->shape().DebugString(), " got " , |
450 | value.shape().DebugString())); |
451 | } |
452 | if (variable->copy_on_read_mode.load()) { |
453 | AllocatorAttributes attr; |
454 | attr.set_gpu_compatible(true); |
455 | attr.set_nic_compatible(true); |
456 | OP_REQUIRES_OK(context, |
457 | context->allocate_temp(value.dtype(), value.shape(), |
458 | variable->tensor(), attr)); |
459 | functor::DenseUpdate<Device, T, ASSIGN> copy_functor; |
460 | copy_functor(context->eigen_device<Device>(), |
461 | variable->tensor()->flat<T>(), value.flat<T>()); |
462 | } else { |
463 | *variable->tensor() = value; |
464 | } |
465 | variable->is_initialized = true; |
466 | } |
467 | |
468 | private: |
469 | DataType dtype_; |
470 | bool relax_constraints_; |
471 | bool validate_shape_ = false; |
472 | }; |
473 | |
474 | template <typename Device> |
475 | class AssignVariableOp<Device, Variant> : public OpKernel { |
476 | public: |
477 | explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) { |
478 | OP_REQUIRES_OK(c, c->GetAttr("dtype" , &dtype_)); |
479 | OP_REQUIRES(c, dtype_ == DT_VARIANT, |
480 | errors::Internal("Variant kernel called with dtype: " , |
481 | DataTypeString(dtype_))); |
482 | } |
483 | |
484 | void Compute(OpKernelContext* context) override { |
485 | const Tensor& value = context->input(1); |
486 | core::RefCountPtr<Var> variable; |
487 | OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( |
488 | context, HandleFromInput(context, 0), &variable, |
489 | [](Var** ptr) { |
490 | // Created on host. |
491 | *ptr = new Var(DT_VARIANT); |
492 | return OkStatus(); |
493 | })); |
494 | |
495 | // For purposes of forwarding DT_VARIANT, we want the least |
496 | // restrictive attr; we already know the input is on host. |
497 | AllocatorAttributes attr; |
498 | |
499 | // Copying is unnecessary if we are the last user of the value |
500 | // tensor, we can just adopt the input tensor's buffer instead. |
501 | // Note that Variant objects themselves always reside on host. |
502 | // |
503 | // We nevertheless want to signal to the runtime that the tensor |
504 | // should reside in memory of the associated device, as Variant |
505 | // tensors may be marked as sitting on either CPU or GPU. This |
506 | // helps to elide one or more copies. |
507 | std::unique_ptr<Tensor> input_alias = context->forward_input( |
508 | 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT, |
509 | value.shape(), |
510 | DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */, |
511 | attr); |
512 | |
513 | mutex_lock ml(*variable->mu()); |
514 | OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT, |
515 | errors::InvalidArgument( |
516 | "Trying to assign variable with wrong dtype. Expected " , |
517 | DataTypeString(variable->tensor()->dtype()), " got " , |
518 | DataTypeString(DT_VARIANT))); |
519 | variable->is_initialized = true; |
520 | *variable->tensor() = Tensor(DT_VARIANT, value.shape()); |
521 | |
522 | if (input_alias) { |
523 | *variable->tensor() = *input_alias; |
524 | return; |
525 | } |
526 | |
527 | // Need to copy, but maybe we can re-use variable's buffer? |
528 | if (!variable->tensor()->RefCountIsOne() || |
529 | !variable->tensor()->shape().IsSameSize(value.shape())) { |
530 | // Allocation of DT_VARIANT is always on host. |
531 | attr.set_on_host(true); |
532 | OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(), |
533 | variable->tensor(), attr)); |
534 | } |
535 | |
536 | const auto elements_in = value.flat<Variant>(); |
537 | auto elements_out = variable->tensor()->flat<Variant>(); |
538 | for (int64_t i = 0; i < elements_in.size(); ++i) { |
539 | elements_out(i) = elements_in(i); |
540 | } |
541 | } |
542 | |
543 | private: |
544 | DataType dtype_; |
545 | }; |
546 | |
547 | #define REGISTER_KERNELS(type) \ |
548 | REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ |
549 | .Device(DEVICE_CPU) \ |
550 | .TypeConstraint<type>("dtype"), \ |
551 | AssignVariableOp<Eigen::ThreadPoolDevice, type>); |
552 | |
553 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
554 | TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); |
555 | #undef REGISTER_KERNELS |
556 | |
557 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
558 | #define REGISTER_GPU_KERNELS(type) \ |
559 | REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ |
560 | .Device(DEVICE_GPU) \ |
561 | .TypeConstraint<type>("dtype") \ |
562 | .HostMemory("resource"), \ |
563 | AssignVariableOp<GPUDevice, type>); |
564 | |
565 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); |
566 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS); |
567 | TF_CALL_bfloat16(REGISTER_GPU_KERNELS); |
568 | #undef REGISTER_GPU_KERNELS |
569 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
570 | |
571 | #define REGISTER_KERNELS(type) \ |
572 | REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ |
573 | .Device(DEVICE_DEFAULT) \ |
574 | .TypeConstraint<type>("dtype") \ |
575 | .HostMemory("resource"), \ |
576 | AssignVariableOp<CPUDevice, type>); |
577 | |
578 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
579 | TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); |
580 | #undef REGISTER_KERNELS |
581 | |
582 | template <typename Device, typename T, DenseUpdateType Op> |
583 | class AssignUpdateVariableOp : public OpKernel { |
584 | public: |
585 | explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {} |
586 | |
587 | void Compute(OpKernelContext* context) override { |
588 | core::RefCountPtr<Var> variable; |
589 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), |
590 | &variable)); |
591 | |
592 | const Tensor& value = context->input(1); |
593 | // TODO(apassos): We could possibly avoid the copy done by |
594 | // PrepareToUpdateVariable() for commutative operations like Op == |
595 | // ADD if value's refcount was 1. |
596 | mutex_lock ml(*variable->mu()); |
597 | Tensor* var_tensor = variable->tensor(); |
598 | OP_REQUIRES_OK(context, ValidateAssignUpdateVariableOpShapes( |
599 | var_tensor->shape(), value.shape())); |
600 | OP_REQUIRES_OK( |
601 | context, PrepareToUpdateVariable<Device, T>( |
602 | context, var_tensor, variable->copy_on_read_mode.load())); |
603 | functor::DenseUpdate<Device, T, Op> update_functor; |
604 | update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(), |
605 | value.flat<T>()); |
606 | } |
607 | }; |
608 | |
609 | #define REGISTER_KERNELS(type) \ |
610 | REGISTER_KERNEL_BUILDER( \ |
611 | Name("AssignAddVariableOp") \ |
612 | .Device(DEVICE_CPU) \ |
613 | .TypeConstraint<type>("dtype"), \ |
614 | AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \ |
615 | REGISTER_KERNEL_BUILDER( \ |
616 | Name("AssignSubVariableOp") \ |
617 | .Device(DEVICE_CPU) \ |
618 | .TypeConstraint<type>("dtype"), \ |
619 | AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>); |
620 | |
621 | TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); |
622 | #undef REGISTER_KERNELS |
623 | |
624 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
625 | #define REGISTER_GPU_KERNELS(type) \ |
626 | REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \ |
627 | .Device(DEVICE_GPU) \ |
628 | .HostMemory("resource") \ |
629 | .TypeConstraint<type>("dtype"), \ |
630 | AssignUpdateVariableOp<GPUDevice, type, ADD>); \ |
631 | REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \ |
632 | .Device(DEVICE_GPU) \ |
633 | .HostMemory("resource") \ |
634 | .TypeConstraint<type>("dtype"), \ |
635 | AssignUpdateVariableOp<GPUDevice, type, SUB>); |
636 | |
637 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); |
638 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS); |
639 | #undef REGISTER_GPU_KERNELS |
640 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
641 | |
642 | class VarIsInitializedOp : public OpKernel { |
643 | public: |
644 | explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {} |
645 | |
646 | void Compute(OpKernelContext* context) override { |
647 | Tensor* output = nullptr; |
648 | OP_REQUIRES_OK(context, |
649 | context->allocate_output(0, TensorShape({}), &output)); |
650 | auto output_tensor = output->tensor<bool, 0>(); |
651 | core::RefCountPtr<Var> variable; |
652 | Status s = LookupResource(context, HandleFromInput(context, 0), &variable); |
653 | if (!s.ok()) { |
654 | output_tensor() = false; |
655 | return; |
656 | } |
657 | mutex_lock ml(*variable->mu()); |
658 | output_tensor() = variable->is_initialized; |
659 | } |
660 | }; |
661 | |
662 | REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp" ).Device(DEVICE_CPU), |
663 | VarIsInitializedOp); |
664 | |
665 | REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp" ) |
666 | .Device(DEVICE_DEFAULT) |
667 | .HostMemory("resource" ) |
668 | .HostMemory("is_initialized" ), |
669 | VarIsInitializedOp); |
670 | |
671 | template <typename Device, typename T, typename Index> |
672 | class ResourceGatherOp : public OpKernel { |
673 | public: |
674 | explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) { |
675 | OP_REQUIRES_OK(c, c->GetAttr("batch_dims" , &batch_dims_)); |
676 | } |
677 | |
678 | void Compute(OpKernelContext* c) override { |
679 | core::RefCountPtr<Var> v; |
680 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
681 | OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get())); |
682 | // NOTE: We hold the lock for the whole gather operation instead |
683 | // of increasing the reference count of v->tensor() to avoid a |
684 | // situation where a write to the same variable will see a |
685 | // reference count greater than one and make a copy of the |
686 | // (potentially very large) tensor buffer. |
687 | tf_shared_lock ml(*v->mu()); |
688 | const Tensor& params = *v->tensor(); |
689 | const Tensor& indices = c->input(1); |
690 | OP_REQUIRES( |
691 | c, TensorShapeUtils::IsVectorOrHigher(params.shape()), |
692 | errors::InvalidArgument("params must be at least 1 dimensional" )); |
693 | OP_REQUIRES( |
694 | c, params.shape().dims() >= batch_dims_, |
695 | errors::InvalidArgument("params must have at least " , batch_dims_, |
696 | " (batch_dims) dimensions but it has shape " , |
697 | params.shape().DebugString())); |
698 | |
699 | // Check that we have enough index space |
700 | const int64_t N = indices.NumElements(); |
701 | OP_REQUIRES( |
702 | c, params.dim_size(0) <= std::numeric_limits<Index>::max(), |
703 | errors::InvalidArgument("params.shape[0] too large for " , |
704 | DataTypeString(DataTypeToEnum<Index>::v()), |
705 | " indexing: " , params.dim_size(0), " > " , |
706 | std::numeric_limits<Index>::max())); |
707 | |
708 | // The result shape is params.shape[:batch_dims] + |
709 | // indices.shape[batch_dims:] + params.shape[batch_dims+1:]. |
710 | TensorShape result_shape; |
711 | for (int i = 0; i < batch_dims_; ++i) { |
712 | result_shape.AddDim(params.dim_size(i)); |
713 | } |
714 | for (int i = batch_dims_; i < indices.dims(); ++i) { |
715 | result_shape.AddDim(indices.dim_size(i)); |
716 | } |
717 | for (int i = batch_dims_ + 1; i < params.dims(); ++i) { |
718 | result_shape.AddDim(params.dim_size(i)); |
719 | } |
720 | |
721 | Tensor* out = nullptr; |
722 | Tensor tmp; |
723 | if (params.dtype() == DT_VARIANT) { |
724 | tmp = Tensor(DT_VARIANT, result_shape); |
725 | c->set_output(0, tmp); |
726 | out = &tmp; |
727 | } else { |
728 | OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); |
729 | } |
730 | |
731 | if (N > 0) { |
732 | Tensor tmp_indices; |
733 | |
734 | // Points to the original or updated (if batch_dims is set) indices. |
735 | const Tensor* op_indices = &indices; |
736 | if (batch_dims_ > 0) { |
737 | OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(), |
738 | &tmp_indices)); |
739 | functor::DenseUpdate<Device, Index, ASSIGN> copy_functor; |
740 | copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(), |
741 | indices.flat<Index>()); |
742 | |
743 | AddBatchOffsets(c, &tmp_indices, params); |
744 | if (!c->status().ok()) return; |
745 | op_indices = &tmp_indices; |
746 | } |
747 | |
748 | int64_t gather_dim_size = 1; |
749 | for (int idx = 0; idx <= batch_dims_; ++idx) { |
750 | gather_dim_size *= params.dim_size(idx); |
751 | } |
752 | int64_t inner_size = 1; |
753 | for (int i = batch_dims_ + 1; i < params.dims(); ++i) { |
754 | inner_size *= params.dim_size(i); |
755 | } |
756 | auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size}); |
757 | const auto indices_flat = op_indices->flat<Index>(); |
758 | auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N}); |
759 | |
760 | functor::GatherFunctor<Device, T, Index> functor; |
761 | int64_t bad_i = functor(c, params_flat, indices_flat, out_flat); |
762 | |
763 | OP_REQUIRES( |
764 | c, bad_i < 0, |
765 | errors::InvalidArgument( |
766 | "indices" , SliceDebugString(indices.shape(), bad_i), " = " , |
767 | indices_flat(bad_i), " is not in [0, " , params.dim_size(0), ")" )); |
768 | } |
769 | } |
770 | |
771 | private: |
772 | // Add the batch offset derived from params to each batch of indices. |
773 | // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]] |
774 | // If indexing into a params dimension of size 4, then the indices will become |
775 | // [0, 1, 2, 4, 5, 6] |
776 | void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices, |
777 | const Tensor& params) { |
778 | int64_t batch_size = 1; // The size of all batch dimensions. |
779 | for (int idx = 0; idx < batch_dims_; ++idx) { |
780 | batch_size *= params.dim_size(idx); |
781 | } |
782 | OP_REQUIRES( |
783 | ctx, batch_size != 0, |
784 | errors::InvalidArgument( |
785 | "Inner size of indices would result in batch_size of 0 and a " , |
786 | "division by 0 in the implementation. This is illegal" )); |
787 | |
788 | auto indices_flat = indices->flat<Index>(); |
789 | int64_t const index_inner_size = indices->NumElements() / batch_size; |
790 | int64_t const batch_offset = params.dim_size(batch_dims_); |
791 | for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size; |
792 | ++batch_idx) { |
793 | for (int64_t idx = 0; idx < index_inner_size; ++idx) { |
794 | indices_flat(dest_idx++) += batch_offset * batch_idx; |
795 | } |
796 | } |
797 | } |
798 | |
799 | int32 batch_dims_ = 0; |
800 | }; |
801 | |
802 | #define REGISTER_GATHER_FULL(dev, type, index_type) \ |
803 | REGISTER_KERNEL_BUILDER(Name("ResourceGather") \ |
804 | .Device(DEVICE_##dev) \ |
805 | .HostMemory("resource") \ |
806 | .TypeConstraint<type>("dtype") \ |
807 | .TypeConstraint<index_type>("Tindices"), \ |
808 | ResourceGatherOp<dev##Device, type, index_type>) |
809 | |
810 | #define REGISTER_GATHER_ALL_INDICES(dev, type) \ |
811 | REGISTER_GATHER_FULL(dev, type, int32); \ |
812 | REGISTER_GATHER_FULL(dev, type, int64_t) |
813 | |
814 | #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) |
815 | |
816 | // Registration of the CPU implementations. |
817 | TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); |
818 | TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); |
819 | |
820 | // Registers GPU kernels. |
821 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
822 | #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) |
823 | |
824 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GATHER_GPU); |
825 | TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU); |
826 | |
827 | // Variant objects themselves sit on CPU, even if they contain data |
828 | // pointing to a device. |
829 | REGISTER_KERNEL_BUILDER(Name("ResourceGather" ) |
830 | .Device(DEVICE_DEFAULT) |
831 | .HostMemory("resource" ) |
832 | .HostMemory("indices" ) |
833 | .TypeConstraint<Variant>("dtype" ) |
834 | .TypeConstraint<int32>("Tindices" ), |
835 | ResourceGatherOp<CPUDevice, Variant, int32>) |
836 | REGISTER_KERNEL_BUILDER(Name("ResourceGather" ) |
837 | .Device(DEVICE_DEFAULT) |
838 | .HostMemory("resource" ) |
839 | .HostMemory("indices" ) |
840 | .TypeConstraint<Variant>("dtype" ) |
841 | .TypeConstraint<int64_t>("Tindices" ), |
842 | ResourceGatherOp<CPUDevice, Variant, int64>) |
843 | |
844 | #undef REGISTER_GATHER_GPU |
845 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
846 | |
847 | #undef REGISTER_GATHER_CPU |
848 | #undef REGISTER_GATHER_ALL_INDICES |
849 | #undef REGISTER_GATHER_FULL |
850 | |
851 | template <typename Device, typename T, typename Index> |
852 | class ResourceGatherNdOp : public OpKernel { |
853 | public: |
854 | explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {} |
855 | |
856 | void Compute(OpKernelContext* c) override { |
857 | core::RefCountPtr<Var> v; |
858 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
859 | OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get())); |
860 | // NOTE: We hold the lock for the whole gather operation instead |
861 | // of increasing the reference count of v->tensor() to avoid a |
862 | // situation where a write to the same variable will see a |
863 | // reference count greater than one and make a copy of the |
864 | // (potentially very large) tensor buffer. |
865 | tf_shared_lock ml(*v->mu()); |
866 | const Tensor& params = *v->tensor(); |
867 | const Tensor& indices = c->input(1); |
868 | |
869 | Tensor out; |
870 | OP_REQUIRES_OK( |
871 | c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out)); |
872 | c->set_output(0, out); |
873 | } |
874 | }; |
875 | |
876 | #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \ |
877 | REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \ |
878 | .Device(DEVICE_##dev) \ |
879 | .HostMemory("resource") \ |
880 | .TypeConstraint<type>("dtype") \ |
881 | .TypeConstraint<index_type>("Tindices"), \ |
882 | ResourceGatherNdOp<dev##Device, type, index_type>) |
883 | |
884 | #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \ |
885 | REGISTER_GATHER_ND_FULL(dev, type, int32); \ |
886 | REGISTER_GATHER_ND_FULL(dev, type, int64_t) |
887 | |
888 | #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type) |
889 | |
890 | // Registration of the CPU implementations. |
891 | TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); |
892 | |
893 | // Registers GPU kernels. |
894 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
895 | #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type) |
896 | |
897 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU); |
898 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GATHER_ND_GPU); |
899 | |
900 | #undef REGISTER_GATHER_ND_GPU |
901 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
902 | |
903 | #undef REGISTER_GATHER_ND_CPU |
904 | #undef REGISTER_GATHER_ND_ALL_INDICES |
905 | #undef REGISTER_GATHER_ND_FULL |
906 | |
907 | namespace { |
908 | |
909 | template <typename Device> |
910 | bool isCPUDevice() { |
911 | return false; |
912 | } |
913 | |
914 | template <> |
915 | bool isCPUDevice<CPUDevice>() { |
916 | return true; |
917 | } |
918 | |
919 | template <typename T> |
920 | bool ValidateInput(const Tensor& updates) { |
921 | const auto updates_flat = updates.flat<T>(); |
922 | for (int i = 0; i < updates.NumElements(); ++i) { |
923 | if (updates_flat(i) == T{}) return false; |
924 | } |
925 | return true; |
926 | } |
927 | |
928 | template <> |
929 | bool ValidateInput<Variant>(const Tensor& updates) { |
930 | return true; |
931 | } |
932 | |
933 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
934 | Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices, |
935 | const Tensor& updates, Index num_indices); |
936 | |
937 | template <typename T, typename Index, scatter_op::UpdateOp Op> |
938 | Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices, |
939 | const Tensor& updates, Index num_indices); |
940 | |
941 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
942 | |
943 | template <typename T> |
944 | Status CopyTensorToHost(OpKernelContext* c, const Tensor& device_tensor, |
945 | Tensor* host_tensor) { |
946 | AllocatorAttributes alloc_attr; |
947 | alloc_attr.set_on_host(true); |
948 | alloc_attr.set_gpu_compatible(true); |
949 | auto stream = c->op_device_context()->stream(); |
950 | TF_RETURN_IF_ERROR(c->allocate_temp( |
951 | device_tensor.dtype(), device_tensor.shape(), host_tensor, alloc_attr)); |
952 | se::DeviceMemoryBase device_ptr( |
953 | const_cast<Tensor&>(device_tensor).flat<T>().data(), |
954 | device_tensor.flat<T>().size() * sizeof(T)); |
955 | stream->ThenMemcpy(host_tensor->flat<T>().data(), device_ptr, |
956 | device_tensor.NumElements() * sizeof(T)); |
957 | if (!stream) { |
958 | return errors::Internal("Failed to copy indices to host" ); |
959 | } |
960 | return OkStatus(); |
961 | } |
962 | |
963 | // Copies inputs to the CPU, runs DoScatter on the CPU, then copies output |
964 | // back to GPU. This is useful because the CPU implementation is deterministic |
965 | // and the GPU implementation is not. Tensor inputs to this function must be on |
966 | // the GPU. |
967 | template <typename T, typename Index, scatter_op::UpdateOp Op> |
968 | Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices, |
969 | const Tensor& updates, Index num_indices) { |
970 | if (!DataTypeCanUseMemcpy(params->dtype())) { |
971 | return errors::Unimplemented( |
972 | "GPU Scatter ops for dtype " , DataTypeString(params->dtype()), |
973 | " do not yet have a deterministic implementation" ); |
974 | } |
975 | auto stream = c->op_device_context()->stream(); |
976 | |
977 | Tensor host_indices; |
978 | TF_RETURN_IF_ERROR(CopyTensorToHost<Index>(c, indices, &host_indices)); |
979 | Tensor host_updates; |
980 | TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, updates, &host_updates)); |
981 | Tensor host_params; |
982 | TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, *params, &host_params)); |
983 | |
984 | TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); |
985 | TF_RETURN_IF_ERROR(DoScatter<CPUDevice, T, Index, Op>( |
986 | c, &host_params, host_indices, host_updates, num_indices)); |
987 | |
988 | // Copy 'host_params' to device. |
989 | se::DeviceMemoryBase params_ptr(params->flat<T>().data(), |
990 | params->flat<T>().size() * sizeof(T)); |
991 | stream->ThenMemcpy(¶ms_ptr, host_params.flat<T>().data(), |
992 | host_params.NumElements() * sizeof(T)); |
993 | if (!stream) { |
994 | return errors::Internal("Failed to copy params to device" ); |
995 | } |
996 | // Deallocate host_params' buffer once the host-to-device copy is complete. |
997 | // host_params is captured by value in the lambda so that its buffer is only |
998 | // destructed once the lambda is destructed. |
999 | c->device()->tensorflow_accelerator_device_info()->event_mgr->ThenExecute( |
1000 | stream, [host_params] {}); |
1001 | return OkStatus(); |
1002 | } |
1003 | |
1004 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1005 | |
1006 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
1007 | Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices, |
1008 | const Tensor& updates, Index num_indices) { |
1009 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1010 | if (std::is_same<Device, GPUDevice>::value && |
1011 | tensorflow::OpDeterminismRequired()) { |
1012 | return DoScatterOnCpu<T, Index, op>(c, params, indices, updates, |
1013 | num_indices); |
1014 | } |
1015 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1016 | |
1017 | // Run on the CPU for integer types, since the GPU implementation uses |
1018 | // atomics, which are not supported for all integer types. |
1019 | if constexpr (std::is_same<Device, GPUDevice>::value && |
1020 | std::is_integral<T>::value) { |
1021 | return DoScatterOnCpu<T, Index, op>(c, params, indices, updates, |
1022 | num_indices); |
1023 | } else { |
1024 | auto indices_flat = indices.flat<Index>(); |
1025 | auto params_flat = params->flat_outer_dims<T>(); |
1026 | if (TensorShapeUtils::IsScalar(updates.shape())) { |
1027 | const auto update = updates.scalar<T>(); |
1028 | |
1029 | functor::ScatterScalarFunctor<Device, T, Index, op> functor; |
1030 | const Index bad_i = functor(c, c->template eigen_device<Device>(), |
1031 | params_flat, update, indices_flat); |
1032 | if (bad_i >= 0) { |
1033 | return errors::InvalidArgument( |
1034 | "indices" , SliceDebugString(indices.shape(), bad_i), " = " , |
1035 | indices_flat(bad_i), " is not in [0, " , params->dim_size(0), ")" ); |
1036 | } |
1037 | } else { |
1038 | int64_t num_updates = updates.NumElements(); |
1039 | if (!TensorShapeUtils::StartsWith(updates.shape(), indices.shape())) { |
1040 | return errors::InvalidArgument( |
1041 | "The shape of indices (" , indices.shape().DebugString(), |
1042 | ") must be a prefix of the shape of updates (" , |
1043 | updates.shape().DebugString(), ")" ); |
1044 | } |
1045 | auto updates_flat = |
1046 | updates.shaped<T, 2>({num_indices, num_updates / num_indices}); |
1047 | functor::ScatterFunctor<Device, T, Index, op> functor; |
1048 | const Index bad_i = functor(c, c->template eigen_device<Device>(), |
1049 | params_flat, updates_flat, indices_flat); |
1050 | if (bad_i >= 0) { |
1051 | return errors::InvalidArgument( |
1052 | "indices" , SliceDebugString(indices.shape(), bad_i), " = " , |
1053 | indices_flat(bad_i), " is not in [0, " , params->dim_size(0), ")" ); |
1054 | } |
1055 | } |
1056 | } |
1057 | return OkStatus(); |
1058 | } |
1059 | |
1060 | } // namespace |
1061 | |
1062 | template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> |
1063 | class ResourceScatterUpdateOp : public OpKernel { |
1064 | public: |
1065 | explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { |
1066 | // We use the same kernel for many operations. |
1067 | // Each operation has a different set of attributes defined in its nodes. |
1068 | Status s = c->GetAttr("use_locking" , &use_exclusive_lock_); |
1069 | if (!s.ok()) { |
1070 | use_exclusive_lock_ = false; |
1071 | } |
1072 | } |
1073 | |
1074 | void Compute(OpKernelContext* c) override { |
1075 | core::RefCountPtr<Var> v; |
1076 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
1077 | OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get())); |
1078 | const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE || |
1079 | c->input_dtype(0) == DT_STRING || |
1080 | c->input_dtype(0) == DT_VARIANT; |
1081 | if (is_non_pod_dtype || use_exclusive_lock_) { |
1082 | mutex_lock ml(*v->mu()); |
1083 | DoCompute(c); |
1084 | } else { |
1085 | // For POD dtypes, we can safely run the update without the mutex. |
1086 | tf_shared_lock ml(*v->mu()); |
1087 | DoCompute(c); |
1088 | } |
1089 | } |
1090 | |
1091 | private: |
1092 | bool use_exclusive_lock_; |
1093 | |
1094 | void DoCompute(OpKernelContext* c) { |
1095 | core::RefCountPtr<Var> v; |
1096 | OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); |
1097 | Tensor* params = v->tensor(); |
1098 | const Tensor& indices = c->input(1); |
1099 | const Tensor& updates = c->input(2); |
1100 | |
1101 | // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:]) |
1102 | OP_REQUIRES(c, |
1103 | updates.dims() == 0 || |
1104 | updates.dims() == indices.dims() + params->dims() - 1, |
1105 | errors::InvalidArgument( |
1106 | "Must have updates.shape = indices.shape + " |
1107 | "params.shape[1:] or updates.shape = [], got " , |
1108 | "updates.shape " , updates.shape().DebugString(), |
1109 | ", indices.shape " , indices.shape().DebugString(), |
1110 | ", params.shape " , params->shape().DebugString())); |
1111 | |
1112 | // Check that we have enough index space |
1113 | const int64_t N_big = indices.NumElements(); |
1114 | OP_REQUIRES( |
1115 | c, N_big <= std::numeric_limits<Index>::max(), |
1116 | errors::InvalidArgument("indices has too many elements for " , |
1117 | DataTypeString(DataTypeToEnum<Index>::v()), |
1118 | " indexing: " , N_big, " > " , |
1119 | std::numeric_limits<Index>::max())); |
1120 | const Index N = static_cast<Index>(N_big); |
1121 | OP_REQUIRES( |
1122 | c, params->dim_size(0) <= std::numeric_limits<Index>::max(), |
1123 | errors::InvalidArgument("params.shape[0] too large for " , |
1124 | DataTypeString(DataTypeToEnum<Index>::v()), |
1125 | " indexing: " , params->dim_size(0), " > " , |
1126 | std::numeric_limits<Index>::max())); |
1127 | |
1128 | // Prevent division by 0 |
1129 | if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) { |
1130 | OP_REQUIRES(c, ValidateInput<T>(updates), |
1131 | errors::InvalidArgument("updates must not contain 0" )); |
1132 | } |
1133 | |
1134 | if (N > 0) { |
1135 | OP_REQUIRES_OK( |
1136 | c, DoScatter<Device, T, Index, op>(c, params, indices, updates, N)); |
1137 | } |
1138 | } |
1139 | }; |
1140 | |
1141 | #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ |
1142 | REGISTER_KERNEL_BUILDER( \ |
1143 | Name(name) \ |
1144 | .Device(DEVICE_##dev) \ |
1145 | .HostMemory("resource") \ |
1146 | .TypeConstraint<type>("dtype") \ |
1147 | .TypeConstraint<index_type>("Tindices"), \ |
1148 | ResourceScatterUpdateOp<dev##Device, type, index_type, op>) |
1149 | |
1150 | #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \ |
1151 | REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ |
1152 | REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op); |
1153 | |
1154 | #define REGISTER_SCATTER_ARITHMETIC(type, dev) \ |
1155 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ |
1156 | scatter_op::UpdateOp::ADD); \ |
1157 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \ |
1158 | scatter_op::UpdateOp::SUB); \ |
1159 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \ |
1160 | scatter_op::UpdateOp::MUL); \ |
1161 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \ |
1162 | scatter_op::UpdateOp::DIV); |
1163 | |
1164 | #define REGISTER_SCATTER_UPDATE(type, dev) \ |
1165 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ |
1166 | scatter_op::UpdateOp::ASSIGN); |
1167 | |
1168 | #define REGISTER_SCATTER_MINMAX(type, dev) \ |
1169 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \ |
1170 | scatter_op::UpdateOp::MIN); \ |
1171 | REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \ |
1172 | scatter_op::UpdateOp::MAX); |
1173 | |
1174 | // Registers CPU kernels. |
1175 | #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ |
1176 | REGISTER_SCATTER_ARITHMETIC(type, CPU); |
1177 | #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); |
1178 | #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); |
1179 | |
1180 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); |
1181 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); |
1182 | TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_CPU); |
1183 | TF_CALL_tstring(REGISTER_SCATTER_UPDATE_CPU); |
1184 | TF_CALL_bool(REGISTER_SCATTER_UPDATE_CPU); |
1185 | TF_CALL_variant(REGISTER_SCATTER_UPDATE_CPU); |
1186 | |
1187 | // Registers GPU kernels. |
1188 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1189 | #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ |
1190 | REGISTER_SCATTER_ARITHMETIC(type, GPU); |
1191 | #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); |
1192 | |
1193 | #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); |
1194 | |
1195 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU); |
1196 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_ARITHMETIC_GPU); |
1197 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU); |
1198 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_MINMAX_GPU); |
1199 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU); |
1200 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_SCATTER_UPDATE_GPU); |
1201 | |
1202 | REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate" ) |
1203 | .Device(DEVICE_DEFAULT) |
1204 | .HostMemory("resource" ) |
1205 | .HostMemory("indices" ) |
1206 | .TypeConstraint<Variant>("dtype" ) |
1207 | .TypeConstraint<int32>("Tindices" ), |
1208 | ResourceScatterUpdateOp<CPUDevice, Variant, int32, |
1209 | scatter_op::UpdateOp::ASSIGN>) |
1210 | REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate" ) |
1211 | .Device(DEVICE_GPU) |
1212 | .HostMemory("resource" ) |
1213 | .TypeConstraint<bool>("dtype" ) |
1214 | .TypeConstraint<int32>("Tindices" ), |
1215 | ResourceScatterUpdateOp<GPUDevice, bool, int32, |
1216 | scatter_op::UpdateOp::ASSIGN>) |
1217 | REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate" ) |
1218 | .Device(DEVICE_DEFAULT) |
1219 | .HostMemory("resource" ) |
1220 | .HostMemory("indices" ) |
1221 | .TypeConstraint<Variant>("dtype" ) |
1222 | .TypeConstraint<int64_t>("Tindices" ), |
1223 | ResourceScatterUpdateOp<CPUDevice, Variant, int64, |
1224 | scatter_op::UpdateOp::ASSIGN>) |
1225 | #undef REGISTER_SCATTER_ARITHMETIC_GPU |
1226 | #undef REGISTER_SCATTER_MINMAX_GPU |
1227 | #undef REGISTER_SCATTER_UPDATE_GPU |
1228 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1229 | |
1230 | #undef REGISTER_SCATTER_ARITHMETIC |
1231 | #undef REGISTER_SCATTER_ARITHMETIC_CPU |
1232 | #undef REGISTER_SCATTER_MINMAX |
1233 | #undef REGISTER_SCATTER_MINMAX_CPU |
1234 | #undef REGISTER_SCATTER_UPDATE |
1235 | #undef REGISTER_SCATTER_UPDATE_CPU |
1236 | #undef REGISTER_SCATTER_KERNEL |
1237 | #undef REGISTER_SCATTER_KERNEL_INDEX |
1238 | |
1239 | } // namespace tensorflow |
1240 | |