1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/resource_mgr.h" |
17 | |
18 | #include <atomic> |
19 | |
20 | #include "tensorflow/core/framework/device_attributes.pb.h" |
21 | #include "tensorflow/core/framework/node_def.pb.h" |
22 | #include "tensorflow/core/framework/node_def_util.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/gtl/map_util.h" |
25 | #include "tensorflow/core/lib/strings/scanner.h" |
26 | #include "tensorflow/core/lib/strings/str_util.h" |
27 | #include "tensorflow/core/lib/strings/stringprintf.h" |
28 | #include "tensorflow/core/platform/demangle.h" |
29 | #include "tensorflow/core/platform/stacktrace.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | ResourceHandle MakeResourceHandle( |
34 | const string& container, const string& name, const DeviceBase& device, |
35 | const TypeIndex& type_index, |
36 | const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes, |
37 | const absl::optional<ManagedStackTrace>& definition_stack_trace) { |
38 | ResourceHandle result; |
39 | result.set_device(device.name()); |
40 | result.set_container(container); |
41 | result.set_definition_stack_trace(definition_stack_trace); |
42 | if (name == ResourceHandle::ANONYMOUS_NAME) { |
43 | result.set_name( |
44 | strings::StrCat("_AnonymousVar" , ResourceHandle::GenerateUniqueId())); |
45 | } else { |
46 | result.set_name(name); |
47 | } |
48 | result.set_hash_code(type_index.hash_code()); |
49 | result.set_maybe_type_name(type_index.name()); |
50 | result.set_dtypes_and_shapes(dtypes_and_shapes); |
51 | return result; |
52 | } |
53 | |
54 | Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, |
55 | const string& container, const string& name, |
56 | const TypeIndex& type_index) { |
57 | Tensor* handle; |
58 | TF_RETURN_IF_ERROR( |
59 | context->allocate_output(output_index, TensorShape({}), &handle)); |
60 | handle->scalar<ResourceHandle>()() = |
61 | MakeResourceHandle(container, name, *context->device(), type_index); |
62 | return OkStatus(); |
63 | } |
64 | |
65 | namespace internal { |
66 | |
67 | Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { |
68 | if (ctx->device()->attributes().name() != p.device()) { |
69 | return errors::InvalidArgument( |
70 | "Trying to access resource " , p.name(), " located in device " , |
71 | p.device(), " from device " , ctx->device()->attributes().name()); |
72 | } |
73 | return OkStatus(); |
74 | } |
75 | |
76 | } // end namespace internal |
77 | |
78 | Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, |
79 | const string& type_name) { |
80 | auto iter = debug_type_names_.emplace(hash_code, type_name); |
81 | if (iter.first->second != type_name) { |
82 | return errors::AlreadyExists("Duplicate hash code found for type " , |
83 | type_name); |
84 | } |
85 | return OkStatus(); |
86 | } |
87 | |
88 | const char* ResourceMgr::DebugTypeName(uint64 hash_code) const { |
89 | auto type_name_iter = debug_type_names_.find(hash_code); |
90 | if (type_name_iter == debug_type_names_.end()) { |
91 | return "<unknown>" ; |
92 | } else { |
93 | return type_name_iter->second.c_str(); |
94 | } |
95 | } |
96 | |
97 | ResourceMgr::ResourceAndName::ResourceAndName() : name(nullptr) {} |
98 | |
99 | ResourceMgr::ResourceAndName::ResourceAndName(const string& name) |
100 | : name(absl::make_unique<string>(name)) {} |
101 | |
102 | core::RefCountPtr<ResourceBase> ResourceMgr::ResourceAndName::GetResource() |
103 | const { |
104 | if (absl::holds_alternative<core::RefCountPtr<ResourceBase>>(resource)) { |
105 | ResourceBase* ptr = |
106 | absl::get<core::RefCountPtr<ResourceBase>>(resource).get(); |
107 | ptr->Ref(); |
108 | return core::RefCountPtr<ResourceBase>(ptr); |
109 | } else if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(resource)) { |
110 | return absl::get<core::WeakPtr<ResourceBase>>(resource).GetNewRef(); |
111 | } else { |
112 | return nullptr; |
113 | } |
114 | } |
115 | |
116 | ResourceMgr::ResourceAndName::ResourceAndName( |
117 | ResourceAndName&& other) noexcept { |
118 | name = std::move(other.name); |
119 | resource = std::move(other.resource); |
120 | } |
121 | |
122 | ResourceMgr::ResourceAndName::~ResourceAndName() {} |
123 | |
124 | ResourceMgr::ResourceAndName& ResourceMgr::ResourceAndName::operator=( |
125 | ResourceAndName&& other) noexcept { |
126 | name = std::move(other.name); |
127 | resource = std::move(other.resource); |
128 | return *this; |
129 | } |
130 | |
131 | ResourceMgr::ResourceMgr() : default_container_("localhost" ) {} |
132 | |
133 | ResourceMgr::ResourceMgr(const string& default_container) |
134 | : default_container_(default_container) {} |
135 | |
136 | ResourceMgr::~ResourceMgr() { Clear(); } |
137 | |
138 | void ResourceMgr::Clear() { |
139 | // We do the deallocation outside of the lock to avoid a potential deadlock |
140 | // in case any of the destructors access the resource manager. |
141 | absl::flat_hash_map<string, Container*> tmp_containers; |
142 | { |
143 | mutex_lock l(mu_); |
144 | tmp_containers = std::move(containers_); |
145 | } |
146 | for (const auto& p : tmp_containers) { |
147 | delete p.second; |
148 | } |
149 | tmp_containers.clear(); |
150 | } |
151 | |
152 | string ResourceMgr::DebugString() const { |
153 | mutex_lock l(mu_); |
154 | struct Line { |
155 | const string* container; |
156 | const string type; |
157 | const string* resource; |
158 | const string detail; |
159 | }; |
160 | std::vector<Line> lines; |
161 | for (const auto& p : containers_) { |
162 | const string& container = p.first; |
163 | for (const auto& q : *p.second) { |
164 | const Key& key = q.first; |
165 | const char* type = DebugTypeName(key.first); |
166 | const core::RefCountPtr<ResourceBase> resource = q.second.GetResource(); |
167 | Line l{&container, port::Demangle(type), q.second.name.get(), |
168 | resource ? resource->DebugString() : "<nullptr>" }; |
169 | lines.push_back(l); |
170 | } |
171 | } |
172 | std::vector<string> text; |
173 | text.reserve(lines.size()); |
174 | for (const Line& line : lines) { |
175 | text.push_back(strings::Printf( |
176 | "%-20s | %-40s | %-40s | %-s" , line.container->c_str(), |
177 | line.type.c_str(), line.resource->c_str(), line.detail.c_str())); |
178 | } |
179 | std::sort(text.begin(), text.end()); |
180 | return absl::StrJoin(text, "\n" ); |
181 | } |
182 | |
183 | Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, |
184 | const string& name, ResourceBase* resource, |
185 | bool owns_resource) { |
186 | Container* container = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
187 | Container** ptr = &containers_[container_name]; |
188 | if (*ptr == nullptr) { |
189 | *ptr = new Container; |
190 | } |
191 | return *ptr; |
192 | }(); |
193 | |
194 | // NOTE: Separating out the construction of the map key and value so that the |
195 | // key can contain a StringPiece that borrows from the string in the value. |
196 | ResourceAndName resource_and_name(name); |
197 | |
198 | StringPiece borrowed_name(*resource_and_name.name); |
199 | |
200 | if (owns_resource) { |
201 | resource_and_name.resource = core::RefCountPtr<ResourceBase>(resource); |
202 | } else { |
203 | auto cleanup_fn = [this, container, type, borrowed_name]() { |
204 | mutex_lock l(mu_); |
205 | auto iter = container->find({type.hash_code(), borrowed_name}); |
206 | if (iter != container->end()) { |
207 | container->erase(iter); |
208 | } |
209 | }; |
210 | resource_and_name.resource = |
211 | core::WeakPtr<ResourceBase>(resource, cleanup_fn); |
212 | } |
213 | |
214 | Container::value_type key_and_value(Key(type.hash_code(), borrowed_name), |
215 | std::move(resource_and_name)); |
216 | |
217 | auto st = container->insert(std::move(key_and_value)); |
218 | if (st.second) { |
219 | TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name())); |
220 | return OkStatus(); |
221 | } |
222 | return errors::AlreadyExists("Resource " , container_name, "/" , name, "/" , |
223 | type.name()); |
224 | } |
225 | |
226 | Status ResourceMgr::Lookup(const ResourceHandle& handle, |
227 | ResourceBase** resource) const { |
228 | tf_shared_lock l(mu_); |
229 | return DoLookup(handle.container(), handle.hash_code(), |
230 | /*type_name=*/"ResourceBase" , handle.name(), resource); |
231 | } |
232 | |
233 | Status ResourceMgr::DoLookup(const string& container, TypeIndex type, |
234 | const string& name, |
235 | ResourceBase** resource) const { |
236 | return DoLookup(container, type.hash_code(), type.name(), name, resource); |
237 | } |
238 | |
239 | Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code, |
240 | const string& type_name, |
241 | const string& resource_name, |
242 | ResourceBase** resource) const { |
243 | const Container* b = gtl::FindPtrOrNull(containers_, container); |
244 | if (b == nullptr) { |
245 | return errors::NotFound("Container " , container, |
246 | " does not exist. (Could not find resource: " , |
247 | container, "/" , resource_name, ")" ); |
248 | } |
249 | auto iter = b->find({type_hash_code, resource_name}); |
250 | if (iter == b->end()) { |
251 | return errors::NotFound("Resource " , container, "/" , resource_name, "/" , |
252 | type_name, " does not exist." ); |
253 | } |
254 | ResourceBase* ptr = iter->second.GetResource().release(); |
255 | if (ptr == nullptr) { |
256 | return errors::NotFound("Resource " , container, "/" , resource_name, "/" , |
257 | type_name, " has been destroyed." ); |
258 | } |
259 | *resource = ptr; |
260 | return OkStatus(); |
261 | } |
262 | |
263 | Status ResourceMgr::PopResourceAndName(const string& container, |
264 | uint64 type_hash_code, |
265 | const string& resource_name, |
266 | const string& type_name, |
267 | ResourceAndName& resource_and_name) { |
268 | mutex_lock l(mu_); |
269 | Container* b = gtl::FindPtrOrNull(containers_, container); |
270 | if (b == nullptr) { |
271 | return errors::NotFound("Container " , container, " does not exist." ); |
272 | } |
273 | auto iter = b->find({type_hash_code, resource_name}); |
274 | if (iter == b->end()) { |
275 | return errors::NotFound("Resource " , container, "/" , resource_name, "/" , |
276 | type_name, " does not exist." ); |
277 | } |
278 | std::swap(resource_and_name, iter->second); |
279 | b->erase(iter); |
280 | return OkStatus(); |
281 | } |
282 | |
283 | Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, |
284 | const string& resource_name, |
285 | const string& type_name) { |
286 | ResourceAndName resource_and_name; |
287 | TF_RETURN_IF_ERROR(PopResourceAndName( |
288 | container, type_hash_code, resource_name, type_name, resource_and_name)); |
289 | |
290 | if (absl::holds_alternative<core::WeakPtr<ResourceBase>>( |
291 | resource_and_name.resource)) { |
292 | return errors::Internal( |
293 | "Cannot delete an unowned Resource " , container, "/" , resource_name, |
294 | "/" , type_name, " from ResourceMgr. " , |
295 | "This indicates ref-counting ResourceHandle is exposed to weak " |
296 | "ResourceHandle code paths." ); |
297 | } |
298 | return OkStatus(); |
299 | } |
300 | |
301 | Status ResourceMgr::DoDelete(const string& container, TypeIndex type, |
302 | const string& resource_name) { |
303 | return DoDelete(container, type.hash_code(), resource_name, type.name()); |
304 | } |
305 | |
306 | Status ResourceMgr::Delete(const ResourceHandle& handle) { |
307 | return DoDelete(handle.container(), handle.hash_code(), handle.name(), |
308 | "<unknown>" ); |
309 | } |
310 | |
311 | Status ResourceMgr::Cleanup(const string& container) { |
312 | { |
313 | tf_shared_lock l(mu_); |
314 | if (!gtl::FindOrNull(containers_, container)) { |
315 | // Nothing to cleanup. |
316 | return OkStatus(); |
317 | } |
318 | } |
319 | Container* b = nullptr; |
320 | { |
321 | mutex_lock l(mu_); |
322 | auto iter = containers_.find(container); |
323 | if (iter == containers_.end()) { |
324 | // Nothing to cleanup, it's OK (concurrent cleanup). |
325 | return OkStatus(); |
326 | } |
327 | b = iter->second; |
328 | containers_.erase(iter); |
329 | } |
330 | CHECK(b != nullptr); |
331 | delete b; |
332 | return OkStatus(); |
333 | } |
334 | |
335 | static bool IsValidContainerName(StringPiece s) { |
336 | using ::tensorflow::strings::Scanner; |
337 | return Scanner(s) |
338 | .One(Scanner::LETTER_DIGIT_DOT) |
339 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH) |
340 | .Eos() |
341 | .GetResult(); |
342 | } |
343 | |
344 | Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, |
345 | bool use_node_name_as_default) { |
346 | CHECK(rmgr); |
347 | rmgr_ = rmgr; |
348 | string attr_container; |
349 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container" , &attr_container)); |
350 | if (!attr_container.empty() && !IsValidContainerName(attr_container)) { |
351 | return errors::InvalidArgument("container contains invalid characters: " , |
352 | attr_container); |
353 | } |
354 | string attr_shared_name; |
355 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name" , &attr_shared_name)); |
356 | if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) { |
357 | return errors::InvalidArgument("shared_name cannot start with '_':" , |
358 | attr_shared_name); |
359 | } |
360 | if (!attr_container.empty()) { |
361 | container_ = attr_container; |
362 | } else { |
363 | container_ = rmgr_->default_container(); |
364 | } |
365 | if (!attr_shared_name.empty()) { |
366 | name_ = attr_shared_name; |
367 | } else if (use_node_name_as_default) { |
368 | name_ = ndef.name(); |
369 | } else { |
370 | resource_is_private_to_kernel_ = true; |
371 | static std::atomic<int64_t> counter(0); |
372 | name_ = strings::StrCat("_" , counter.fetch_add(1), "_" , ndef.name()); |
373 | } |
374 | return OkStatus(); |
375 | } |
376 | |
377 | string ContainerInfo::DebugString() const { |
378 | return strings::StrCat("[" , container(), "," , name(), "," , |
379 | resource_is_private_to_kernel() ? "private" : "public" , |
380 | "]" ); |
381 | } |
382 | |
383 | const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { |
384 | return ctx->input(input).flat<ResourceHandle>()(0); |
385 | } |
386 | |
387 | Status HandleFromInput(OpKernelContext* ctx, StringPiece input, |
388 | ResourceHandle* handle) { |
389 | const Tensor* tensor; |
390 | TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); |
391 | *handle = tensor->flat<ResourceHandle>()(0); |
392 | return OkStatus(); |
393 | } |
394 | |
395 | Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, |
396 | ResourceBase** value) { |
397 | TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); |
398 | if (p.IsRefCounting()) { |
399 | TF_ASSIGN_OR_RETURN(*value, p.GetResource<ResourceBase>()); |
400 | (*value)->Ref(); |
401 | return OkStatus(); |
402 | } |
403 | return ctx->resource_manager()->Lookup(p, value); |
404 | } |
405 | |
406 | Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { |
407 | TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); |
408 | if (p.IsRefCounting()) { |
409 | return OkStatus(); |
410 | } |
411 | return ctx->resource_manager()->Delete(p); |
412 | } |
413 | |
414 | } // end namespace tensorflow |
415 | |