1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18
19#include <memory>
20#include <string>
21#include <typeindex>
22#include <typeinfo>
23#include <unordered_map>
24
25#include "absl/container/flat_hash_map.h"
26#include "absl/types/variant.h"
27#include "tensorflow/core/framework/common_shape_fns.h"
28#include "tensorflow/core/framework/device_attributes.pb.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/resource_base.h"
31#include "tensorflow/core/framework/resource_handle.h"
32#include "tensorflow/core/framework/tensor.h"
33#include "tensorflow/core/framework/tensor_shape.h"
34#include "tensorflow/core/framework/tensor_types.h"
35#include "tensorflow/core/framework/type_index.h"
36#include "tensorflow/core/framework/variant_tensor_data.h"
37#include "tensorflow/core/lib/core/errors.h"
38#include "tensorflow/core/lib/hash/hash.h"
39#include "tensorflow/core/platform/logging.h"
40#include "tensorflow/core/platform/macros.h"
41#include "tensorflow/core/platform/mutex.h"
42#include "tensorflow/core/platform/thread_annotations.h"
43
44namespace tensorflow {
45
46// A ResourceMgr instance keeps track of named and typed resources
47// grouped into containers.
48//
49// Each named resource is
50// registered with ResourceMgr under a named "container" name. At any
51// time, there is at most one instance of a resource given the container
52// name, the resource type and the resource name.
53//
54// All resources for a given container can be dropped by one call of
55// Cleanup().
56//
57// E.g.,
58// struct MyVar : public ResourceBase {
59// mutex mu;
60// Tensor val;
61// }
62//
63// ResourceMgr rm;
64//
65// // Create a var.
66// MyVar* my_var = new MyVar;
67// my_var->val = Tensor(DT_FLOAT, my_shape);
68// my_var->val.flat<float>().setZeros(); // 0 initialized.
69// ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
70//
71// // += a variable.
72// MyVar* my_var = nullptr;
73// Status s = rm.Lookup("my_container", "my_name", &my_var);
74// if (s.ok()) {
75// my_var->val.flat<float>() += grad;
76// }
77// my_var->Unref(); // Or use ScopedUnref().
78// ctx->SetStatus(s);
79
80// Container used for per-step resources.
81class ScopedStepContainer {
82 public:
83 // step_id: the unique ID of this step. Doesn't have to be sequential, just
84 // has to be unique.
85 // cleanup: callback to delete a container of this name.
86 // prefix: optional string prefix to disambiguate step containers.
87 ScopedStepContainer(const int64_t step_id,
88 std::function<void(const string&)> cleanup)
89 : step_id_(step_id),
90 container_(strings::StrCat("__per_step_", step_id)),
91 cleanup_(cleanup),
92 dirty_(false) {}
93
94 ScopedStepContainer(const int64_t step_id,
95 std::function<void(const string&)> cleanup,
96 const std::string& prefix)
97 : step_id_(step_id),
98 container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
99 cleanup_(cleanup),
100 dirty_(false) {}
101
102 ~ScopedStepContainer() { CleanUp(); }
103
104 void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
105 // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
106 // clean.
107 if (dirty_) {
108 mutex_lock ml(mu_);
109 cleanup_(container_);
110 dirty_ = false;
111 }
112 }
113
114 // Pass through functions for resource lookup and creation. We do this to
115 // ensure that we can appropriately set the dirty_ bit in the
116 // ScopedStepContainer if the name of the container is used to create
117 // resources.
118
119 // Pass through to MakeResourceHandle with the container name
120 template <typename T>
121 ResourceHandle MakeResourceHandle(
122 const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
123 // Pass through to ResourceMgr::Create with the container name
124 template <typename T>
125 Status Create(ResourceMgr* rm, const std::string& name,
126 T* resource) TF_MUST_USE_RESULT;
127 // Pass through to ResourceMgr::Delete with the container name
128 template <typename T>
129 Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
130 // Pass through to ResourceMgr::Lookup with the container name
131 template <typename T>
132 Status Lookup(ResourceMgr* rm, const std::string& name,
133 T** resource) const TF_MUST_USE_RESULT;
134 // Pass through to ResourceMgr::LookupOrCreate with the container name
135 template <typename T>
136 Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
137 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
138 int64_t StepId() const { return step_id_; }
139
140 private:
141 const int64_t step_id_;
142 const std::string container_;
143 const std::function<void(const string&)> cleanup_;
144 mutex mu_;
145 mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
146};
147
148class ResourceMgr {
149 public:
150 ResourceMgr();
151 explicit ResourceMgr(const std::string& default_container);
152 ~ResourceMgr();
153
154 // Returns the default container name for *this.
155 const std::string& default_container() const { return default_container_; }
156
157 // Creates a resource "name" in the "container". The caller transfers
158 // the ownership of one ref on "resource" to *this, regardless of whether this
159 // operation succeeds or fails.
160 //
161 // REQUIRES: std::is_base_of<ResourceBase, T>
162 // REQUIRES: resource != nullptr.
163 template <typename T>
164 Status Create(const std::string& container, const std::string& name,
165 T* resource) TF_MUST_USE_RESULT;
166
167 // Creates a unowned resource "name" in the "container". The caller does NOT
168 // transfer the ownership of any ref on "resource" to *this, regardless of
169 // whether this operation succeeds or fails.
170 //
171 // After the resource is destroyed, lookups from the manager fail.
172 // The caller must call this->Delete() on the name to free up the memory
173 // entry of the name.
174 //
175 // REQUIRES: std::is_base_of<ResourceBase, T>
176 // REQUIRES: resource != nullptr.
177 template <typename T>
178 Status CreateUnowned(const std::string& container, const std::string& name,
179 T* resource) TF_MUST_USE_RESULT;
180
181 // If "container" has a resource "name", returns it in "*resource" and
182 // the caller takes the ownership of one ref on "*resource".
183 //
184 // REQUIRES: std::is_base_of<ResourceBase, T>
185 // REQUIRES: resource != nullptr
186 template <typename T, bool use_dynamic_cast = false>
187 Status Lookup(const std::string& container, const std::string& name,
188 T** resource) const TF_MUST_USE_RESULT;
189
190 // If the resource manager has a resource matching "handle", returns it in
191 // "*resource" and the caller takes the ownership of one ref on "*resource".
192 //
193 // REQUIRES: resource != nullptr
194 Status Lookup(const ResourceHandle& handle,
195 ResourceBase** resource) const TF_MUST_USE_RESULT;
196
197 // Similar to Lookup, but looks up multiple resources at once, with only a
198 // single lock acquisition. If containers_and_names[i] is uninitialized
199 // then this function does not modify resources[i].
200 template <typename T, bool use_dynamic_cast = false>
201 Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
202 containers_and_names,
203 std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
204 resources) const TF_MUST_USE_RESULT;
205
206 // If "container" has a resource "name", returns it in
207 // "*resource". Otherwise, invokes creator() to create the resource.
208 // The caller takes the ownership of one ref on "*resource".
209 //
210 // WARNING: creator() must not call any methods on ResourceMgr during its
211 // execution, because a non-reentrant lock is held during the creator() call
212 // in order to guarantee atomicity of LookupOrCreate().
213 //
214 // REQUIRES: std::is_base_of<ResourceBase, T>
215 // REQUIRES: resource != nullptr
216 template <typename T, bool use_dynamic_cast = false>
217 Status LookupOrCreate(const std::string& container, const std::string& name,
218 T** resource,
219 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
220
221 // Deletes the resource "name" from the "container".
222 //
223 // REQUIRES: std::is_base_of<ResourceBase, T>
224 template <typename T>
225 Status Delete(const std::string& container,
226 const std::string& name) TF_MUST_USE_RESULT;
227
228 // Deletes the resource pointed by "handle".
229 Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
230
231 // Deletes all resources from the "container" and removes the container.
232 Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
233
234 // Deletes all resources in all containers.
235 void Clear();
236
237 // Returns a text description for all resources.
238 std::string DebugString() const;
239
240 private:
241 typedef std::pair<uint64, StringPiece> Key;
242 struct KeyHash {
243 std::size_t operator()(const Key& k) const {
244 return Hash64(k.second.data(), k.second.size(), k.first);
245 }
246 };
247 struct KeyEqual {
248 bool operator()(const Key& x, const Key& y) const {
249 return (x.second == y.second) && (x.first == y.first);
250 }
251 };
252 struct ResourceAndName {
253 absl::variant<core::RefCountPtr<ResourceBase>, core::WeakPtr<ResourceBase>>
254 resource;
255 std::unique_ptr<std::string> name;
256
257 ResourceAndName();
258 explicit ResourceAndName(const string& name);
259 ResourceAndName(ResourceAndName&& other) noexcept;
260 ~ResourceAndName();
261
262 ResourceAndName& operator=(ResourceAndName&&) noexcept;
263
264 // Returns a strong reference to resource, or nullptr if the resource is
265 // no longer valid.
266 core::RefCountPtr<ResourceBase> GetResource() const;
267
268 private:
269 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
270 };
271 typedef absl::flat_hash_map<Key, ResourceAndName, KeyHash, KeyEqual>
272 Container;
273
274 const std::string default_container_;
275 mutable mutex mu_;
276 absl::flat_hash_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
277
278 template <typename T, bool use_dynamic_cast = false>
279 Status LookupInternal(const std::string& container, const std::string& name,
280 T** resource) const
281 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
282 Status LookupInternal(const std::string& container, uint64 type_hash_code,
283 const std::string& name, ResourceBase** resource) const
284 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
285
286 Status DoCreate(const std::string& container, TypeIndex type,
287 const std::string& name, ResourceBase* resource,
288 bool owns_resource)
289 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
290
291 Status DoLookup(const std::string& container, TypeIndex type,
292 const std::string& name, ResourceBase** resource) const
293 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
294 Status DoLookup(const std::string& container, uint64 type_hash_code,
295 const std::string& type_name,
296 const std::string& resource_name,
297 ResourceBase** resource) const
298 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
299
300 Status DoDelete(const std::string& container, uint64 type_hash_code,
301 const std::string& resource_name,
302 const std::string& type_name) TF_MUST_USE_RESULT;
303 Status DoDelete(const std::string& container, TypeIndex type,
304 const std::string& resource_name) TF_MUST_USE_RESULT;
305
306 // Pops the ResourceAndName entry. The entry is moved from the list to
307 // the output argument `resource_and_name`.
308 Status PopResourceAndName(
309 const std::string& container, uint64 type_hash_code,
310 const std::string& resource_name, const std::string& type_name,
311 ResourceAndName& resource_and_name) TF_MUST_USE_RESULT;
312 // Inserts the type name for 'hash_code' into the hash_code to type name map.
313 Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
314 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
315
316 // Returns the type name for the 'hash_code'.
317 // Returns "<unknown>" if a resource with such a type was never inserted into
318 // the container.
319 const char* DebugTypeName(uint64 hash_code) const
320 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
321
322 // Map from type hash_code to type name.
323 std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
324
325 TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
326};
327
328// Makes a resource handle with the specified type for a given container /
329// name.
330ResourceHandle MakeResourceHandle(
331 const std::string& container, const std::string& name,
332 const DeviceBase& device, const TypeIndex& type_index,
333 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
334 const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
335 TF_MUST_USE_RESULT;
336
337template <typename T>
338ResourceHandle MakeResourceHandle(
339 OpKernelContext* ctx, const std::string& container, const std::string& name,
340 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
341 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
342 return MakeResourceHandle(container.empty()
343 ? ctx->resource_manager()->default_container()
344 : container,
345 name, *ctx->device(), TypeIndex::Make<T>(),
346 dtypes_and_shapes, definition_stack_trace);
347}
348
349template <typename T>
350ResourceHandle MakeResourceHandle(
351 OpKernelConstruction* ctx, const std::string& container,
352 const std::string& name,
353 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
354 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
355 return MakeResourceHandle(container.empty()
356 ? ctx->resource_manager()->default_container()
357 : container,
358 name, *ctx->device(), TypeIndex::Make<T>(),
359 dtypes_and_shapes, definition_stack_trace);
360}
361
362Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
363 const std::string& container,
364 const std::string& name,
365 const TypeIndex& type_index);
366
367// Returns a resource handle from a numbered op input.
368const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
369Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
370 ResourceHandle* handle);
371
372// Create a resource pointed by a given resource handle.
373//
374// If successful, the caller transfers the ownership of one ref on `resource` to
375// `ctx->resource_mgr()`.
376template <typename T>
377Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
378
379// Looks up a resource pointed by a given resource handle.
380//
381// If the lookup is successful, the caller takes the ownership of one ref on
382// `*value`, and must call its `Unref()` method when it has finished using it.
383template <typename T, bool use_dynamic_cast = false>
384Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
385
386// Looks up a resource pointed by a given resource handle.
387//
388// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
389// requiring the caller to explicitly call `Unref()`.
390template <typename T>
391Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
392 core::RefCountPtr<T>* value);
393
394// Looks up multiple resources pointed by a sequence of resource handles. If
395// p[i] is uninitialized then values[i] is unmodified.
396template <typename T>
397Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
398 std::vector<core::RefCountPtr<T>>* values);
399
400// Looks up or creates a resource.
401//
402// If successful, the caller takes the ownership of one ref on `*value`, and
403// must call its `Unref()` method when it has finished using it. If the
404// `creator` is invoked, its reference on the created resource is transferred
405// to `ctx->resource_mgr()`.
406//
407// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
408// requiring the caller to explicitly call `Unref()`.
409template <typename T>
410Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
411 T** value, std::function<Status(T**)> creator);
412
413// Looks up or creates a resource.
414template <typename T>
415Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
416 core::RefCountPtr<T>* value,
417 std::function<Status(T**)> creator);
418
419// Destroys a resource pointed by a given resource handle.
420template <typename T>
421Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
422
423// Same as above, but uses the hash code of the type directly.
424// The type name information will be missing in the debug output when the
425// resource is not present in the container.
426Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
427
428// Policy helper to decide which container/shared_name to use for a
429// stateful kernel that accesses shared resource.
430class ContainerInfo {
431 public:
432 // Analyze the node attribute of 'ndef' and decides the container and
433 // resource name the kernel should use for accessing the shared
434 // resource.
435 //
436 // 'ndef' is expected to have node attribute "container" and
437 // "shared_name". Returns non-OK if they are not provided or they are
438 // invalid.
439 //
440 // The policy is as following:
441 // * If the attribute "container" is non-empty, it is used as is.
442 // Otherwise, uses the resource manager's default container.
443 // * If the attribute "shared_name" is non-empty, it is used as is.
444 // Otherwise, if "use_node_name_as_default" is true, the kernel's
445 // node name is used as the resource name. Otherwise, a string
446 // unique to this process is used.
447 Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
448 bool use_node_name_as_default);
449 Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
450 return Init(rmgr, ndef, false);
451 }
452
453 // The policy decides that the kernel should access the resource in
454 // resource_manager(), the resource is in the container() and its
455 // name is name(). If resource_is_private_to_kernel() is true, the
456 // kernel should delete the resource when the kernel is deleted.
457 ResourceMgr* resource_manager() const { return rmgr_; }
458 const std::string& container() const { return container_; }
459 const std::string& name() const { return name_; }
460 bool resource_is_private_to_kernel() const {
461 return resource_is_private_to_kernel_;
462 }
463
464 // Returns a readable string for *this.
465 std::string DebugString() const;
466
467 private:
468 ResourceMgr* rmgr_ = nullptr;
469 std::string container_;
470 std::string name_;
471 bool resource_is_private_to_kernel_ = false;
472};
473
474// Helper for kernels to obtain 'resource' from the
475// ctx->resource_manager().
476//
477// "input_name" specifies the kernel's ref input which gives a string
478// tensor with two elements, which specifies the container and
479// resource name.
480//
481// Returns OK if the resource is found and transfers one ref of
482// *resource to the caller. Otherwise, returns an error.
483template <typename T>
484Status GetResourceFromContext(OpKernelContext* ctx,
485 const std::string& input_name, T** resource);
486
487// Utility op kernel to check if a handle to resource type T is initialized.
488template <typename T>
489class IsResourceInitialized : public OpKernel {
490 public:
491 explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
492
493 void Compute(OpKernelContext* ctx) override;
494};
495
496// Registers an op which produces just a resource handle to a resource of the
497// specified type. The type will be a part of the generated op name.
498// TODO(apassos): figure out how to get non-cpu-allocated tensors to work
499// through constant folding so this doesn't have to be marked as stateful.
500#define REGISTER_RESOURCE_HANDLE_OP(Type) \
501 REGISTER_OP(#Type "HandleOp") \
502 .Attr("container: string = ''") \
503 .Attr("shared_name: string = ''") \
504 .Output("resource: resource") \
505 .SetIsStateful() \
506 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
507
508// Utility op kernel to produce a handle to a resource of type T.
509template <typename T>
510class ResourceHandleOp : public OpKernel {
511 public:
512 explicit ResourceHandleOp(OpKernelConstruction* context);
513
514 void Compute(OpKernelContext* ctx) override;
515
516 bool IsExpensive() override { return false; }
517
518 private:
519 std::string container_;
520 std::string name_;
521 mutex mutex_;
522 Tensor resource_;
523 std::atomic<bool> initialized_{false};
524};
525
526// Utility op kernel to produce a handle to a resource of type T.
527template <typename T>
528class ResourceHandlesOp : public OpKernel {
529 public:
530 explicit ResourceHandlesOp(OpKernelConstruction* context);
531
532 void Compute(OpKernelContext* ctx) override;
533
534 bool IsExpensive() override { return false; }
535
536 private:
537 std::vector<string> containers_;
538 std::vector<string> names_;
539 mutex mutex_;
540 std::vector<Tensor> resources_;
541 std::atomic<bool> initialized_{false};
542};
543
544// Registers a kernel for an op which produces a handle to a resource of the
545// specified type.
546#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
547 REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
548 ResourceHandleOp<Type>)
549
550// This class is used to guarantee that an anonymous resource is deleted
551// (irrespective of whether a resource deleter op is called explicitly or
552// the execution encounters an error before the op runs).
553//
554// This is achieved by wrapping an instance of this class into a variant
555// tensor which is passed as an input to a resource deleter op. If the
556// execution encounters an error before the op runs, the tensor will be
557// destroyed, essentially triggering the iterator deletion.
558// NOTE: This is not a feature-complete implementation of the DT_VARIANT
559// specification. In particular, we cannot serialize the `ResourceMgr`
560// object, so the `Encode()` and `Decode()` methods are not implemented.
561class ResourceDeleter {
562 public:
563 ResourceDeleter() : deleter_() {}
564
565 ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
566 : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
567
568 ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
569 VLOG(3) << "ResourceDeleter move constructor called.";
570 }
571
572 ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
573 VLOG(3) << "ResourceDeleter copy constructor called.";
574 }
575
576 ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
577
578 ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
579
580 virtual ~ResourceDeleter() {
581 VLOG(3) << "ResourceDeleter destructor called.";
582 }
583
584 void Encode(VariantTensorData*) const {
585 LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
586 "objects.";
587 }
588
589 bool Decode(const VariantTensorData&) {
590 LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
591 "objects";
592 return false; // Not supported.
593 }
594
595 private:
596 // Helper that performs reference counting for the parent class and deletes
597 // the iterator resource when the refcount goes to zero.
598 //
599 // NOTE: The object is borrowing a pointer to the resource manager.
600 // Consequently, the tensor containing this object should not escape the
601 // function in which was created (so that it is guaranteed that the resource
602 // manager will outlive it).
603 struct Helper {
604 Helper(ResourceHandle handle, ResourceMgr* resource_manager)
605 : handle(handle), resource_manager(resource_manager) {}
606
607 Helper(const Helper& rhs) = delete;
608 Helper(Helper&& rhs) = delete;
609
610 ~Helper() {
611 VLOG(3) << "Deleting Resource: " << handle.DebugString();
612 resource_manager->Delete(handle).IgnoreError();
613 }
614
615 ResourceHandle handle;
616 ResourceMgr* resource_manager; // not owned
617 };
618
619 std::shared_ptr<Helper> deleter_;
620};
621
622// Implementation details below.
623
624template <typename T>
625void CheckDeriveFromResourceBase() {
626 static_assert(std::is_base_of<ResourceBase, T>::value,
627 "T must derive from ResourceBase");
628}
629
630template <typename T>
631Status ResourceMgr::Create(const std::string& container,
632 const std::string& name, T* resource) {
633 CheckDeriveFromResourceBase<T>();
634 CHECK(resource != nullptr);
635 mutex_lock l(mu_);
636 return DoCreate(container, TypeIndex::Make<T>(), name, resource,
637 /* owns_resource */ true);
638}
639
640template <typename T>
641Status ResourceMgr::CreateUnowned(const std::string& container,
642 const std::string& name, T* resource) {
643 CheckDeriveFromResourceBase<T>();
644 mutex_lock l(mu_);
645 return DoCreate(container, TypeIndex::Make<T>(), name, resource,
646 /* owns_resource */ false);
647}
648
649template <typename T, bool use_dynamic_cast>
650Status ResourceMgr::Lookup(const std::string& container,
651 const std::string& name, T** resource) const {
652 CheckDeriveFromResourceBase<T>();
653 tf_shared_lock l(mu_);
654 return LookupInternal<T, use_dynamic_cast>(container, name, resource);
655}
656
657template <typename T, bool use_dynamic_cast>
658Status ResourceMgr::LookupMany(
659 absl::Span<std::pair<const string*, const string*> const>
660 containers_and_names,
661 std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
662 CheckDeriveFromResourceBase<T>();
663 tf_shared_lock l(mu_);
664 resources->resize(containers_and_names.size());
665 for (size_t i = 0; i < containers_and_names.size(); ++i) {
666 T* resource;
667 Status s = LookupInternal<T, use_dynamic_cast>(
668 *containers_and_names[i].first, *containers_and_names[i].second,
669 &resource);
670 if (s.ok()) {
671 (*resources)[i].reset(resource);
672 }
673 }
674 return OkStatus();
675}
676
677// Simple wrapper to allow conditional dynamic / static casts.
678template <typename T, bool use_dynamic_cast>
679struct TypeCastFunctor {
680 static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
681};
682
683template <typename T>
684struct TypeCastFunctor<T, true> {
685 static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
686};
687
688template <typename T, bool use_dynamic_cast>
689Status ResourceMgr::LookupInternal(const std::string& container,
690 const std::string& name,
691 T** resource) const {
692 ResourceBase* found = nullptr;
693 Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
694 if (s.ok()) {
695 // It's safe to down cast 'found' to T* since
696 // typeid(T).hash_code() is part of the map key.
697 *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
698 }
699 return s;
700}
701
702template <typename T, bool use_dynamic_cast>
703Status ResourceMgr::LookupOrCreate(const std::string& container,
704 const std::string& name, T** resource,
705 std::function<Status(T**)> creator) {
706 CheckDeriveFromResourceBase<T>();
707 *resource = nullptr;
708 Status s;
709 {
710 tf_shared_lock l(mu_);
711 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
712 if (s.ok()) return s;
713 }
714 mutex_lock l(mu_);
715 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
716 if (s.ok()) return s;
717 TF_RETURN_IF_ERROR(creator(resource));
718 s = DoCreate(container, TypeIndex::Make<T>(), name, *resource,
719 /* owns_resource */ true);
720 if (!s.ok()) {
721 return errors::Internal("LookupOrCreate failed unexpectedly");
722 }
723 (*resource)->Ref();
724 return s;
725}
726
727template <typename T>
728Status ResourceMgr::Delete(const std::string& container,
729 const std::string& name) {
730 CheckDeriveFromResourceBase<T>();
731 return DoDelete(container, TypeIndex::Make<T>(), name);
732}
733
734template <typename T>
735Status GetResourceFromContext(OpKernelContext* ctx,
736 const std::string& input_name, T** resource) {
737 DataType dtype;
738 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
739 if (dtype == DT_RESOURCE) {
740 const Tensor* handle;
741 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
742 return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
743 }
744 std::string container;
745 std::string shared_name;
746 {
747 mutex* mu;
748 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
749 mutex_lock l(*mu);
750 Tensor tensor;
751 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
752 if (tensor.NumElements() != 2) {
753 return errors::InvalidArgument(
754 "Resource handle must have 2 elements, but had shape: ",
755 tensor.shape().DebugString());
756 }
757 container = tensor.flat<tstring>()(0);
758 shared_name = tensor.flat<tstring>()(1);
759 }
760 return ctx->resource_manager()->Lookup(container, shared_name, resource);
761}
762
763namespace internal {
764
765Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
766
767template <typename T>
768Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
769 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
770 TF_RETURN_IF_ERROR(p.ValidateType<T>());
771 return OkStatus();
772}
773
774} // namespace internal
775
776// Creates the resource pointed at by "p". The caller transfers the ownership of
777// one ref on "*value" to the resource manager in "ctx", regardless of whether
778// this operation succeeds or fails.
779template <typename T>
780Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
781 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
782 return ctx->resource_manager()->Create(p.container(), p.name(), value);
783}
784
785// Finds the resource as "*value" from the handle. If the handle is
786// ref-counting, returns the resource owned by the handle. Otherwise, looks up
787// the resource matching "p" from resource manager associated with ctx.
788// Always returns a new reference to the resource in "*value". The caller shall
789// call (*value)->Unref().
790template <typename T, bool use_dynamic_cast>
791Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
792 T** value) {
793 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
794 if (p.IsRefCounting()) {
795 TF_ASSIGN_OR_RETURN(*value, p.GetResource<T>());
796 // Transfers out a new reference.
797 (*value)->Ref();
798 return OkStatus();
799 }
800
801 return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
802 p.name(), value);
803}
804
805// Finds the resource as "*value" from the handle. This is a type-erased
806// variant of LookupResource above.
807Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
808 ResourceBase** value);
809
810// If the resource manager in "ctx" has a resource matching "p", returns it in
811// "*value".
812template <typename T>
813Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
814 core::RefCountPtr<T>* value) {
815 T* raw_ptr = nullptr;
816 TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
817 value->reset(raw_ptr);
818
819 return OkStatus();
820}
821
822// Similar to Lookup, but looks up multiple resources at once, with only a
823// single lock acquisition.
824template <typename T>
825Status LookupResources(OpKernelContext* ctx,
826 absl::Span<ResourceHandle const* const> p,
827 std::vector<core::RefCountPtr<T>>* values) {
828 std::vector<std::pair<const string*, const string*>> containers_and_names(
829 p.size());
830 for (size_t i = 0; i < p.size(); ++i) {
831 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
832 containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
833 }
834 return ctx->resource_manager()->LookupMany(containers_and_names, values);
835}
836
837// If the resource manager in "ctx" has a resource pointed at by "p", returns
838// it in "*value". Otherwise, invokes creator() to create the resource.
839// The caller takes the ownership of one ref on "*value".
840//
841// WARNING: creator() must not call any methods on the resource manager during
842// its execution, because a non-reentrant lock is held during the creator() call
843// in order to guarantee atomicity of LookupOrCreateResource().
844template <typename T>
845Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
846 T** value, std::function<Status(T**)> creator) {
847 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
848 return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
849 creator);
850}
851
852// If the resource manager in "ctx" has a resource pointed at by "p", returns
853// it in "*value". Otherwise, invokes creator() to create the resource.
854//
855// WARNING: creator() must not call any methods on the resource manager during
856// its execution, because a non-reentrant lock is held during the creator() call
857// in order to guarantee atomicity of LookupOrCreateResource().
858template <typename T>
859Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
860 core::RefCountPtr<T>* value,
861 std::function<Status(T**)> creator) {
862 T* raw_ptr = nullptr;
863 TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
864 value->reset(raw_ptr);
865
866 return OkStatus();
867}
868
869// Deletes the resource pointed by "p", using the resource manager in "ctx".
870template <typename T>
871Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
872 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
873 // This is a noop because ResourceMgr does not hold a reference.
874 // NOTE(feyu): if we can convert all resources handle to ref-counting, then
875 // DeleteResource can be removed.
876 if (p.IsRefCounting()) {
877 return OkStatus();
878 }
879 return ctx->resource_manager()->Delete<T>(p.container(), p.name());
880}
881
882// Deletes the resource pointed by "p", using the resource manager in "ctx".
883Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
884
885template <typename T>
886void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
887 Tensor* output;
888 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
889 T* object;
890 bool found;
891 if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
892 found = true;
893 object->Unref();
894 } else {
895 found = false;
896 }
897
898 output->flat<bool>()(0) = found;
899}
900
901template <typename T>
902ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
903 : OpKernel(context) {
904 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
905 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
906}
907
908template <typename T>
909void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
910 if (name_ == ResourceHandle::ANONYMOUS_NAME) {
911 AllocatorAttributes attr;
912 attr.set_on_host(true);
913 Tensor handle;
914 OP_REQUIRES_OK(
915 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
916 handle.scalar<ResourceHandle>()() = MakeResourceHandle<T>(
917 ctx, container_, name_, /*dtypes_and_shapes=*/{}, ctx->stack_trace());
918 ctx->set_output(0, handle);
919 } else {
920 if (!initialized_.load()) {
921 mutex_lock ml(mutex_);
922 // Checking again to see if another thread has initialized the resource.
923 if (!initialized_.load()) {
924 AllocatorAttributes attr;
925 attr.set_on_host(true);
926 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
927 &resource_, attr));
928 resource_.scalar<ResourceHandle>()() =
929 MakeResourceHandle<T>(ctx, container_, name_,
930 /*dtypes_and_shapes=*/{}, ctx->stack_trace());
931 initialized_.store(true);
932 }
933 }
934 ctx->set_output(0, resource_);
935 }
936}
937
938template <typename T>
939ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
940 : OpKernel(context) {
941 int n;
942 OP_REQUIRES_OK(context, context->GetAttr("N", &n));
943 OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
944 OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
945 OP_REQUIRES(
946 context, containers_.size() == n,
947 errors::InvalidArgument("Number of containers (", containers_.size(),
948 ") must be equal to N (", n, ")"));
949 OP_REQUIRES(context, names_.size() == n,
950 errors::InvalidArgument("Number of names (", containers_.size(),
951 ") must be equal to N (", n, ")"));
952 resources_.resize(n);
953}
954
955template <typename T>
956void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
957 if (!initialized_.load()) {
958 mutex_lock ml(mutex_);
959 // Checking again to see if another thread has initialized the resource.
960 if (!initialized_.load()) {
961 AllocatorAttributes attr;
962 attr.set_on_host(true);
963 for (size_t i = 0; i < resources_.size(); ++i) {
964 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
965 &resources_[i], attr));
966 ResourceHandle h =
967 MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
968 resources_[i].template scalar<ResourceHandle>()() = h;
969 }
970 initialized_.store(true);
971 }
972 }
973 for (size_t i = 0; i < resources_.size(); ++i) {
974 ctx->set_output(i, resources_[i]);
975 }
976}
977
978template <typename T>
979ResourceHandle ScopedStepContainer::MakeResourceHandle(
980 const std::string& name, const DeviceBase& device) {
981 mutex_lock ml(mu_);
982 dirty_ = true;
983 return tensorflow::MakeResourceHandle(container_, name, device,
984 TypeIndex::Make<T>(), {});
985}
986
987template <typename T>
988Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
989 T** resource) const {
990 return rm->Lookup<T>(container_, name, resource);
991}
992
993template <typename T>
994Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
995 const std::string& name,
996 T** resource,
997 std::function<Status(T**)> creator) {
998 mutex_lock ml(mu_);
999 dirty_ = true;
1000 return rm->LookupOrCreate<T>(container_, name, resource, creator);
1001}
1002
1003template <typename T>
1004Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
1005 T* resource) {
1006 mutex_lock ml(mu_);
1007 dirty_ = true;
1008 return rm->Create<T>(container_, name, resource);
1009}
1010
1011template <typename T>
1012Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
1013 return rm->Delete<T>(container_, name);
1014}
1015
1016} // end namespace tensorflow
1017
1018#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
1019