1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/kernels/checkpoint_callback_manager.h"
16
17#include <string>
18#include <utility>
19
20#include "absl/container/flat_hash_map.h"
21#include "absl/strings/str_cat.h"
22#include "absl/strings/string_view.h"
23#include "tensorflow/core/platform/env.h"
24#include "tensorflow/core/platform/errors.h"
25#include "tensorflow/core/platform/mutex.h"
26#include "tensorflow/core/platform/path.h"
27#include "tensorflow/core/platform/regexp.h"
28#include "tensorflow/core/platform/status.h"
29#include "tensorflow/core/platform/statusor.h"
30#include "tensorflow/core/platform/stringpiece.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34namespace checkpoint {
35
36const absl::string_view kCheckpointCallbackManagerResourceName =
37 "checkpoint_callback_manager";
38
39namespace {
40
41const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$";
42const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$";
43const absl::string_view kCheckpointDirRegex = "-[0-9]*$";
44const absl::string_view kCheckpointTempDirSuffix = "_temp";
45
46void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,
47 absl::string_view checkpoint_dir,
48 absl::string_view file_extension,
49 SaveCallback callback) {
50 const std::string file_path = io::JoinPath(
51 checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
52
53 // If the file already exists, we are done.
54 if (Env::Default()->FileExists(file_path).ok()) {
55 return;
56 }
57 LOG(INFO) << "Calling a save callback: file_extension = " << file_extension
58 << ", checkpoint_id = " << checkpoint_id;
59 // The callback should return a string to store.
60 StatusOr<std::string> save_content = callback(checkpoint_id);
61 if (!save_content.ok()) {
62 LOG(WARNING) << save_content.status();
63 return;
64 }
65
66 // An empty string means nothing to be saved.
67 if (save_content->empty()) {
68 return;
69 }
70
71 Status write_status =
72 WriteStringToFile(Env::Default(), file_path, *save_content);
73 if (!write_status.ok()) {
74 LOG(WARNING) << write_status;
75 } else {
76 LOG(INFO) << "A CheckpointCallbackManager has been written to "
77 << file_path;
78 }
79}
80
81void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id,
82 absl::string_view checkpoint_dir,
83 absl::string_view file_extension,
84 RestoreCallback callback) {
85 const std::string file_path = io::JoinPath(
86 checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
87 if (!Env::Default()->FileExists(file_path).ok()) {
88 return;
89 }
90 std::string payload;
91 Status read_status = ReadFileToString(Env::Default(), file_path, &payload);
92 if (!read_status.ok()) {
93 LOG(WARNING) << "Failed to read: " << read_status;
94 return;
95 }
96
97 LOG(INFO) << "Calling a restore callback: file_extension = " << file_extension
98 << ", checkpoint_id = " << checkpoint_id;
99 Status callback_status = callback(checkpoint_id, payload);
100 if (!callback_status.ok()) {
101 LOG(WARNING) << callback_status;
102 }
103}
104
105} // namespace
106
107// Examples:
108// "/foo/bar/checkpoint-1_temp/part-00000-of-00001" -->
109// ("checkpoint-1", "/foo/bar");
110// "/foo/bar/checkpoint-2/part-00000-of-00001" -->
111// ("checkpoint-2", "/foo/bar");
112// "/foo/bar/checkpoint-3" --> ("checkpoint-3", "/foo/bar");
113// "/foo/bar" --> NotFound error
114StatusOr<std::pair<std::string, std::string>>
115CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix(
116 absl::string_view prefix) {
117 for (absl::string_view path = prefix;; path = io::Dirname(path)) {
118 absl::string_view basename = io::Basename(path);
119
120 // Failed to find checkpoint_id
121 if (basename.empty()) break;
122
123 // Skip known checkpoint file: e.g., part-00000-of-00001
124 if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue;
125
126 // With _temp suffix: e.g., checkpoint-1_temp
127 if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) {
128 // Trim suffix, "_temp".
129 return std::make_pair(
130 std::string(basename.substr(
131 0, basename.length() - kCheckpointTempDirSuffix.length())),
132 std::string(io::Dirname(path)));
133 }
134
135 // Without _temp suffix: e.g., checkpoint-1
136 if (RE2::PartialMatch(basename, kCheckpointDirRegex)) {
137 return std::make_pair(std::string(basename),
138 std::string(io::Dirname(path)));
139 }
140 }
141 return errors::NotFound(
142 absl::StrCat("Failed to find a checkpoint id. prefix = ", prefix));
143}
144
145Status CheckpointCallbackManager::RegisterSaveCallback(
146 absl::string_view file_extension, SaveCallback callback) {
147 SaveCallback lazy_callback = nullptr;
148 std::string checkpoint_id;
149 std::string checkpoint_dir;
150 {
151 mutex_lock l(mu_);
152 if (!save_callbacks_.try_emplace(file_extension, std::move(callback))
153 .second) {
154 return errors::AlreadyExists("A callback already exists.");
155 }
156
157 // If last_saved_checkpoint_id_and_dir_ is not empty,
158 // tries to trigger save callback lazily.
159 if (!last_saved_checkpoint_id_and_dir_.first.empty()) {
160 lazy_callback = save_callbacks_[file_extension];
161 checkpoint_id = last_saved_checkpoint_id_and_dir_.first;
162 checkpoint_dir = last_saved_checkpoint_id_and_dir_.second;
163 }
164 }
165
166 if (lazy_callback != nullptr) {
167 TriggerSaveCallbackIfFileNotExist(checkpoint_id, checkpoint_dir,
168 file_extension, lazy_callback);
169 }
170 return OkStatus();
171}
172
173bool CheckpointCallbackManager::DoesSaveCallbackExist(
174 absl::string_view file_extension) {
175 tf_shared_lock l(mu_);
176 return save_callbacks_.contains(file_extension);
177}
178
179Status CheckpointCallbackManager::RegisterRestoreCallback(
180 absl::string_view file_extension, RestoreCallback callback) {
181 RestoreCallback lazy_callback = nullptr;
182 std::string checkpoint_id;
183 std::string checkpoint_dir;
184 {
185 mutex_lock l(mu_);
186 if (!restore_callbacks_.try_emplace(file_extension, std::move(callback))
187 .second) {
188 return errors::AlreadyExists("A callback already exists.");
189 }
190
191 // If last_restored_checkpoint_id_and_dir_ is not empty,
192 // tries to trigger restore callback lazily.
193 if (!last_restored_checkpoint_id_and_dir_.first.empty()) {
194 lazy_callback = restore_callbacks_[file_extension];
195 checkpoint_id = last_restored_checkpoint_id_and_dir_.first;
196 checkpoint_dir = last_restored_checkpoint_id_and_dir_.second;
197 }
198 }
199
200 if (lazy_callback != nullptr) {
201 TriggerRestoreCallbackIfFileExists(checkpoint_id, checkpoint_dir,
202 file_extension, lazy_callback);
203 }
204 return OkStatus();
205}
206
207bool CheckpointCallbackManager::DoesRestoreCallbackExist(
208 absl::string_view file_extension) {
209 tf_shared_lock l(mu_);
210 return restore_callbacks_.contains(file_extension);
211}
212
213void CheckpointCallbackManager::Save(absl::string_view prefix) {
214 StatusOr<std::pair<std::string, std::string>> id_and_dir =
215 GetCheckpointIdAndPathFromPrefix(prefix);
216 if (!id_and_dir.ok()) {
217 return;
218 }
219
220 // Create a copy to avoid holding lock while calling a callback.
221 absl::flat_hash_map<std::string, SaveCallback> copy_of_save_callbacks;
222 {
223 mutex_lock l(mu_);
224 last_saved_checkpoint_id_and_dir_ = *id_and_dir;
225 copy_of_save_callbacks = save_callbacks_;
226 }
227
228 for (const auto& name_and_callback : copy_of_save_callbacks) {
229 TriggerSaveCallbackIfFileNotExist(id_and_dir->first, id_and_dir->second,
230 name_and_callback.first,
231 name_and_callback.second);
232 }
233}
234
235void CheckpointCallbackManager::Restore(absl::string_view prefix) {
236 StatusOr<std::pair<std::string, std::string>> id_and_dir =
237 GetCheckpointIdAndPathFromPrefix(prefix);
238 if (!id_and_dir.ok()) {
239 return;
240 }
241
242 // Create a copy to avoid holding lock while calling a callback.
243 absl::flat_hash_map<std::string, RestoreCallback> copy_of_restore_callbacks;
244 {
245 mutex_lock l(mu_);
246 last_restored_checkpoint_id_and_dir_ = *id_and_dir;
247 copy_of_restore_callbacks = restore_callbacks_;
248 }
249
250 for (const auto& name_and_callback : copy_of_restore_callbacks) {
251 TriggerRestoreCallbackIfFileExists(id_and_dir->first, id_and_dir->second,
252 name_and_callback.first,
253 name_and_callback.second);
254 }
255}
256
257} // namespace checkpoint
258} // namespace tensorflow
259