1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/op_requires.h" |
23 | #include "tensorflow/core/framework/resource_mgr.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/platform/mutex.h" |
27 | #include "tensorflow/core/platform/refcount.h" |
28 | #include "tensorflow/core/platform/thread_annotations.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // ResourceOpKernel<T> is a virtual base class for resource op implementing |
34 | // interface type T. The inherited op looks up the resource name (determined by |
35 | // ContainerInfo), and creates a new resource if necessary. |
36 | // |
37 | // Requirements: |
38 | // - Op must be marked as stateful. |
39 | // - Op must have `container` and `shared_name` attributes. Empty `container` |
40 | // means using the default container. Empty `shared_name` means private |
41 | // resource. |
42 | // - Subclass must override CreateResource(). |
43 | // - Subclass is encouraged to override VerifyResource(). |
44 | template <typename T> |
45 | class ResourceOpKernel : public OpKernel { |
46 | public: |
47 | explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { |
48 | has_resource_type_ = (context->output_type(0) == DT_RESOURCE); |
49 | if (!has_resource_type_) { |
50 | // The resource variant of the op may be placed on non-CPU devices, but |
51 | // this allocation is always on the host. Fortunately we don't need it in |
52 | // the resource case. |
53 | OP_REQUIRES_OK(context, context->allocate_temp( |
54 | DT_STRING, TensorShape({2}), &tensor_)); |
55 | } |
56 | } |
57 | |
58 | // The resource is deleted from the resource manager only when it is private |
59 | // to kernel. Ideally the resource should be deleted when it is no longer held |
60 | // by anyone, but it would break backward compatibility. |
61 | ~ResourceOpKernel() override { |
62 | if (cinfo_.resource_is_private_to_kernel()) { |
63 | if (!cinfo_.resource_manager() |
64 | ->template Delete<T>(cinfo_.container(), cinfo_.name()) |
65 | .ok()) { |
66 | // Do nothing; the resource can have been deleted by session resets. |
67 | } |
68 | } |
69 | } |
70 | |
71 | void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) { |
72 | mutex_lock l(mu_); |
73 | core::RefCountPtr<T> resource_ref_ptr = weak_resource_.GetNewRef(); |
74 | if (resource_ref_ptr == nullptr) { |
75 | ResourceMgr* mgr = context->resource_manager(); |
76 | OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); |
77 | |
78 | T* resource; |
79 | OP_REQUIRES_OK(context, |
80 | mgr->LookupOrCreate<T>( |
81 | cinfo_.container(), cinfo_.name(), &resource, |
82 | [this](T** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
83 | Status s = CreateResource(ret); |
84 | if (!s.ok() && *ret != nullptr) { |
85 | CHECK((*ret)->Unref()); |
86 | } |
87 | return s; |
88 | })); |
89 | // Here the code releases the reference to the resource created by this op |
90 | // and only holds a WeakPtr to the resource. This way the lifetime of the |
91 | // resource is owned by the container; otherwise the container may be |
92 | // cleared (e.g. a Session::Reset()) but the resource lives on inside this |
93 | // op, causing later lookups in the container by handle to fail. |
94 | core::ScopedUnref resource_unref(resource); |
95 | OP_REQUIRES_OK(context, VerifyResource(resource)); |
96 | weak_resource_ = core::WeakPtr<T>(resource); |
97 | // TODO(b/243544755): delete after scam migrates ResourceKernelOp |
98 | // subclasses to get_resource() in TF 2.11. |
99 | resource_ = resource; |
100 | |
101 | if (!has_resource_type_) { |
102 | auto h = tensor_.template flat<tstring>(); |
103 | h(0) = cinfo_.container(); |
104 | h(1) = cinfo_.name(); |
105 | } |
106 | } |
107 | if (has_resource_type_) { |
108 | OP_REQUIRES_OK(context, MakeResourceHandleToOutput( |
109 | context, 0, cinfo_.container(), cinfo_.name(), |
110 | TypeIndex::Make<T>())); |
111 | } else { |
112 | context->set_output_ref(0, &mu_, &tensor_); |
113 | } |
114 | } |
115 | |
116 | protected: |
117 | // Variables accessible from subclasses. |
118 | mutex mu_; |
119 | ContainerInfo cinfo_ TF_GUARDED_BY(mu_); |
120 | // TODO(b/243544755): delete after scam migrates ResourceKernelOp subclasses |
121 | // to get_resource() in TF 2.11. |
122 | ABSL_DEPRECATED("Use get_resource() instead." ) |
123 | T* resource_ TF_GUARDED_BY(mu_) = nullptr; |
124 | |
125 | core::RefCountPtr<T> get_resource() TF_LOCKS_EXCLUDED(mu_) { |
126 | mutex_lock lock(mu_); |
127 | return weak_resource_.GetNewRef(); |
128 | } |
129 | |
130 | private: |
131 | core::WeakPtr<T> weak_resource_ TF_GUARDED_BY(mu_) = |
132 | core::WeakPtr<T>(nullptr); |
133 | |
134 | // Must return a T descendant allocated with new that ResourceOpKernel will |
135 | // take ownership of. |
136 | virtual Status CreateResource(T** resource) |
137 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; |
138 | |
139 | // During the first Compute(), resource is either created or looked up using |
140 | // shared_name. In the latter case, the resource found should be verified if |
141 | // it is compatible with this op's configuration. The verification may fail in |
142 | // cases such as two graphs asking queues of the same shared name to have |
143 | // inconsistent capacities. |
144 | virtual Status VerifyResource(T* resource) { return OkStatus(); } |
145 | |
146 | Tensor tensor_ TF_GUARDED_BY(mu_); |
147 | |
148 | // Is the output of the operator of type DT_RESOURCE? |
149 | bool has_resource_type_; |
150 | }; |
151 | } // namespace tensorflow |
152 | |
153 | #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ |
154 | |