1#include <torch/csrc/distributed/c10d/FileStore.hpp>
2
3#include <fcntl.h>
4#include <sys/stat.h>
5#include <cassert>
6#include <cstdint>
7
8#ifdef _WIN32
9#include <c10/util/win32-headers.h>
10#include <fileapi.h>
11#include <io.h>
12#else
13#include <sys/file.h>
14#include <unistd.h>
15#endif
16
17#include <chrono>
18#include <cstdio>
19#include <functional>
20#include <iostream>
21#include <limits>
22#include <sstream>
23#include <system_error>
24#include <thread>
25#include <utility>
26
27#include <c10/util/Exception.h>
28
29#define SYSASSERT(rv, ...) \
30 if ((rv) < 0) { \
31 throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
32 }
33
34#ifdef _WIN32
35#define LOCK_EX 0x00000001
36#define LOCK_SH 0x00000010
37#define LOCK_UN 0x00000100
38
39int flock_(int fd, int op) {
40 HANDLE hdl = (HANDLE)_get_osfhandle(fd);
41 DWORD low = 1, high = 0;
42 OVERLAPPED offset = {0, 0, 0, 0, NULL};
43
44 if ((intptr_t)hdl < 0)
45 return -1;
46
47 switch (op) {
48 case LOCK_EX:
49 if (LockFileEx(hdl, LOCKFILE_EXCLUSIVE_LOCK, 0, low, high, &offset))
50 return 0;
51 break;
52 case LOCK_SH:
53 if (LockFileEx(hdl, 0, 0, low, high, &offset))
54 return 0;
55 break;
56 case LOCK_UN:
57 if (UnlockFileEx(hdl, 0, low, high, &offset) != 0)
58 return 0;
59 break;
60 default:
61 break;
62 }
63 errno = EINVAL;
64 return -1;
65}
66#endif
67
68namespace c10d {
69
70namespace {
71
72template <typename F>
73typename c10::invoke_result_t<F> syscall(F fn) {
74 while (true) {
75 auto rv = fn();
76 if (rv == -1) {
77 if (errno == EINTR) {
78 continue;
79 }
80 }
81 return rv;
82 }
83}
84
85// For a comprehensive overview of file locking methods,
86// see: https://gavv.github.io/blog/file-locks/.
87// We stick to flock(2) here because we don't care about
88// locking byte ranges and don't want locks to be process-wide.
89
90// RAII wrapper around flock(2)
91class Lock {
92 public:
93 explicit Lock(int fd, int operation) : fd_(fd) {
94 flock(operation);
95 }
96
97 ~Lock() {
98 unlock();
99 }
100
101 Lock(const Lock& that) = delete;
102
103 Lock& operator=(Lock&& other) noexcept {
104 if (this != &other) {
105 fd_ = other.fd_;
106 other.fd_ = -1;
107 }
108 return *this;
109 }
110
111 Lock(Lock&& other) noexcept {
112 *this = std::move(other);
113 }
114
115 void unlock() {
116 if (fd_ >= 0) {
117 flock(LOCK_UN);
118 fd_ = -1;
119 }
120 }
121
122 protected:
123 int fd_{-1};
124
125 void flock(int operation) {
126#ifdef _WIN32
127 auto rv = syscall(std::bind(::flock_, fd_, operation));
128#else
129 auto rv = syscall([this, operation] { return ::flock(fd_, operation); });
130#endif
131 SYSASSERT(rv, "flock");
132 }
133};
134
135class File {
136 public:
137 explicit File(
138 const std::string& path,
139 int flags,
140 std::chrono::milliseconds timeout) {
141 const auto start = std::chrono::steady_clock::now();
142 while (true) {
143#ifdef _WIN32
144 fd_ = syscall(std::bind(
145 ::open, path.c_str(), flags | _O_BINARY, _S_IREAD | _S_IWRITE));
146#else
147 fd_ = syscall([capture0 = path.c_str(), flags] {
148 return ::open(capture0, flags, 0644);
149 });
150#endif
151 // Only retry when the file doesn't exist, since we are waiting for the
152 // file to be created in this case to address the following issue:
153 // https://github.com/pytorch/pytorch/issues/13750
154 if (fd_ >= 0 || errno != ENOENT) {
155 break;
156 }
157 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
158 std::chrono::steady_clock::now() - start);
159 if (timeout != c10d::Store::kNoTimeout && elapsed > timeout) {
160 break;
161 }
162 std::this_thread::sleep_for(std::chrono::milliseconds(10));
163 }
164 SYSASSERT(fd_, "open(" + path + ")");
165 }
166
167 ~File() {
168 ::close(fd_);
169 }
170
171 Lock lockShared() {
172 return Lock(fd_, LOCK_SH);
173 }
174
175 Lock lockExclusive() {
176 return Lock(fd_, LOCK_EX);
177 }
178
179 off_t seek(off_t offset, int whence) {
180 auto rv =
181 syscall([this, offset, whence] { return lseek(fd_, offset, whence); });
182 SYSASSERT(rv, "lseek");
183 return rv;
184 }
185
186 off_t tell() {
187 auto rv = syscall([this] { return lseek(fd_, 0, SEEK_CUR); });
188 SYSASSERT(rv, "lseek");
189 return rv;
190 }
191
192 off_t size() {
193 auto pos = tell();
194 auto size = seek(0, SEEK_END);
195 seek(pos, SEEK_SET);
196 return size;
197 }
198
199 void write(const void* buf, size_t count) {
200 while (count > 0) {
201 auto rv =
202 syscall([this, buf, count] { return ::write(fd_, buf, count); });
203 SYSASSERT(rv, "write");
204 buf = (uint8_t*)buf + rv;
205 count -= rv;
206 }
207 }
208
209 void read(void* buf, size_t count) {
210 while (count > 0) {
211 auto rv = syscall([this, buf, count] { return ::read(fd_, buf, count); });
212 SYSASSERT(rv, "read");
213 buf = (uint8_t*)buf + rv;
214 count -= rv;
215 }
216 }
217
218 void write(const std::string& str) {
219 uint32_t len = str.size();
220 assert(str.size() <= std::numeric_limits<decltype(len)>::max());
221 write(&len, sizeof(len));
222 write(str.c_str(), len);
223 }
224
225 void write(const std::vector<uint8_t>& data) {
226 uint32_t len = data.size();
227 assert(data.size() <= std::numeric_limits<decltype(len)>::max());
228 write(&len, sizeof(len));
229 write(data.data(), len);
230 }
231
232 void read(std::string& str) {
233 uint32_t len = 0;
234 read(&len, sizeof(len));
235 std::vector<uint8_t> buf(len);
236 read(buf.data(), len);
237 str.assign(buf.begin(), buf.end());
238 }
239
240 void read(std::vector<uint8_t>& data) {
241 uint32_t len = 0;
242 read(&len, sizeof(len));
243 data.resize(len);
244 read(data.data(), len);
245 }
246
247 protected:
248 int fd_;
249};
250
251off_t refresh(
252 File& file,
253 off_t pos,
254 std::unordered_map<std::string, std::vector<uint8_t>>& cache,
255 const std::string deletePrefix) {
256 auto size = file.size();
257 if (size != pos) {
258 std::string tmpKey;
259 std::vector<uint8_t> tmpValue;
260 file.seek(pos, SEEK_SET);
261 while (size > pos) {
262 file.read(tmpKey);
263 file.read(tmpValue);
264 if (tmpKey.compare(0, deletePrefix.size(), deletePrefix) == 0) {
265 cache.erase(tmpKey.substr(deletePrefix.size()));
266 } else {
267 cache[tmpKey] = std::move(tmpValue);
268 }
269 pos = file.tell();
270 }
271 }
272 file.seek(0, SEEK_SET);
273 return pos;
274}
275
276} // namespace
277
278FileStore::FileStore(std::string path, int numWorkers)
279 : Store(),
280 path_(std::move(path)),
281
282 numWorkers_(numWorkers),
283 cleanupKey_("cleanup/"),
284 refCountKey_("refcount/"),
285 regularPrefix_("/"),
286 deletePrefix_("-") {
287 addHelper(refCountKey_, 1);
288}
289
290FileStore::~FileStore() {
291 // If the file does not exist - exit.
292 // This can happen when FileStore is invoked from python language which has
293 // GC. If python code has directory cleanup procedure, the race condition may
294 // occur between that code and this deconstructor. As a result, we check for
295 // file existense before cleanup
296#ifdef _WIN32
297 int res = syscall(std::bind(::_access, path_.c_str(), 0));
298#else
299 int res =
300 syscall([filepath = path_.c_str()] { return ::access(filepath, F_OK); });
301#endif
302 if (res == -1) {
303 return;
304 }
305
306 // cleanup key will be different from all rest keys since all rest keys will
307 // have a regular prefix.
308 auto numFinishedWorker = addHelper(cleanupKey_, 1);
309 auto refCount = addHelper(refCountKey_, -1);
310 // The last worker cleans up the file. If numWorkers was not initialized to
311 // a specific postive value (i.e. meaning that there was not a fixed number
312 // of workers), we don't attempt to clean.
313 // Clean up the file if number of references is 0.
314 if (refCount == 0 && numWorkers_ >= 0 && numFinishedWorker >= numWorkers_) {
315 // Best effort removal without checking the return
316 ::remove(path_.c_str());
317 }
318}
319
320void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
321 std::string regKey = regularPrefix_ + key;
322 std::unique_lock<std::mutex> l(activeFileOpLock_);
323 File file(path_, O_RDWR | O_CREAT, timeout_);
324 auto lock = file.lockExclusive();
325 file.seek(0, SEEK_END);
326 file.write(regKey);
327 file.write(value);
328}
329
330std::vector<uint8_t> FileStore::compareSet(
331 const std::string& key,
332 const std::vector<uint8_t>& expectedValue,
333 const std::vector<uint8_t>& desiredValue) {
334 std::string regKey = regularPrefix_ + key;
335 std::unique_lock<std::mutex> l(activeFileOpLock_);
336 File file(path_, O_RDWR | O_CREAT, timeout_);
337 auto lock = file.lockExclusive();
338 // Always refresh since even though the key exists in the cache,
339 // it might be outdated
340 pos_ = refresh(file, pos_, cache_, deletePrefix_);
341 if ((cache_.count(regKey) == 0 && expectedValue.empty()) ||
342 (cache_.count(regKey) != 0 && cache_[regKey] == expectedValue)) {
343 // if the key does not exist and currentValue arg is empty or
344 // the key does exist and current value is what is expected, then set it
345 file.seek(0, SEEK_END);
346 file.write(regKey);
347 file.write(desiredValue);
348 return desiredValue;
349 } else if (cache_.count(regKey) == 0) {
350 // if the key does not exist
351 return expectedValue;
352 }
353 // key exists but current value is not expected
354 return cache_[regKey];
355}
356
357std::vector<uint8_t> FileStore::get(const std::string& key) {
358 std::string regKey = regularPrefix_ + key;
359 const auto start = std::chrono::steady_clock::now();
360 while (true) {
361 std::unique_lock<std::mutex> l(activeFileOpLock_);
362 File file(path_, O_RDONLY, timeout_);
363 auto lock = file.lockShared();
364 auto size = file.size();
365 if (cache_.count(regKey) == 0 && size == pos_) {
366 // No new entries; release the shared lock and sleep for a bit
367 lock.unlock();
368 l.unlock();
369 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
370 std::chrono::steady_clock::now() - start);
371 if (timeout_ != kNoTimeout && elapsed > timeout_) {
372 auto err = c10::str(
373 "Timeout waiting for key: ",
374 key,
375 " after ",
376 timeout_.count(),
377 " ms");
378 TORCH_CHECK(false, err);
379 }
380 std::this_thread::sleep_for(std::chrono::milliseconds(10));
381 continue;
382 }
383 // Always refresh since even though the key exists in the cache,
384 // it might be outdated
385 pos_ = refresh(file, pos_, cache_, deletePrefix_);
386 if (cache_.count(regKey) != 0) {
387 return cache_[regKey];
388 }
389 }
390}
391
392int64_t FileStore::addHelper(const std::string& key, int64_t i) {
393 std::unique_lock<std::mutex> l(activeFileOpLock_);
394 File file(path_, O_RDWR | O_CREAT, timeout_);
395 auto lock = file.lockExclusive();
396 pos_ = refresh(file, pos_, cache_, deletePrefix_);
397
398 const auto& value = cache_[key];
399 int64_t ti = i;
400 if (!value.empty()) {
401 auto buf = reinterpret_cast<const char*>(value.data());
402 auto len = value.size();
403 ti += std::stoll(std::string(buf, len));
404 }
405 // Always seek to the end to write
406 file.seek(0, SEEK_END);
407 // File cursor is at the end of the file now, and we have an
408 // exclusive lock, so we can write the new value.
409 file.write(key);
410 file.write(std::to_string(ti));
411 return ti;
412}
413
414int64_t FileStore::add(const std::string& key, int64_t value) {
415 std::string regKey = regularPrefix_ + key;
416 return addHelper(regKey, value);
417}
418
419int64_t FileStore::getNumKeys() {
420 std::unique_lock<std::mutex> l(activeFileOpLock_);
421 File file(path_, O_RDONLY, timeout_);
422 auto lock = file.lockShared();
423 pos_ = refresh(file, pos_, cache_, deletePrefix_);
424 return cache_.size();
425}
426
427bool FileStore::deleteKey(const std::string& key) {
428 std::string deleteKey = deletePrefix_ + regularPrefix_ + key;
429 std::unique_lock<std::mutex> l(activeFileOpLock_);
430 File file(path_, O_RDWR, timeout_);
431 auto lock = file.lockExclusive();
432 file.seek(0, SEEK_END);
433 file.write(deleteKey);
434 file.write(std::vector<uint8_t>{});
435 return true;
436}
437
438bool FileStore::check(const std::vector<std::string>& keys) {
439 std::unique_lock<std::mutex> l(activeFileOpLock_);
440 File file(path_, O_RDONLY, timeout_);
441 auto lock = file.lockShared();
442 pos_ = refresh(file, pos_, cache_, deletePrefix_);
443
444 for (const auto& key : keys) {
445 std::string regKey = regularPrefix_ + key;
446 if (cache_.count(regKey) == 0) {
447 return false;
448 }
449 }
450
451 return true;
452}
453
454void FileStore::wait(const std::vector<std::string>& keys) {
455 wait(keys, timeout_);
456}
457
458void FileStore::wait(
459 const std::vector<std::string>& keys,
460 const std::chrono::milliseconds& timeout) {
461 // Not using inotify because it doesn't work on many
462 // shared filesystems (such as NFS).
463 const auto start = std::chrono::steady_clock::now();
464 while (!check(keys)) {
465 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
466 std::chrono::steady_clock::now() - start);
467 if (timeout != kNoTimeout && elapsed > timeout) {
468 TORCH_CHECK(false, "Wait timeout");
469 }
470
471 /* sleep override */
472 std::this_thread::sleep_for(std::chrono::milliseconds(10));
473 }
474}
475
476} // namespace c10d
477