1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/tensor_slice_reader_cache.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/platform/logging.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | namespace checkpoint { |
25 | |
26 | TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {} |
27 | TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() { |
28 | delete cache_; |
29 | cache_ = nullptr; |
30 | } |
31 | |
32 | const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader( |
33 | const string& filepattern, |
34 | TensorSliceReader::OpenTableFunction open_function, |
35 | int preferred_shard) const { |
36 | mutex_lock l(mu_); |
37 | if (!cache_) { |
38 | cache_ = new TensorSliceReaderCache; |
39 | } |
40 | return cache_->GetReader(filepattern, std::move(open_function), |
41 | preferred_shard); |
42 | } |
43 | |
44 | TensorSliceReaderCache::TensorSliceReaderCache() {} |
45 | |
46 | TensorSliceReaderCache::~TensorSliceReaderCache() { |
47 | for (const auto& pair : readers_) { |
48 | delete pair.second.second; |
49 | } |
50 | } |
51 | |
52 | const TensorSliceReader* TensorSliceReaderCache::GetReader( |
53 | const string& filepattern, |
54 | TensorSliceReader::OpenTableFunction open_function, int preferred_shard) { |
55 | mutex_lock l(mu_); |
56 | |
57 | #if defined(__GXX_RTTI) || defined(_CPPRTTI) |
58 | // Get the function pointer from the open_function value. |
59 | TensorSliceReaderCache::OpenFuncType* func_ptr = |
60 | open_function.target<TensorSliceReaderCache::OpenFuncType>(); |
61 | #else // __GXX_RTTI |
62 | // When RTTI is disabled, we will hard-code func_ptr to be zero, |
63 | // since we cannot figure out the target type for open_function. |
64 | // TODO(jiayq): find a more elegant way to possibly enable cache again. |
65 | TensorSliceReaderCache::OpenFuncType* func_ptr = nullptr; |
66 | #endif // _GXX_RTTI |
67 | |
68 | if (!func_ptr) { |
69 | // We could not get the pointer, no caching is possible. |
70 | LOG(WARNING) << "Caching disabled because the open function is a lambda or " |
71 | "RTTI is not enabled in this build." ; |
72 | return nullptr; |
73 | } |
74 | |
75 | // Wait if another thread is already trying to open the same files. |
76 | while (still_opening_.find(filepattern) != still_opening_.end()) { |
77 | cv_.wait(l); |
78 | } |
79 | |
80 | TensorSliceReader* reader = nullptr; |
81 | if (readers_.find(filepattern) == readers_.end()) { |
82 | VLOG(1) << "Creating new TensorSliceReader for " << filepattern; |
83 | still_opening_.insert(filepattern); |
84 | // Release the lock temporary as constructing TensorSliceReader is |
85 | // expensive. |
86 | mu_.unlock(); |
87 | TensorSliceReader* tmp_reader( |
88 | new TensorSliceReader(filepattern, open_function, preferred_shard)); |
89 | // Acquire the lock again. |
90 | mu_.lock(); |
91 | if (tmp_reader->status().ok()) { |
92 | reader = tmp_reader; |
93 | readers_[filepattern] = std::make_pair(*func_ptr, reader); |
94 | } else { |
95 | delete tmp_reader; |
96 | } |
97 | CHECK_EQ(size_t{1}, still_opening_.erase(filepattern)); |
98 | VLOG(1) << "Cached TensorSliceReader for " << filepattern << ": " << reader; |
99 | } else { |
100 | auto cached_val = readers_[filepattern]; |
101 | if (cached_val.first == *func_ptr) { |
102 | reader = cached_val.second; |
103 | VLOG(1) << "Using cached TensorSliceReader for " << filepattern << ": " |
104 | << reader; |
105 | } else { |
106 | LOG(WARNING) << "Caching disabled because the checkpoint file " |
107 | << "is being opened with two different open functions: " |
108 | << filepattern; |
109 | } |
110 | } |
111 | |
112 | cv_.notify_all(); |
113 | return reader; |
114 | } |
115 | |
116 | } // namespace checkpoint |
117 | |
118 | } // namespace tensorflow |
119 | |