1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/resource_base.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/framework/type_index.h" |
24 | #include "tensorflow/core/framework/types.pb.h" |
25 | #include "tensorflow/core/platform/casts.h" |
26 | #include "tensorflow/core/platform/intrusive_ptr.h" |
27 | #include "tensorflow/core/platform/statusor.h" |
28 | #include "tensorflow/core/platform/tensor_coding.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | #include "tensorflow/core/util/managed_stack_trace.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | class ResourceHandleProto; |
35 | |
36 | // Class representing a handle to a tensorflow resource. Handles are |
37 | // not valid across executions, but can be serialized back and forth from within |
38 | // a single run (except for those created from MakeRefCountingHandle i.e. whose |
39 | // resource_ field is not empty). |
40 | // |
41 | // This is the native C++ class equivalent of ResourceHandleProto. They are |
42 | // separate so that kernels do not need to depend on protos. |
43 | class ResourceHandle { |
44 | public: |
45 | ResourceHandle(); |
46 | ResourceHandle(const ResourceHandleProto& proto); |
47 | ~ResourceHandle(); |
48 | |
49 | // Use this factory method if the `proto` comes from user controlled input, to |
50 | // prevent a denial of service. |
51 | static Status BuildResourceHandle(const ResourceHandleProto& proto, |
52 | ResourceHandle* out); |
53 | |
54 | // Unique name for the device containing the resource. |
55 | const std::string& device() const { return device_; } |
56 | |
57 | void set_device(const std::string& device) { device_ = device; } |
58 | |
59 | // Container in which this resource is placed. |
60 | const std::string& container() const { return container_; } |
61 | void set_container(const std::string& container) { container_ = container; } |
62 | |
63 | // Unique name of this resource. |
64 | const std::string& name() const { return name_; } |
65 | void set_name(const std::string& name) { name_ = name; } |
66 | |
67 | // Hash code for the type of the resource. Is only valid in the same device |
68 | // and in the same execution. |
69 | uint64 hash_code() const { return hash_code_; } |
70 | void set_hash_code(uint64 hash_code) { hash_code_ = hash_code; } |
71 | |
72 | // For debug-only, the name of the type pointed to by this handle, if |
73 | // available. |
74 | const std::string& maybe_type_name() const { return maybe_type_name_; } |
75 | void set_maybe_type_name(const std::string& value) { |
76 | maybe_type_name_ = value; |
77 | } |
78 | |
79 | // Data types and shapes for the underlying resource. |
80 | std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes() const { |
81 | return dtypes_and_shapes_; |
82 | } |
83 | void set_dtypes_and_shapes( |
84 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) { |
85 | dtypes_and_shapes_ = dtypes_and_shapes; |
86 | } |
87 | |
88 | void set_definition_stack_trace( |
89 | const absl::optional<ManagedStackTrace>& definition_stack_trace) { |
90 | definition_stack_trace_ = definition_stack_trace; |
91 | } |
92 | |
93 | const absl::optional<ManagedStackTrace>& definition_stack_trace() const { |
94 | return definition_stack_trace_; |
95 | } |
96 | |
97 | // Conversion to and from ResourceHandleProto |
98 | void AsProto(ResourceHandleProto* proto) const; |
99 | Status FromProto(const ResourceHandleProto& proto); |
100 | |
101 | // Serialization via ResourceHandleProto |
102 | std::string SerializeAsString() const; |
103 | bool ParseFromString(const std::string& s); |
104 | |
105 | std::string DebugString() const; |
106 | |
107 | std::string SummarizeValue() const; |
108 | |
109 | // GUID for anonymous resources. Resources with this shared_name will have |
110 | // their shared_name replaced with a GUID at creation time |
111 | static constexpr const char* ANONYMOUS_NAME = |
112 | "cd2c89b7-88b7-44c8-ad83-06c2a9158347" ; |
113 | |
114 | // Creates a `ResourceHandle` that holds a pointer to a resource and takes |
115 | // ownership of it. Normally a `ResourceHandle` only contains the name (and |
116 | // some other metadata) of the resource. When created via this function, |
117 | // the handle will own the resource, in the sense that it will destroy the |
118 | // resource automatically when the resource is no longer needed. It does this |
119 | // via automatic ref-counting on the resource: when the handle is copied, it |
120 | // will call `Ref` on the resource (remember that all resources inherit from |
121 | // `ResourceBase` which inherits from `RefCounted`), and when the handle is |
122 | // destroyed, it will call `Unref` on the resource. When the last handle goes |
123 | // out of scope, the resource's ref-count will go down to zero and the |
124 | // resource will be destroyed. When calling this function, the `resource` |
125 | // argument should have a ref-count of one (which is the case when the |
126 | // resource is newly created). |
127 | // |
128 | // For those familiar with `ResourceMgr`, when you create a handle by the |
129 | // `MakeResourceHandle` function in resource_mgr.h, the handle doesn't hold a |
130 | // strong reference to the resource, and the resource is owned by the |
131 | // resource manager whose strong reference must be manually deleted by |
132 | // calling `ResourceMgr::Delete`. In contrast, a handle created by this |
133 | // function holds a strong reference to the resource. The resource manager |
134 | // does not hold a strong reference to the resource. |
135 | template <typename T> |
136 | static ResourceHandle MakeRefCountingHandle( |
137 | T* resource, const string& device_name, |
138 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, |
139 | const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) { |
140 | return MakeRefCountingHandle(resource, device_name, TypeIndex::Make<T>(), |
141 | dtypes_and_shapes, definition_stack_trace); |
142 | } |
143 | |
144 | static ResourceHandle MakeRefCountingHandle( |
145 | ResourceBase* resource, const string& device_name, |
146 | const TypeIndex& type_index, |
147 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}, |
148 | const absl::optional<ManagedStackTrace>& definition_stack_trace = {}); |
149 | |
150 | // Pointer to the resource. |
151 | const core::IntrusivePtr<ResourceBase>& resource() const { return resource_; } |
152 | |
153 | // Gets the resource pointer in `handle` as `T*`, or an error if the actual |
154 | // resource type is not `T`. |
155 | template <typename T> |
156 | StatusOr<T*> GetResource() const { |
157 | TF_RETURN_IF_ERROR(ValidateType<T>()); |
158 | return down_cast<T*>(resource_.get()); |
159 | } |
160 | |
161 | // Returns True if the resource handle is ref-counting. |
162 | // See MakeRefCountingHandle. |
163 | bool IsRefCounting() const { return resource_.get() != nullptr; } |
164 | |
165 | // Validates that the resource type in `handle` is `T`. |
166 | template <typename T> |
167 | Status ValidateType() const { |
168 | return ValidateType(TypeIndex::Make<T>()); |
169 | } |
170 | |
171 | Status ValidateType(const TypeIndex& type_index) const; |
172 | |
173 | // Generates unique IDs (e.g. for names of anonymous variables) |
174 | static int64_t GenerateUniqueId(); |
175 | |
176 | private: |
177 | std::string device_; |
178 | std::string container_; |
179 | std::string name_; |
180 | uint64 hash_code_ = 0; |
181 | std::string maybe_type_name_; |
182 | std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes_; |
183 | absl::optional<ManagedStackTrace> definition_stack_trace_; |
184 | // A smart pointer to the actual resource. When this field is not empty, the |
185 | // handle is in a "ref-counting" mode, owning the resource; otherwise it's in |
186 | // a "weak-ref" mode, only containing the name of the resource (conceptually a |
187 | // weak reference). |
188 | core::IntrusivePtr<ResourceBase> resource_; |
189 | static std::atomic<int64_t> current_id_; |
190 | }; |
191 | |
192 | // For backwards compatibility for when this was a proto |
193 | std::string ProtoDebugString(const ResourceHandle& handle); |
194 | |
195 | // Encodes a list of ResourceHandle protos in the given StringListEncoder. |
196 | void EncodeResourceHandleList(const ResourceHandle* p, int64_t n, |
197 | std::unique_ptr<port::StringListEncoder> e); |
198 | |
199 | // Decodes a list of ResourceHandle protos from the given StringListDecoder. |
200 | bool DecodeResourceHandleList(std::unique_ptr<port::StringListDecoder> d, |
201 | ResourceHandle* ps, int64_t n); |
202 | |
203 | } // namespace tensorflow |
204 | |
205 | #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ |
206 | |