1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/kernels/checkpoint_callback_manager.h" |
16 | |
17 | #include <string> |
18 | #include <utility> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "absl/strings/str_cat.h" |
22 | #include "absl/strings/string_view.h" |
23 | #include "tensorflow/core/platform/env.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | #include "tensorflow/core/platform/mutex.h" |
26 | #include "tensorflow/core/platform/path.h" |
27 | #include "tensorflow/core/platform/regexp.h" |
28 | #include "tensorflow/core/platform/status.h" |
29 | #include "tensorflow/core/platform/statusor.h" |
30 | #include "tensorflow/core/platform/stringpiece.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | namespace checkpoint { |
35 | |
36 | const absl::string_view kCheckpointCallbackManagerResourceName = |
37 | "checkpoint_callback_manager" ; |
38 | |
39 | namespace { |
40 | |
41 | const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$" ; |
42 | const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$" ; |
43 | const absl::string_view kCheckpointDirRegex = "-[0-9]*$" ; |
44 | const absl::string_view kCheckpointTempDirSuffix = "_temp" ; |
45 | |
46 | void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id, |
47 | absl::string_view checkpoint_dir, |
48 | absl::string_view file_extension, |
49 | SaveCallback callback) { |
50 | const std::string file_path = io::JoinPath( |
51 | checkpoint_dir, absl::StrCat(checkpoint_id, "." , file_extension)); |
52 | |
53 | // If the file already exists, we are done. |
54 | if (Env::Default()->FileExists(file_path).ok()) { |
55 | return; |
56 | } |
57 | LOG(INFO) << "Calling a save callback: file_extension = " << file_extension |
58 | << ", checkpoint_id = " << checkpoint_id; |
59 | // The callback should return a string to store. |
60 | StatusOr<std::string> save_content = callback(checkpoint_id); |
61 | if (!save_content.ok()) { |
62 | LOG(WARNING) << save_content.status(); |
63 | return; |
64 | } |
65 | |
66 | // An empty string means nothing to be saved. |
67 | if (save_content->empty()) { |
68 | return; |
69 | } |
70 | |
71 | Status write_status = |
72 | WriteStringToFile(Env::Default(), file_path, *save_content); |
73 | if (!write_status.ok()) { |
74 | LOG(WARNING) << write_status; |
75 | } else { |
76 | LOG(INFO) << "A CheckpointCallbackManager has been written to " |
77 | << file_path; |
78 | } |
79 | } |
80 | |
81 | void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id, |
82 | absl::string_view checkpoint_dir, |
83 | absl::string_view file_extension, |
84 | RestoreCallback callback) { |
85 | const std::string file_path = io::JoinPath( |
86 | checkpoint_dir, absl::StrCat(checkpoint_id, "." , file_extension)); |
87 | if (!Env::Default()->FileExists(file_path).ok()) { |
88 | return; |
89 | } |
90 | std::string payload; |
91 | Status read_status = ReadFileToString(Env::Default(), file_path, &payload); |
92 | if (!read_status.ok()) { |
93 | LOG(WARNING) << "Failed to read: " << read_status; |
94 | return; |
95 | } |
96 | |
97 | LOG(INFO) << "Calling a restore callback: file_extension = " << file_extension |
98 | << ", checkpoint_id = " << checkpoint_id; |
99 | Status callback_status = callback(checkpoint_id, payload); |
100 | if (!callback_status.ok()) { |
101 | LOG(WARNING) << callback_status; |
102 | } |
103 | } |
104 | |
105 | } // namespace |
106 | |
107 | // Examples: |
108 | // "/foo/bar/checkpoint-1_temp/part-00000-of-00001" --> |
109 | // ("checkpoint-1", "/foo/bar"); |
110 | // "/foo/bar/checkpoint-2/part-00000-of-00001" --> |
111 | // ("checkpoint-2", "/foo/bar"); |
112 | // "/foo/bar/checkpoint-3" --> ("checkpoint-3", "/foo/bar"); |
113 | // "/foo/bar" --> NotFound error |
114 | StatusOr<std::pair<std::string, std::string>> |
115 | CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix( |
116 | absl::string_view prefix) { |
117 | for (absl::string_view path = prefix;; path = io::Dirname(path)) { |
118 | absl::string_view basename = io::Basename(path); |
119 | |
120 | // Failed to find checkpoint_id |
121 | if (basename.empty()) break; |
122 | |
123 | // Skip known checkpoint file: e.g., part-00000-of-00001 |
124 | if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue; |
125 | |
126 | // With _temp suffix: e.g., checkpoint-1_temp |
127 | if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) { |
128 | // Trim suffix, "_temp". |
129 | return std::make_pair( |
130 | std::string(basename.substr( |
131 | 0, basename.length() - kCheckpointTempDirSuffix.length())), |
132 | std::string(io::Dirname(path))); |
133 | } |
134 | |
135 | // Without _temp suffix: e.g., checkpoint-1 |
136 | if (RE2::PartialMatch(basename, kCheckpointDirRegex)) { |
137 | return std::make_pair(std::string(basename), |
138 | std::string(io::Dirname(path))); |
139 | } |
140 | } |
141 | return errors::NotFound( |
142 | absl::StrCat("Failed to find a checkpoint id. prefix = " , prefix)); |
143 | } |
144 | |
145 | Status CheckpointCallbackManager::RegisterSaveCallback( |
146 | absl::string_view file_extension, SaveCallback callback) { |
147 | SaveCallback lazy_callback = nullptr; |
148 | std::string checkpoint_id; |
149 | std::string checkpoint_dir; |
150 | { |
151 | mutex_lock l(mu_); |
152 | if (!save_callbacks_.try_emplace(file_extension, std::move(callback)) |
153 | .second) { |
154 | return errors::AlreadyExists("A callback already exists." ); |
155 | } |
156 | |
157 | // If last_saved_checkpoint_id_and_dir_ is not empty, |
158 | // tries to trigger save callback lazily. |
159 | if (!last_saved_checkpoint_id_and_dir_.first.empty()) { |
160 | lazy_callback = save_callbacks_[file_extension]; |
161 | checkpoint_id = last_saved_checkpoint_id_and_dir_.first; |
162 | checkpoint_dir = last_saved_checkpoint_id_and_dir_.second; |
163 | } |
164 | } |
165 | |
166 | if (lazy_callback != nullptr) { |
167 | TriggerSaveCallbackIfFileNotExist(checkpoint_id, checkpoint_dir, |
168 | file_extension, lazy_callback); |
169 | } |
170 | return OkStatus(); |
171 | } |
172 | |
173 | bool CheckpointCallbackManager::DoesSaveCallbackExist( |
174 | absl::string_view file_extension) { |
175 | tf_shared_lock l(mu_); |
176 | return save_callbacks_.contains(file_extension); |
177 | } |
178 | |
179 | Status CheckpointCallbackManager::RegisterRestoreCallback( |
180 | absl::string_view file_extension, RestoreCallback callback) { |
181 | RestoreCallback lazy_callback = nullptr; |
182 | std::string checkpoint_id; |
183 | std::string checkpoint_dir; |
184 | { |
185 | mutex_lock l(mu_); |
186 | if (!restore_callbacks_.try_emplace(file_extension, std::move(callback)) |
187 | .second) { |
188 | return errors::AlreadyExists("A callback already exists." ); |
189 | } |
190 | |
191 | // If last_restored_checkpoint_id_and_dir_ is not empty, |
192 | // tries to trigger restore callback lazily. |
193 | if (!last_restored_checkpoint_id_and_dir_.first.empty()) { |
194 | lazy_callback = restore_callbacks_[file_extension]; |
195 | checkpoint_id = last_restored_checkpoint_id_and_dir_.first; |
196 | checkpoint_dir = last_restored_checkpoint_id_and_dir_.second; |
197 | } |
198 | } |
199 | |
200 | if (lazy_callback != nullptr) { |
201 | TriggerRestoreCallbackIfFileExists(checkpoint_id, checkpoint_dir, |
202 | file_extension, lazy_callback); |
203 | } |
204 | return OkStatus(); |
205 | } |
206 | |
207 | bool CheckpointCallbackManager::DoesRestoreCallbackExist( |
208 | absl::string_view file_extension) { |
209 | tf_shared_lock l(mu_); |
210 | return restore_callbacks_.contains(file_extension); |
211 | } |
212 | |
213 | void CheckpointCallbackManager::Save(absl::string_view prefix) { |
214 | StatusOr<std::pair<std::string, std::string>> id_and_dir = |
215 | GetCheckpointIdAndPathFromPrefix(prefix); |
216 | if (!id_and_dir.ok()) { |
217 | return; |
218 | } |
219 | |
220 | // Create a copy to avoid holding lock while calling a callback. |
221 | absl::flat_hash_map<std::string, SaveCallback> copy_of_save_callbacks; |
222 | { |
223 | mutex_lock l(mu_); |
224 | last_saved_checkpoint_id_and_dir_ = *id_and_dir; |
225 | copy_of_save_callbacks = save_callbacks_; |
226 | } |
227 | |
228 | for (const auto& name_and_callback : copy_of_save_callbacks) { |
229 | TriggerSaveCallbackIfFileNotExist(id_and_dir->first, id_and_dir->second, |
230 | name_and_callback.first, |
231 | name_and_callback.second); |
232 | } |
233 | } |
234 | |
235 | void CheckpointCallbackManager::Restore(absl::string_view prefix) { |
236 | StatusOr<std::pair<std::string, std::string>> id_and_dir = |
237 | GetCheckpointIdAndPathFromPrefix(prefix); |
238 | if (!id_and_dir.ok()) { |
239 | return; |
240 | } |
241 | |
242 | // Create a copy to avoid holding lock while calling a callback. |
243 | absl::flat_hash_map<std::string, RestoreCallback> copy_of_restore_callbacks; |
244 | { |
245 | mutex_lock l(mu_); |
246 | last_restored_checkpoint_id_and_dir_ = *id_and_dir; |
247 | copy_of_restore_callbacks = restore_callbacks_; |
248 | } |
249 | |
250 | for (const auto& name_and_callback : copy_of_restore_callbacks) { |
251 | TriggerRestoreCallbackIfFileExists(id_and_dir->first, id_and_dir->second, |
252 | name_and_callback.first, |
253 | name_and_callback.second); |
254 | } |
255 | } |
256 | |
257 | } // namespace checkpoint |
258 | } // namespace tensorflow |
259 | |