1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/resource_mgr.h"
17
18#include <atomic>
19
20#include "tensorflow/core/framework/device_attributes.pb.h"
21#include "tensorflow/core/framework/node_def.pb.h"
22#include "tensorflow/core/framework/node_def_util.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/gtl/map_util.h"
25#include "tensorflow/core/lib/strings/scanner.h"
26#include "tensorflow/core/lib/strings/str_util.h"
27#include "tensorflow/core/lib/strings/stringprintf.h"
28#include "tensorflow/core/platform/demangle.h"
29#include "tensorflow/core/platform/stacktrace.h"
30
31namespace tensorflow {
32
33ResourceHandle MakeResourceHandle(
34 const string& container, const string& name, const DeviceBase& device,
35 const TypeIndex& type_index,
36 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes,
37 const absl::optional<ManagedStackTrace>& definition_stack_trace) {
38 ResourceHandle result;
39 result.set_device(device.name());
40 result.set_container(container);
41 result.set_definition_stack_trace(definition_stack_trace);
42 if (name == ResourceHandle::ANONYMOUS_NAME) {
43 result.set_name(
44 strings::StrCat("_AnonymousVar", ResourceHandle::GenerateUniqueId()));
45 } else {
46 result.set_name(name);
47 }
48 result.set_hash_code(type_index.hash_code());
49 result.set_maybe_type_name(type_index.name());
50 result.set_dtypes_and_shapes(dtypes_and_shapes);
51 return result;
52}
53
54Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
55 const string& container, const string& name,
56 const TypeIndex& type_index) {
57 Tensor* handle;
58 TF_RETURN_IF_ERROR(
59 context->allocate_output(output_index, TensorShape({}), &handle));
60 handle->scalar<ResourceHandle>()() =
61 MakeResourceHandle(container, name, *context->device(), type_index);
62 return OkStatus();
63}
64
65namespace internal {
66
67Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
68 if (ctx->device()->attributes().name() != p.device()) {
69 return errors::InvalidArgument(
70 "Trying to access resource ", p.name(), " located in device ",
71 p.device(), " from device ", ctx->device()->attributes().name());
72 }
73 return OkStatus();
74}
75
76} // end namespace internal
77
78Status ResourceMgr::InsertDebugTypeName(uint64 hash_code,
79 const string& type_name) {
80 auto iter = debug_type_names_.emplace(hash_code, type_name);
81 if (iter.first->second != type_name) {
82 return errors::AlreadyExists("Duplicate hash code found for type ",
83 type_name);
84 }
85 return OkStatus();
86}
87
88const char* ResourceMgr::DebugTypeName(uint64 hash_code) const {
89 auto type_name_iter = debug_type_names_.find(hash_code);
90 if (type_name_iter == debug_type_names_.end()) {
91 return "<unknown>";
92 } else {
93 return type_name_iter->second.c_str();
94 }
95}
96
97ResourceMgr::ResourceAndName::ResourceAndName() : name(nullptr) {}
98
99ResourceMgr::ResourceAndName::ResourceAndName(const string& name)
100 : name(absl::make_unique<string>(name)) {}
101
102core::RefCountPtr<ResourceBase> ResourceMgr::ResourceAndName::GetResource()
103 const {
104 if (absl::holds_alternative<core::RefCountPtr<ResourceBase>>(resource)) {
105 ResourceBase* ptr =
106 absl::get<core::RefCountPtr<ResourceBase>>(resource).get();
107 ptr->Ref();
108 return core::RefCountPtr<ResourceBase>(ptr);
109 } else if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(resource)) {
110 return absl::get<core::WeakPtr<ResourceBase>>(resource).GetNewRef();
111 } else {
112 return nullptr;
113 }
114}
115
116ResourceMgr::ResourceAndName::ResourceAndName(
117 ResourceAndName&& other) noexcept {
118 name = std::move(other.name);
119 resource = std::move(other.resource);
120}
121
122ResourceMgr::ResourceAndName::~ResourceAndName() {}
123
124ResourceMgr::ResourceAndName& ResourceMgr::ResourceAndName::operator=(
125 ResourceAndName&& other) noexcept {
126 name = std::move(other.name);
127 resource = std::move(other.resource);
128 return *this;
129}
130
131ResourceMgr::ResourceMgr() : default_container_("localhost") {}
132
133ResourceMgr::ResourceMgr(const string& default_container)
134 : default_container_(default_container) {}
135
136ResourceMgr::~ResourceMgr() { Clear(); }
137
138void ResourceMgr::Clear() {
139 // We do the deallocation outside of the lock to avoid a potential deadlock
140 // in case any of the destructors access the resource manager.
141 absl::flat_hash_map<string, Container*> tmp_containers;
142 {
143 mutex_lock l(mu_);
144 tmp_containers = std::move(containers_);
145 }
146 for (const auto& p : tmp_containers) {
147 delete p.second;
148 }
149 tmp_containers.clear();
150}
151
152string ResourceMgr::DebugString() const {
153 mutex_lock l(mu_);
154 struct Line {
155 const string* container;
156 const string type;
157 const string* resource;
158 const string detail;
159 };
160 std::vector<Line> lines;
161 for (const auto& p : containers_) {
162 const string& container = p.first;
163 for (const auto& q : *p.second) {
164 const Key& key = q.first;
165 const char* type = DebugTypeName(key.first);
166 const core::RefCountPtr<ResourceBase> resource = q.second.GetResource();
167 Line l{&container, port::Demangle(type), q.second.name.get(),
168 resource ? resource->DebugString() : "<nullptr>"};
169 lines.push_back(l);
170 }
171 }
172 std::vector<string> text;
173 text.reserve(lines.size());
174 for (const Line& line : lines) {
175 text.push_back(strings::Printf(
176 "%-20s | %-40s | %-40s | %-s", line.container->c_str(),
177 line.type.c_str(), line.resource->c_str(), line.detail.c_str()));
178 }
179 std::sort(text.begin(), text.end());
180 return absl::StrJoin(text, "\n");
181}
182
183Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type,
184 const string& name, ResourceBase* resource,
185 bool owns_resource) {
186 Container* container = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
187 Container** ptr = &containers_[container_name];
188 if (*ptr == nullptr) {
189 *ptr = new Container;
190 }
191 return *ptr;
192 }();
193
194 // NOTE: Separating out the construction of the map key and value so that the
195 // key can contain a StringPiece that borrows from the string in the value.
196 ResourceAndName resource_and_name(name);
197
198 StringPiece borrowed_name(*resource_and_name.name);
199
200 if (owns_resource) {
201 resource_and_name.resource = core::RefCountPtr<ResourceBase>(resource);
202 } else {
203 auto cleanup_fn = [this, container, type, borrowed_name]() {
204 mutex_lock l(mu_);
205 auto iter = container->find({type.hash_code(), borrowed_name});
206 if (iter != container->end()) {
207 container->erase(iter);
208 }
209 };
210 resource_and_name.resource =
211 core::WeakPtr<ResourceBase>(resource, cleanup_fn);
212 }
213
214 Container::value_type key_and_value(Key(type.hash_code(), borrowed_name),
215 std::move(resource_and_name));
216
217 auto st = container->insert(std::move(key_and_value));
218 if (st.second) {
219 TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name()));
220 return OkStatus();
221 }
222 return errors::AlreadyExists("Resource ", container_name, "/", name, "/",
223 type.name());
224}
225
226Status ResourceMgr::Lookup(const ResourceHandle& handle,
227 ResourceBase** resource) const {
228 tf_shared_lock l(mu_);
229 return DoLookup(handle.container(), handle.hash_code(),
230 /*type_name=*/"ResourceBase", handle.name(), resource);
231}
232
233Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
234 const string& name,
235 ResourceBase** resource) const {
236 return DoLookup(container, type.hash_code(), type.name(), name, resource);
237}
238
239Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code,
240 const string& type_name,
241 const string& resource_name,
242 ResourceBase** resource) const {
243 const Container* b = gtl::FindPtrOrNull(containers_, container);
244 if (b == nullptr) {
245 return errors::NotFound("Container ", container,
246 " does not exist. (Could not find resource: ",
247 container, "/", resource_name, ")");
248 }
249 auto iter = b->find({type_hash_code, resource_name});
250 if (iter == b->end()) {
251 return errors::NotFound("Resource ", container, "/", resource_name, "/",
252 type_name, " does not exist.");
253 }
254 ResourceBase* ptr = iter->second.GetResource().release();
255 if (ptr == nullptr) {
256 return errors::NotFound("Resource ", container, "/", resource_name, "/",
257 type_name, " has been destroyed.");
258 }
259 *resource = ptr;
260 return OkStatus();
261}
262
263Status ResourceMgr::PopResourceAndName(const string& container,
264 uint64 type_hash_code,
265 const string& resource_name,
266 const string& type_name,
267 ResourceAndName& resource_and_name) {
268 mutex_lock l(mu_);
269 Container* b = gtl::FindPtrOrNull(containers_, container);
270 if (b == nullptr) {
271 return errors::NotFound("Container ", container, " does not exist.");
272 }
273 auto iter = b->find({type_hash_code, resource_name});
274 if (iter == b->end()) {
275 return errors::NotFound("Resource ", container, "/", resource_name, "/",
276 type_name, " does not exist.");
277 }
278 std::swap(resource_and_name, iter->second);
279 b->erase(iter);
280 return OkStatus();
281}
282
283Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code,
284 const string& resource_name,
285 const string& type_name) {
286 ResourceAndName resource_and_name;
287 TF_RETURN_IF_ERROR(PopResourceAndName(
288 container, type_hash_code, resource_name, type_name, resource_and_name));
289
290 if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(
291 resource_and_name.resource)) {
292 return errors::Internal(
293 "Cannot delete an unowned Resource ", container, "/", resource_name,
294 "/", type_name, " from ResourceMgr. ",
295 "This indicates ref-counting ResourceHandle is exposed to weak "
296 "ResourceHandle code paths.");
297 }
298 return OkStatus();
299}
300
301Status ResourceMgr::DoDelete(const string& container, TypeIndex type,
302 const string& resource_name) {
303 return DoDelete(container, type.hash_code(), resource_name, type.name());
304}
305
306Status ResourceMgr::Delete(const ResourceHandle& handle) {
307 return DoDelete(handle.container(), handle.hash_code(), handle.name(),
308 "<unknown>");
309}
310
311Status ResourceMgr::Cleanup(const string& container) {
312 {
313 tf_shared_lock l(mu_);
314 if (!gtl::FindOrNull(containers_, container)) {
315 // Nothing to cleanup.
316 return OkStatus();
317 }
318 }
319 Container* b = nullptr;
320 {
321 mutex_lock l(mu_);
322 auto iter = containers_.find(container);
323 if (iter == containers_.end()) {
324 // Nothing to cleanup, it's OK (concurrent cleanup).
325 return OkStatus();
326 }
327 b = iter->second;
328 containers_.erase(iter);
329 }
330 CHECK(b != nullptr);
331 delete b;
332 return OkStatus();
333}
334
335static bool IsValidContainerName(StringPiece s) {
336 using ::tensorflow::strings::Scanner;
337 return Scanner(s)
338 .One(Scanner::LETTER_DIGIT_DOT)
339 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)
340 .Eos()
341 .GetResult();
342}
343
344Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
345 bool use_node_name_as_default) {
346 CHECK(rmgr);
347 rmgr_ = rmgr;
348 string attr_container;
349 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
350 if (!attr_container.empty() && !IsValidContainerName(attr_container)) {
351 return errors::InvalidArgument("container contains invalid characters: ",
352 attr_container);
353 }
354 string attr_shared_name;
355 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name));
356 if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) {
357 return errors::InvalidArgument("shared_name cannot start with '_':",
358 attr_shared_name);
359 }
360 if (!attr_container.empty()) {
361 container_ = attr_container;
362 } else {
363 container_ = rmgr_->default_container();
364 }
365 if (!attr_shared_name.empty()) {
366 name_ = attr_shared_name;
367 } else if (use_node_name_as_default) {
368 name_ = ndef.name();
369 } else {
370 resource_is_private_to_kernel_ = true;
371 static std::atomic<int64_t> counter(0);
372 name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name());
373 }
374 return OkStatus();
375}
376
377string ContainerInfo::DebugString() const {
378 return strings::StrCat("[", container(), ",", name(), ",",
379 resource_is_private_to_kernel() ? "private" : "public",
380 "]");
381}
382
383const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
384 return ctx->input(input).flat<ResourceHandle>()(0);
385}
386
387Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
388 ResourceHandle* handle) {
389 const Tensor* tensor;
390 TF_RETURN_IF_ERROR(ctx->input(input, &tensor));
391 *handle = tensor->flat<ResourceHandle>()(0);
392 return OkStatus();
393}
394
395Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
396 ResourceBase** value) {
397 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
398 if (p.IsRefCounting()) {
399 TF_ASSIGN_OR_RETURN(*value, p.GetResource<ResourceBase>());
400 (*value)->Ref();
401 return OkStatus();
402 }
403 return ctx->resource_manager()->Lookup(p, value);
404}
405
406Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
407 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
408 if (p.IsRefCounting()) {
409 return OkStatus();
410 }
411 return ctx->resource_manager()->Delete(p);
412}
413
414} // end namespace tensorflow
415