1 | #include <ATen/ThreadLocalState.h> |
---|---|
2 | #include <c10/util/Optional.h> |
3 | #include <torch/csrc/distributed/c10d/sequence_num.hpp> |
4 | |
5 | #include <c10/util/Logging.h> |
6 | |
7 | namespace c10d { |
8 | SequenceNum::SequenceNum() : num_(c10::nullopt) {} |
9 | |
10 | SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} |
11 | |
12 | SequenceNum::SequenceNum(const SequenceNum& other) { |
13 | if (!other.isSet()) { |
14 | num_ = c10::nullopt; |
15 | } else { |
16 | num_ = other.get(); |
17 | } |
18 | } |
19 | |
20 | uint64_t SequenceNum::get() const { |
21 | std::lock_guard<std::mutex> lock(lock_); |
22 | return *num_; |
23 | } |
24 | |
25 | void SequenceNum::increment() { |
26 | std::lock_guard<std::mutex> lock(lock_); |
27 | TORCH_CHECK(num_ != c10::nullopt); |
28 | num_ = ++(*num_); |
29 | } |
30 | |
31 | // Implemented without above get() and increment() so we don't repeatedly lock |
32 | // and unblock. |
33 | uint64_t SequenceNum::getAndIncrement() { |
34 | uint64_t curVal = 0; |
35 | std::lock_guard<std::mutex> lock(lock_); |
36 | TORCH_CHECK(num_ != c10::nullopt); |
37 | curVal = *num_; |
38 | num_ = ++(*num_); |
39 | return curVal; |
40 | } |
41 | |
42 | void SequenceNum::set(const uint64_t num) { |
43 | std::lock_guard<std::mutex> lock(lock_); |
44 | num_ = num; |
45 | } |
46 | |
47 | bool SequenceNum::isSet() const { |
48 | std::lock_guard<std::mutex> lock(lock_); |
49 | return num_ != c10::nullopt; |
50 | } |
51 | |
52 | SequenceNum& SequenceNum::operator=(const SequenceNum& other) { |
53 | std::lock_guard<std::mutex> lock(lock_); |
54 | if (!other.isSet()) { |
55 | num_ = c10::nullopt; |
56 | } else { |
57 | num_ = other.get(); |
58 | } |
59 | return *this; |
60 | } |
61 | |
62 | } // namespace c10d |
63 |