1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <typeindex> |
22 | #include <typeinfo> |
23 | #include <unordered_map> |
24 | |
25 | #include "absl/container/flat_hash_map.h" |
26 | #include "absl/types/variant.h" |
27 | #include "tensorflow/core/framework/common_shape_fns.h" |
28 | #include "tensorflow/core/framework/device_attributes.pb.h" |
29 | #include "tensorflow/core/framework/op_kernel.h" |
30 | #include "tensorflow/core/framework/resource_base.h" |
31 | #include "tensorflow/core/framework/resource_handle.h" |
32 | #include "tensorflow/core/framework/tensor.h" |
33 | #include "tensorflow/core/framework/tensor_shape.h" |
34 | #include "tensorflow/core/framework/tensor_types.h" |
35 | #include "tensorflow/core/framework/type_index.h" |
36 | #include "tensorflow/core/framework/variant_tensor_data.h" |
37 | #include "tensorflow/core/lib/core/errors.h" |
38 | #include "tensorflow/core/lib/hash/hash.h" |
39 | #include "tensorflow/core/platform/logging.h" |
40 | #include "tensorflow/core/platform/macros.h" |
41 | #include "tensorflow/core/platform/mutex.h" |
42 | #include "tensorflow/core/platform/thread_annotations.h" |
43 | |
44 | namespace tensorflow { |
45 | |
46 | // A ResourceMgr instance keeps track of named and typed resources |
47 | // grouped into containers. |
48 | // |
49 | // Each named resource is |
50 | // registered with ResourceMgr under a named "container" name. At any |
51 | // time, there is at most one instance of a resource given the container |
52 | // name, the resource type and the resource name. |
53 | // |
54 | // All resources for a given container can be dropped by one call of |
55 | // Cleanup(). |
56 | // |
57 | // E.g., |
58 | // struct MyVar : public ResourceBase { |
59 | // mutex mu; |
60 | // Tensor val; |
61 | // } |
62 | // |
63 | // ResourceMgr rm; |
64 | // |
65 | // // Create a var. |
66 | // MyVar* my_var = new MyVar; |
67 | // my_var->val = Tensor(DT_FLOAT, my_shape); |
68 | // my_var->val.flat<float>().setZeros(); // 0 initialized. |
69 | // ctx->SetStatus(rm.Create("my_container", "my_name", my_var)); |
70 | // |
71 | // // += a variable. |
72 | // MyVar* my_var = nullptr; |
73 | // Status s = rm.Lookup("my_container", "my_name", &my_var); |
74 | // if (s.ok()) { |
75 | // my_var->val.flat<float>() += grad; |
76 | // } |
77 | // my_var->Unref(); // Or use ScopedUnref(). |
78 | // ctx->SetStatus(s); |
79 | |
80 | // Container used for per-step resources. |
81 | class ScopedStepContainer { |
82 | public: |
83 | // step_id: the unique ID of this step. Doesn't have to be sequential, just |
84 | // has to be unique. |
85 | // cleanup: callback to delete a container of this name. |
86 | // prefix: optional string prefix to disambiguate step containers. |
87 | ScopedStepContainer(const int64_t step_id, |
88 | std::function<void(const string&)> cleanup) |
89 | : step_id_(step_id), |
90 | container_(strings::StrCat("__per_step_" , step_id)), |
91 | cleanup_(cleanup), |
92 | dirty_(false) {} |
93 | |
94 | ScopedStepContainer(const int64_t step_id, |
95 | std::function<void(const string&)> cleanup, |
96 | const std::string& prefix) |
97 | : step_id_(step_id), |
98 | container_(strings::StrCat("__" , prefix, "_per_step_" , step_id)), |
99 | cleanup_(cleanup), |
100 | dirty_(false) {} |
101 | |
102 | ~ScopedStepContainer() { CleanUp(); } |
103 | |
104 | void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS { |
105 | // NOTE(mrry): Avoid acquiring the mutex in the case that the container is |
106 | // clean. |
107 | if (dirty_) { |
108 | mutex_lock ml(mu_); |
109 | cleanup_(container_); |
110 | dirty_ = false; |
111 | } |
112 | } |
113 | |
114 | // Pass through functions for resource lookup and creation. We do this to |
115 | // ensure that we can appropriately set the dirty_ bit in the |
116 | // ScopedStepContainer if the name of the container is used to create |
117 | // resources. |
118 | |
119 | // Pass through to MakeResourceHandle with the container name |
120 | template <typename T> |
121 | ResourceHandle MakeResourceHandle( |
122 | const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT; |
123 | // Pass through to ResourceMgr::Create with the container name |
124 | template <typename T> |
125 | Status Create(ResourceMgr* rm, const std::string& name, |
126 | T* resource) TF_MUST_USE_RESULT; |
127 | // Pass through to ResourceMgr::Delete with the container name |
128 | template <typename T> |
129 | Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT; |
130 | // Pass through to ResourceMgr::Lookup with the container name |
131 | template <typename T> |
132 | Status Lookup(ResourceMgr* rm, const std::string& name, |
133 | T** resource) const TF_MUST_USE_RESULT; |
134 | // Pass through to ResourceMgr::LookupOrCreate with the container name |
135 | template <typename T> |
136 | Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource, |
137 | std::function<Status(T**)> creator) TF_MUST_USE_RESULT; |
138 | int64_t StepId() const { return step_id_; } |
139 | |
140 | private: |
141 | const int64_t step_id_; |
142 | const std::string container_; |
143 | const std::function<void(const string&)> cleanup_; |
144 | mutex mu_; |
145 | mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_); |
146 | }; |
147 | |
148 | class ResourceMgr { |
149 | public: |
150 | ResourceMgr(); |
151 | explicit ResourceMgr(const std::string& default_container); |
152 | ~ResourceMgr(); |
153 | |
154 | // Returns the default container name for *this. |
155 | const std::string& default_container() const { return default_container_; } |
156 | |
157 | // Creates a resource "name" in the "container". The caller transfers |
158 | // the ownership of one ref on "resource" to *this, regardless of whether this |
159 | // operation succeeds or fails. |
160 | // |
161 | // REQUIRES: std::is_base_of<ResourceBase, T> |
162 | // REQUIRES: resource != nullptr. |
163 | template <typename T> |
164 | Status Create(const std::string& container, const std::string& name, |
165 | T* resource) TF_MUST_USE_RESULT; |
166 | |
167 | // Creates a unowned resource "name" in the "container". The caller does NOT |
168 | // transfer the ownership of any ref on "resource" to *this, regardless of |
169 | // whether this operation succeeds or fails. |
170 | // |
171 | // After the resource is destroyed, lookups from the manager fail. |
172 | // The caller must call this->Delete() on the name to free up the memory |
173 | // entry of the name. |
174 | // |
175 | // REQUIRES: std::is_base_of<ResourceBase, T> |
176 | // REQUIRES: resource != nullptr. |
177 | template <typename T> |
178 | Status CreateUnowned(const std::string& container, const std::string& name, |
179 | T* resource) TF_MUST_USE_RESULT; |
180 | |
181 | // If "container" has a resource "name", returns it in "*resource" and |
182 | // the caller takes the ownership of one ref on "*resource". |
183 | // |
184 | // REQUIRES: std::is_base_of<ResourceBase, T> |
185 | // REQUIRES: resource != nullptr |
186 | template <typename T, bool use_dynamic_cast = false> |
187 | Status Lookup(const std::string& container, const std::string& name, |
188 | T** resource) const TF_MUST_USE_RESULT; |
189 | |
190 | // If the resource manager has a resource matching "handle", returns it in |
191 | // "*resource" and the caller takes the ownership of one ref on "*resource". |
192 | // |
193 | // REQUIRES: resource != nullptr |
194 | Status Lookup(const ResourceHandle& handle, |
195 | ResourceBase** resource) const TF_MUST_USE_RESULT; |
196 | |
197 | // Similar to Lookup, but looks up multiple resources at once, with only a |
198 | // single lock acquisition. If containers_and_names[i] is uninitialized |
199 | // then this function does not modify resources[i]. |
200 | template <typename T, bool use_dynamic_cast = false> |
201 | Status LookupMany(absl::Span<std::pair<const string*, const string*> const> |
202 | containers_and_names, |
203 | std::vector<std::unique_ptr<T, core::RefCountDeleter>>* |
204 | resources) const TF_MUST_USE_RESULT; |
205 | |
206 | // If "container" has a resource "name", returns it in |
207 | // "*resource". Otherwise, invokes creator() to create the resource. |
208 | // The caller takes the ownership of one ref on "*resource". |
209 | // |
210 | // WARNING: creator() must not call any methods on ResourceMgr during its |
211 | // execution, because a non-reentrant lock is held during the creator() call |
212 | // in order to guarantee atomicity of LookupOrCreate(). |
213 | // |
214 | // REQUIRES: std::is_base_of<ResourceBase, T> |
215 | // REQUIRES: resource != nullptr |
216 | template <typename T, bool use_dynamic_cast = false> |
217 | Status LookupOrCreate(const std::string& container, const std::string& name, |
218 | T** resource, |
219 | std::function<Status(T**)> creator) TF_MUST_USE_RESULT; |
220 | |
221 | // Deletes the resource "name" from the "container". |
222 | // |
223 | // REQUIRES: std::is_base_of<ResourceBase, T> |
224 | template <typename T> |
225 | Status Delete(const std::string& container, |
226 | const std::string& name) TF_MUST_USE_RESULT; |
227 | |
228 | // Deletes the resource pointed by "handle". |
229 | Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT; |
230 | |
231 | // Deletes all resources from the "container" and removes the container. |
232 | Status Cleanup(const std::string& container) TF_MUST_USE_RESULT; |
233 | |
234 | // Deletes all resources in all containers. |
235 | void Clear(); |
236 | |
237 | // Returns a text description for all resources. |
238 | std::string DebugString() const; |
239 | |
240 | private: |
241 | typedef std::pair<uint64, StringPiece> Key; |
242 | struct KeyHash { |
243 | std::size_t operator()(const Key& k) const { |
244 | return Hash64(k.second.data(), k.second.size(), k.first); |
245 | } |
246 | }; |
247 | struct KeyEqual { |
248 | bool operator()(const Key& x, const Key& y) const { |
249 | return (x.second == y.second) && (x.first == y.first); |
250 | } |
251 | }; |
252 | struct ResourceAndName { |
253 | absl::variant<core::RefCountPtr<ResourceBase>, core::WeakPtr<ResourceBase>> |
254 | resource; |
255 | std::unique_ptr<std::string> name; |
256 | |
257 | ResourceAndName(); |
258 | explicit ResourceAndName(const string& name); |
259 | ResourceAndName(ResourceAndName&& other) noexcept; |
260 | ~ResourceAndName(); |
261 | |
262 | ResourceAndName& operator=(ResourceAndName&&) noexcept; |
263 | |
264 | // Returns a strong reference to resource, or nullptr if the resource is |
265 | // no longer valid. |
266 | core::RefCountPtr<ResourceBase> GetResource() const; |
267 | |
268 | private: |
269 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName); |
270 | }; |
271 | typedef absl::flat_hash_map<Key, ResourceAndName, KeyHash, KeyEqual> |
272 | Container; |
273 | |
274 | const std::string default_container_; |
275 | mutable mutex mu_; |
276 | absl::flat_hash_map<string, Container*> containers_ TF_GUARDED_BY(mu_); |
277 | |
278 | template <typename T, bool use_dynamic_cast = false> |
279 | Status LookupInternal(const std::string& container, const std::string& name, |
280 | T** resource) const |
281 | TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
282 | Status LookupInternal(const std::string& container, uint64 type_hash_code, |
283 | const std::string& name, ResourceBase** resource) const |
284 | TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
285 | |
286 | Status DoCreate(const std::string& container, TypeIndex type, |
287 | const std::string& name, ResourceBase* resource, |
288 | bool owns_resource) |
289 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
290 | |
291 | Status DoLookup(const std::string& container, TypeIndex type, |
292 | const std::string& name, ResourceBase** resource) const |
293 | TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
294 | Status DoLookup(const std::string& container, uint64 type_hash_code, |
295 | const std::string& type_name, |
296 | const std::string& resource_name, |
297 | ResourceBase** resource) const |
298 | TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
299 | |
300 | Status DoDelete(const std::string& container, uint64 type_hash_code, |
301 | const std::string& resource_name, |
302 | const std::string& type_name) TF_MUST_USE_RESULT; |
303 | Status DoDelete(const std::string& container, TypeIndex type, |
304 | const std::string& resource_name) TF_MUST_USE_RESULT; |
305 | |
306 | // Pops the ResourceAndName entry. The entry is moved from the list to |
307 | // the output argument `resource_and_name`. |
308 | Status PopResourceAndName( |
309 | const std::string& container, uint64 type_hash_code, |
310 | const std::string& resource_name, const std::string& type_name, |
311 | ResourceAndName& resource_and_name) TF_MUST_USE_RESULT; |
312 | // Inserts the type name for 'hash_code' into the hash_code to type name map. |
313 | Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name) |
314 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; |
315 | |
316 | // Returns the type name for the 'hash_code'. |
317 | // Returns "<unknown>" if a resource with such a type was never inserted into |
318 | // the container. |
319 | const char* DebugTypeName(uint64 hash_code) const |
320 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
321 | |
322 | // Map from type hash_code to type name. |
323 | std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_); |
324 | |
325 | TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr); |
326 | }; |
327 | |
328 | // Makes a resource handle with the specified type for a given container / |
329 | // name. |
330 | ResourceHandle MakeResourceHandle( |
331 | const std::string& container, const std::string& name, |
332 | const DeviceBase& device, const TypeIndex& type_index, |
333 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, |
334 | const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) |
335 | TF_MUST_USE_RESULT; |
336 | |
337 | template <typename T> |
338 | ResourceHandle MakeResourceHandle( |
339 | OpKernelContext* ctx, const std::string& container, const std::string& name, |
340 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, |
341 | const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) { |
342 | return MakeResourceHandle(container.empty() |
343 | ? ctx->resource_manager()->default_container() |
344 | : container, |
345 | name, *ctx->device(), TypeIndex::Make<T>(), |
346 | dtypes_and_shapes, definition_stack_trace); |
347 | } |
348 | |
349 | template <typename T> |
350 | ResourceHandle MakeResourceHandle( |
351 | OpKernelConstruction* ctx, const std::string& container, |
352 | const std::string& name, |
353 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, |
354 | const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) { |
355 | return MakeResourceHandle(container.empty() |
356 | ? ctx->resource_manager()->default_container() |
357 | : container, |
358 | name, *ctx->device(), TypeIndex::Make<T>(), |
359 | dtypes_and_shapes, definition_stack_trace); |
360 | } |
361 | |
362 | Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, |
363 | const std::string& container, |
364 | const std::string& name, |
365 | const TypeIndex& type_index); |
366 | |
367 | // Returns a resource handle from a numbered op input. |
368 | const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); |
369 | Status HandleFromInput(OpKernelContext* ctx, StringPiece input, |
370 | ResourceHandle* handle); |
371 | |
372 | // Create a resource pointed by a given resource handle. |
373 | // |
374 | // If successful, the caller transfers the ownership of one ref on `resource` to |
375 | // `ctx->resource_mgr()`. |
376 | template <typename T> |
377 | Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); |
378 | |
379 | // Looks up a resource pointed by a given resource handle. |
380 | // |
381 | // If the lookup is successful, the caller takes the ownership of one ref on |
382 | // `*value`, and must call its `Unref()` method when it has finished using it. |
383 | template <typename T, bool use_dynamic_cast = false> |
384 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); |
385 | |
386 | // Looks up a resource pointed by a given resource handle. |
387 | // |
388 | // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid |
389 | // requiring the caller to explicitly call `Unref()`. |
390 | template <typename T> |
391 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, |
392 | core::RefCountPtr<T>* value); |
393 | |
394 | // Looks up multiple resources pointed by a sequence of resource handles. If |
395 | // p[i] is uninitialized then values[i] is unmodified. |
396 | template <typename T> |
397 | Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p, |
398 | std::vector<core::RefCountPtr<T>>* values); |
399 | |
400 | // Looks up or creates a resource. |
401 | // |
402 | // If successful, the caller takes the ownership of one ref on `*value`, and |
403 | // must call its `Unref()` method when it has finished using it. If the |
404 | // `creator` is invoked, its reference on the created resource is transferred |
405 | // to `ctx->resource_mgr()`. |
406 | // |
407 | // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid |
408 | // requiring the caller to explicitly call `Unref()`. |
409 | template <typename T> |
410 | Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, |
411 | T** value, std::function<Status(T**)> creator); |
412 | |
413 | // Looks up or creates a resource. |
414 | template <typename T> |
415 | Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, |
416 | core::RefCountPtr<T>* value, |
417 | std::function<Status(T**)> creator); |
418 | |
419 | // Destroys a resource pointed by a given resource handle. |
420 | template <typename T> |
421 | Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); |
422 | |
423 | // Same as above, but uses the hash code of the type directly. |
424 | // The type name information will be missing in the debug output when the |
425 | // resource is not present in the container. |
426 | Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); |
427 | |
428 | // Policy helper to decide which container/shared_name to use for a |
429 | // stateful kernel that accesses shared resource. |
430 | class ContainerInfo { |
431 | public: |
432 | // Analyze the node attribute of 'ndef' and decides the container and |
433 | // resource name the kernel should use for accessing the shared |
434 | // resource. |
435 | // |
436 | // 'ndef' is expected to have node attribute "container" and |
437 | // "shared_name". Returns non-OK if they are not provided or they are |
438 | // invalid. |
439 | // |
440 | // The policy is as following: |
441 | // * If the attribute "container" is non-empty, it is used as is. |
442 | // Otherwise, uses the resource manager's default container. |
443 | // * If the attribute "shared_name" is non-empty, it is used as is. |
444 | // Otherwise, if "use_node_name_as_default" is true, the kernel's |
445 | // node name is used as the resource name. Otherwise, a string |
446 | // unique to this process is used. |
447 | Status Init(ResourceMgr* rmgr, const NodeDef& ndef, |
448 | bool use_node_name_as_default); |
449 | Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { |
450 | return Init(rmgr, ndef, false); |
451 | } |
452 | |
453 | // The policy decides that the kernel should access the resource in |
454 | // resource_manager(), the resource is in the container() and its |
455 | // name is name(). If resource_is_private_to_kernel() is true, the |
456 | // kernel should delete the resource when the kernel is deleted. |
457 | ResourceMgr* resource_manager() const { return rmgr_; } |
458 | const std::string& container() const { return container_; } |
459 | const std::string& name() const { return name_; } |
460 | bool resource_is_private_to_kernel() const { |
461 | return resource_is_private_to_kernel_; |
462 | } |
463 | |
464 | // Returns a readable string for *this. |
465 | std::string DebugString() const; |
466 | |
467 | private: |
468 | ResourceMgr* rmgr_ = nullptr; |
469 | std::string container_; |
470 | std::string name_; |
471 | bool resource_is_private_to_kernel_ = false; |
472 | }; |
473 | |
474 | // Helper for kernels to obtain 'resource' from the |
475 | // ctx->resource_manager(). |
476 | // |
477 | // "input_name" specifies the kernel's ref input which gives a string |
478 | // tensor with two elements, which specifies the container and |
479 | // resource name. |
480 | // |
481 | // Returns OK if the resource is found and transfers one ref of |
482 | // *resource to the caller. Otherwise, returns an error. |
483 | template <typename T> |
484 | Status GetResourceFromContext(OpKernelContext* ctx, |
485 | const std::string& input_name, T** resource); |
486 | |
487 | // Utility op kernel to check if a handle to resource type T is initialized. |
488 | template <typename T> |
489 | class IsResourceInitialized : public OpKernel { |
490 | public: |
491 | explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {} |
492 | |
493 | void Compute(OpKernelContext* ctx) override; |
494 | }; |
495 | |
496 | // Registers an op which produces just a resource handle to a resource of the |
497 | // specified type. The type will be a part of the generated op name. |
498 | // TODO(apassos): figure out how to get non-cpu-allocated tensors to work |
499 | // through constant folding so this doesn't have to be marked as stateful. |
500 | #define REGISTER_RESOURCE_HANDLE_OP(Type) \ |
501 | REGISTER_OP(#Type "HandleOp") \ |
502 | .Attr("container: string = ''") \ |
503 | .Attr("shared_name: string = ''") \ |
504 | .Output("resource: resource") \ |
505 | .SetIsStateful() \ |
506 | .SetShapeFn(tensorflow::shape_inference::ScalarShape) |
507 | |
508 | // Utility op kernel to produce a handle to a resource of type T. |
509 | template <typename T> |
510 | class ResourceHandleOp : public OpKernel { |
511 | public: |
512 | explicit ResourceHandleOp(OpKernelConstruction* context); |
513 | |
514 | void Compute(OpKernelContext* ctx) override; |
515 | |
516 | bool IsExpensive() override { return false; } |
517 | |
518 | private: |
519 | std::string container_; |
520 | std::string name_; |
521 | mutex mutex_; |
522 | Tensor resource_; |
523 | std::atomic<bool> initialized_{false}; |
524 | }; |
525 | |
526 | // Utility op kernel to produce a handle to a resource of type T. |
527 | template <typename T> |
528 | class ResourceHandlesOp : public OpKernel { |
529 | public: |
530 | explicit ResourceHandlesOp(OpKernelConstruction* context); |
531 | |
532 | void Compute(OpKernelContext* ctx) override; |
533 | |
534 | bool IsExpensive() override { return false; } |
535 | |
536 | private: |
537 | std::vector<string> containers_; |
538 | std::vector<string> names_; |
539 | mutex mutex_; |
540 | std::vector<Tensor> resources_; |
541 | std::atomic<bool> initialized_{false}; |
542 | }; |
543 | |
544 | // Registers a kernel for an op which produces a handle to a resource of the |
545 | // specified type. |
546 | #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ |
547 | REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \ |
548 | ResourceHandleOp<Type>) |
549 | |
550 | // This class is used to guarantee that an anonymous resource is deleted |
551 | // (irrespective of whether a resource deleter op is called explicitly or |
552 | // the execution encounters an error before the op runs). |
553 | // |
554 | // This is achieved by wrapping an instance of this class into a variant |
555 | // tensor which is passed as an input to a resource deleter op. If the |
556 | // execution encounters an error before the op runs, the tensor will be |
557 | // destroyed, essentially triggering the iterator deletion. |
558 | // NOTE: This is not a feature-complete implementation of the DT_VARIANT |
559 | // specification. In particular, we cannot serialize the `ResourceMgr` |
560 | // object, so the `Encode()` and `Decode()` methods are not implemented. |
561 | class ResourceDeleter { |
562 | public: |
563 | ResourceDeleter() : deleter_() {} |
564 | |
565 | ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager) |
566 | : deleter_(std::make_shared<Helper>(handle, resource_manager)) {} |
567 | |
568 | ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) { |
569 | VLOG(3) << "ResourceDeleter move constructor called." ; |
570 | } |
571 | |
572 | ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) { |
573 | VLOG(3) << "ResourceDeleter copy constructor called." ; |
574 | } |
575 | |
576 | ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete; |
577 | |
578 | ResourceDeleter& operator=(ResourceDeleter&& rhs) = default; |
579 | |
580 | virtual ~ResourceDeleter() { |
581 | VLOG(3) << "ResourceDeleter destructor called." ; |
582 | } |
583 | |
584 | void Encode(VariantTensorData*) const { |
585 | LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter " |
586 | "objects." ; |
587 | } |
588 | |
589 | bool Decode(const VariantTensorData&) { |
590 | LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter " |
591 | "objects" ; |
592 | return false; // Not supported. |
593 | } |
594 | |
595 | private: |
596 | // Helper that performs reference counting for the parent class and deletes |
597 | // the iterator resource when the refcount goes to zero. |
598 | // |
599 | // NOTE: The object is borrowing a pointer to the resource manager. |
600 | // Consequently, the tensor containing this object should not escape the |
601 | // function in which was created (so that it is guaranteed that the resource |
602 | // manager will outlive it). |
603 | struct Helper { |
604 | Helper(ResourceHandle handle, ResourceMgr* resource_manager) |
605 | : handle(handle), resource_manager(resource_manager) {} |
606 | |
607 | Helper(const Helper& rhs) = delete; |
608 | Helper(Helper&& rhs) = delete; |
609 | |
610 | ~Helper() { |
611 | VLOG(3) << "Deleting Resource: " << handle.DebugString(); |
612 | resource_manager->Delete(handle).IgnoreError(); |
613 | } |
614 | |
615 | ResourceHandle handle; |
616 | ResourceMgr* resource_manager; // not owned |
617 | }; |
618 | |
619 | std::shared_ptr<Helper> deleter_; |
620 | }; |
621 | |
622 | // Implementation details below. |
623 | |
624 | template <typename T> |
625 | void CheckDeriveFromResourceBase() { |
626 | static_assert(std::is_base_of<ResourceBase, T>::value, |
627 | "T must derive from ResourceBase" ); |
628 | } |
629 | |
630 | template <typename T> |
631 | Status ResourceMgr::Create(const std::string& container, |
632 | const std::string& name, T* resource) { |
633 | CheckDeriveFromResourceBase<T>(); |
634 | CHECK(resource != nullptr); |
635 | mutex_lock l(mu_); |
636 | return DoCreate(container, TypeIndex::Make<T>(), name, resource, |
637 | /* owns_resource */ true); |
638 | } |
639 | |
640 | template <typename T> |
641 | Status ResourceMgr::CreateUnowned(const std::string& container, |
642 | const std::string& name, T* resource) { |
643 | CheckDeriveFromResourceBase<T>(); |
644 | mutex_lock l(mu_); |
645 | return DoCreate(container, TypeIndex::Make<T>(), name, resource, |
646 | /* owns_resource */ false); |
647 | } |
648 | |
649 | template <typename T, bool use_dynamic_cast> |
650 | Status ResourceMgr::Lookup(const std::string& container, |
651 | const std::string& name, T** resource) const { |
652 | CheckDeriveFromResourceBase<T>(); |
653 | tf_shared_lock l(mu_); |
654 | return LookupInternal<T, use_dynamic_cast>(container, name, resource); |
655 | } |
656 | |
657 | template <typename T, bool use_dynamic_cast> |
658 | Status ResourceMgr::LookupMany( |
659 | absl::Span<std::pair<const string*, const string*> const> |
660 | containers_and_names, |
661 | std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const { |
662 | CheckDeriveFromResourceBase<T>(); |
663 | tf_shared_lock l(mu_); |
664 | resources->resize(containers_and_names.size()); |
665 | for (size_t i = 0; i < containers_and_names.size(); ++i) { |
666 | T* resource; |
667 | Status s = LookupInternal<T, use_dynamic_cast>( |
668 | *containers_and_names[i].first, *containers_and_names[i].second, |
669 | &resource); |
670 | if (s.ok()) { |
671 | (*resources)[i].reset(resource); |
672 | } |
673 | } |
674 | return OkStatus(); |
675 | } |
676 | |
677 | // Simple wrapper to allow conditional dynamic / static casts. |
678 | template <typename T, bool use_dynamic_cast> |
679 | struct TypeCastFunctor { |
680 | static T* Cast(ResourceBase* r) { return static_cast<T*>(r); } |
681 | }; |
682 | |
683 | template <typename T> |
684 | struct TypeCastFunctor<T, true> { |
685 | static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); } |
686 | }; |
687 | |
688 | template <typename T, bool use_dynamic_cast> |
689 | Status ResourceMgr::LookupInternal(const std::string& container, |
690 | const std::string& name, |
691 | T** resource) const { |
692 | ResourceBase* found = nullptr; |
693 | Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found); |
694 | if (s.ok()) { |
695 | // It's safe to down cast 'found' to T* since |
696 | // typeid(T).hash_code() is part of the map key. |
697 | *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found); |
698 | } |
699 | return s; |
700 | } |
701 | |
702 | template <typename T, bool use_dynamic_cast> |
703 | Status ResourceMgr::LookupOrCreate(const std::string& container, |
704 | const std::string& name, T** resource, |
705 | std::function<Status(T**)> creator) { |
706 | CheckDeriveFromResourceBase<T>(); |
707 | *resource = nullptr; |
708 | Status s; |
709 | { |
710 | tf_shared_lock l(mu_); |
711 | s = LookupInternal<T, use_dynamic_cast>(container, name, resource); |
712 | if (s.ok()) return s; |
713 | } |
714 | mutex_lock l(mu_); |
715 | s = LookupInternal<T, use_dynamic_cast>(container, name, resource); |
716 | if (s.ok()) return s; |
717 | TF_RETURN_IF_ERROR(creator(resource)); |
718 | s = DoCreate(container, TypeIndex::Make<T>(), name, *resource, |
719 | /* owns_resource */ true); |
720 | if (!s.ok()) { |
721 | return errors::Internal("LookupOrCreate failed unexpectedly" ); |
722 | } |
723 | (*resource)->Ref(); |
724 | return s; |
725 | } |
726 | |
727 | template <typename T> |
728 | Status ResourceMgr::Delete(const std::string& container, |
729 | const std::string& name) { |
730 | CheckDeriveFromResourceBase<T>(); |
731 | return DoDelete(container, TypeIndex::Make<T>(), name); |
732 | } |
733 | |
734 | template <typename T> |
735 | Status GetResourceFromContext(OpKernelContext* ctx, |
736 | const std::string& input_name, T** resource) { |
737 | DataType dtype; |
738 | TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); |
739 | if (dtype == DT_RESOURCE) { |
740 | const Tensor* handle; |
741 | TF_RETURN_IF_ERROR(ctx->input(input_name, &handle)); |
742 | return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource); |
743 | } |
744 | std::string container; |
745 | std::string shared_name; |
746 | { |
747 | mutex* mu; |
748 | TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); |
749 | mutex_lock l(*mu); |
750 | Tensor tensor; |
751 | TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); |
752 | if (tensor.NumElements() != 2) { |
753 | return errors::InvalidArgument( |
754 | "Resource handle must have 2 elements, but had shape: " , |
755 | tensor.shape().DebugString()); |
756 | } |
757 | container = tensor.flat<tstring>()(0); |
758 | shared_name = tensor.flat<tstring>()(1); |
759 | } |
760 | return ctx->resource_manager()->Lookup(container, shared_name, resource); |
761 | } |
762 | |
763 | namespace internal { |
764 | |
765 | Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); |
766 | |
767 | template <typename T> |
768 | Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { |
769 | TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); |
770 | TF_RETURN_IF_ERROR(p.ValidateType<T>()); |
771 | return OkStatus(); |
772 | } |
773 | |
774 | } // namespace internal |
775 | |
776 | // Creates the resource pointed at by "p". The caller transfers the ownership of |
777 | // one ref on "*value" to the resource manager in "ctx", regardless of whether |
778 | // this operation succeeds or fails. |
779 | template <typename T> |
780 | Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { |
781 | TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); |
782 | return ctx->resource_manager()->Create(p.container(), p.name(), value); |
783 | } |
784 | |
785 | // Finds the resource as "*value" from the handle. If the handle is |
786 | // ref-counting, returns the resource owned by the handle. Otherwise, looks up |
787 | // the resource matching "p" from resource manager associated with ctx. |
788 | // Always returns a new reference to the resource in "*value". The caller shall |
789 | // call (*value)->Unref(). |
790 | template <typename T, bool use_dynamic_cast> |
791 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, |
792 | T** value) { |
793 | TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); |
794 | if (p.IsRefCounting()) { |
795 | TF_ASSIGN_OR_RETURN(*value, p.GetResource<T>()); |
796 | // Transfers out a new reference. |
797 | (*value)->Ref(); |
798 | return OkStatus(); |
799 | } |
800 | |
801 | return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(), |
802 | p.name(), value); |
803 | } |
804 | |
805 | // Finds the resource as "*value" from the handle. This is a type-erased |
806 | // variant of LookupResource above. |
807 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, |
808 | ResourceBase** value); |
809 | |
810 | // If the resource manager in "ctx" has a resource matching "p", returns it in |
811 | // "*value". |
812 | template <typename T> |
813 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, |
814 | core::RefCountPtr<T>* value) { |
815 | T* raw_ptr = nullptr; |
816 | TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr)); |
817 | value->reset(raw_ptr); |
818 | |
819 | return OkStatus(); |
820 | } |
821 | |
822 | // Similar to Lookup, but looks up multiple resources at once, with only a |
823 | // single lock acquisition. |
824 | template <typename T> |
825 | Status LookupResources(OpKernelContext* ctx, |
826 | absl::Span<ResourceHandle const* const> p, |
827 | std::vector<core::RefCountPtr<T>>* values) { |
828 | std::vector<std::pair<const string*, const string*>> containers_and_names( |
829 | p.size()); |
830 | for (size_t i = 0; i < p.size(); ++i) { |
831 | TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i])); |
832 | containers_and_names[i] = {&p[i]->container(), &p[i]->name()}; |
833 | } |
834 | return ctx->resource_manager()->LookupMany(containers_and_names, values); |
835 | } |
836 | |
837 | // If the resource manager in "ctx" has a resource pointed at by "p", returns |
838 | // it in "*value". Otherwise, invokes creator() to create the resource. |
839 | // The caller takes the ownership of one ref on "*value". |
840 | // |
841 | // WARNING: creator() must not call any methods on the resource manager during |
842 | // its execution, because a non-reentrant lock is held during the creator() call |
843 | // in order to guarantee atomicity of LookupOrCreateResource(). |
844 | template <typename T> |
845 | Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, |
846 | T** value, std::function<Status(T**)> creator) { |
847 | TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); |
848 | return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, |
849 | creator); |
850 | } |
851 | |
852 | // If the resource manager in "ctx" has a resource pointed at by "p", returns |
853 | // it in "*value". Otherwise, invokes creator() to create the resource. |
854 | // |
855 | // WARNING: creator() must not call any methods on the resource manager during |
856 | // its execution, because a non-reentrant lock is held during the creator() call |
857 | // in order to guarantee atomicity of LookupOrCreateResource(). |
858 | template <typename T> |
859 | Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, |
860 | core::RefCountPtr<T>* value, |
861 | std::function<Status(T**)> creator) { |
862 | T* raw_ptr = nullptr; |
863 | TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator)); |
864 | value->reset(raw_ptr); |
865 | |
866 | return OkStatus(); |
867 | } |
868 | |
869 | // Deletes the resource pointed by "p", using the resource manager in "ctx". |
870 | template <typename T> |
871 | Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { |
872 | TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); |
873 | // This is a noop because ResourceMgr does not hold a reference. |
874 | // NOTE(feyu): if we can convert all resources handle to ref-counting, then |
875 | // DeleteResource can be removed. |
876 | if (p.IsRefCounting()) { |
877 | return OkStatus(); |
878 | } |
879 | return ctx->resource_manager()->Delete<T>(p.container(), p.name()); |
880 | } |
881 | |
882 | // Deletes the resource pointed by "p", using the resource manager in "ctx". |
883 | Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); |
884 | |
885 | template <typename T> |
886 | void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) { |
887 | Tensor* output; |
888 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); |
889 | T* object; |
890 | bool found; |
891 | if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) { |
892 | found = true; |
893 | object->Unref(); |
894 | } else { |
895 | found = false; |
896 | } |
897 | |
898 | output->flat<bool>()(0) = found; |
899 | } |
900 | |
901 | template <typename T> |
902 | ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context) |
903 | : OpKernel(context) { |
904 | OP_REQUIRES_OK(context, context->GetAttr("container" , &container_)); |
905 | OP_REQUIRES_OK(context, context->GetAttr("shared_name" , &name_)); |
906 | } |
907 | |
908 | template <typename T> |
909 | void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) { |
910 | if (name_ == ResourceHandle::ANONYMOUS_NAME) { |
911 | AllocatorAttributes attr; |
912 | attr.set_on_host(true); |
913 | Tensor handle; |
914 | OP_REQUIRES_OK( |
915 | ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); |
916 | handle.scalar<ResourceHandle>()() = MakeResourceHandle<T>( |
917 | ctx, container_, name_, /*dtypes_and_shapes=*/{}, ctx->stack_trace()); |
918 | ctx->set_output(0, handle); |
919 | } else { |
920 | if (!initialized_.load()) { |
921 | mutex_lock ml(mutex_); |
922 | // Checking again to see if another thread has initialized the resource. |
923 | if (!initialized_.load()) { |
924 | AllocatorAttributes attr; |
925 | attr.set_on_host(true); |
926 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), |
927 | &resource_, attr)); |
928 | resource_.scalar<ResourceHandle>()() = |
929 | MakeResourceHandle<T>(ctx, container_, name_, |
930 | /*dtypes_and_shapes=*/{}, ctx->stack_trace()); |
931 | initialized_.store(true); |
932 | } |
933 | } |
934 | ctx->set_output(0, resource_); |
935 | } |
936 | } |
937 | |
938 | template <typename T> |
939 | ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context) |
940 | : OpKernel(context) { |
941 | int n; |
942 | OP_REQUIRES_OK(context, context->GetAttr("N" , &n)); |
943 | OP_REQUIRES_OK(context, context->GetAttr("containers" , &containers_)); |
944 | OP_REQUIRES_OK(context, context->GetAttr("shared_names" , &names_)); |
945 | OP_REQUIRES( |
946 | context, containers_.size() == n, |
947 | errors::InvalidArgument("Number of containers (" , containers_.size(), |
948 | ") must be equal to N (" , n, ")" )); |
949 | OP_REQUIRES(context, names_.size() == n, |
950 | errors::InvalidArgument("Number of names (" , containers_.size(), |
951 | ") must be equal to N (" , n, ")" )); |
952 | resources_.resize(n); |
953 | } |
954 | |
955 | template <typename T> |
956 | void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) { |
957 | if (!initialized_.load()) { |
958 | mutex_lock ml(mutex_); |
959 | // Checking again to see if another thread has initialized the resource. |
960 | if (!initialized_.load()) { |
961 | AllocatorAttributes attr; |
962 | attr.set_on_host(true); |
963 | for (size_t i = 0; i < resources_.size(); ++i) { |
964 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), |
965 | &resources_[i], attr)); |
966 | ResourceHandle h = |
967 | MakeResourceHandle<T>(ctx, containers_[i], names_[i]); |
968 | resources_[i].template scalar<ResourceHandle>()() = h; |
969 | } |
970 | initialized_.store(true); |
971 | } |
972 | } |
973 | for (size_t i = 0; i < resources_.size(); ++i) { |
974 | ctx->set_output(i, resources_[i]); |
975 | } |
976 | } |
977 | |
978 | template <typename T> |
979 | ResourceHandle ScopedStepContainer::MakeResourceHandle( |
980 | const std::string& name, const DeviceBase& device) { |
981 | mutex_lock ml(mu_); |
982 | dirty_ = true; |
983 | return tensorflow::MakeResourceHandle(container_, name, device, |
984 | TypeIndex::Make<T>(), {}); |
985 | } |
986 | |
987 | template <typename T> |
988 | Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name, |
989 | T** resource) const { |
990 | return rm->Lookup<T>(container_, name, resource); |
991 | } |
992 | |
993 | template <typename T> |
994 | Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, |
995 | const std::string& name, |
996 | T** resource, |
997 | std::function<Status(T**)> creator) { |
998 | mutex_lock ml(mu_); |
999 | dirty_ = true; |
1000 | return rm->LookupOrCreate<T>(container_, name, resource, creator); |
1001 | } |
1002 | |
1003 | template <typename T> |
1004 | Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name, |
1005 | T* resource) { |
1006 | mutex_lock ml(mu_); |
1007 | dirty_ = true; |
1008 | return rm->Create<T>(container_, name, resource); |
1009 | } |
1010 | |
1011 | template <typename T> |
1012 | Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) { |
1013 | return rm->Delete<T>(container_, name); |
1014 | } |
1015 | |
1016 | } // end namespace tensorflow |
1017 | |
1018 | #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ |
1019 | |