1 | #include <torch/csrc/distributed/rpc/types.h> |
2 | |
3 | namespace torch { |
4 | namespace distributed { |
5 | namespace rpc { |
6 | |
7 | // Thread local flag to enforce rref JIT pickling to be allowed only |
8 | // in the scope of an rpc call. For other scopes like when model is |
9 | // saved by calling torch.save(), rref is not allowed to be pickled directly. |
10 | static thread_local bool allowJitRRefPickle = false; |
11 | |
12 | bool getAllowJitRRefPickle() { |
13 | return allowJitRRefPickle; |
14 | } |
15 | |
16 | void enableJitRRefPickle() { |
17 | allowJitRRefPickle = true; |
18 | } |
19 | |
20 | void disableJitRRefPickle() { |
21 | allowJitRRefPickle = false; |
22 | } |
23 | |
24 | static_assert( |
25 | std::numeric_limits<local_id_t>::max() <= |
26 | std::numeric_limits<int64_t>::max(), |
27 | "The max value of local_id_t must be within the range of int64_t" ); |
28 | static_assert( |
29 | std::numeric_limits<worker_id_t>::max() <= |
30 | std::numeric_limits<int64_t>::max(), |
31 | "The max value of worker_id_t must be within the range of int64_t" ); |
32 | |
33 | /////////////////////////// JitRRefPickleGuard /////////////////////////// |
34 | JitRRefPickleGuard::JitRRefPickleGuard() { |
35 | allowJitRRefPickle = true; |
36 | } |
37 | JitRRefPickleGuard::~JitRRefPickleGuard() { |
38 | allowJitRRefPickle = false; |
39 | } |
40 | |
41 | /////////////////////////// GloballyUniqueId /////////////////////////// |
42 | |
43 | GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId) |
44 | : createdOn_(createdOn), localId_(localId) {} |
45 | |
46 | bool GloballyUniqueId::operator==(const GloballyUniqueId& other) const { |
47 | return createdOn_ == other.createdOn_ && localId_ == other.localId_; |
48 | } |
49 | |
50 | bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const { |
51 | return createdOn_ != other.createdOn_ || localId_ != other.localId_; |
52 | } |
53 | |
54 | at::IValue GloballyUniqueId::toIValue() const { |
55 | return c10::ivalue::Tuple::create( |
56 | {static_cast<int64_t>(createdOn_), static_cast<int64_t>(localId_)}); |
57 | } |
58 | |
59 | GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) { |
60 | TORCH_INTERNAL_ASSERT( |
61 | ivalue.isTuple(), |
62 | "GloballyUniqueId::fromIValue expected ivalue to be a tuple." ); |
63 | const auto& ivalues = ivalue.toTupleRef().elements(); |
64 | TORCH_CHECK( |
65 | ivalues.size() == 2, |
66 | "Constructing GloballyUniqueId from ivalue " |
67 | "expects a GenericList of two elements, but got " , |
68 | ivalues.size()); |
69 | |
70 | TORCH_CHECK( |
71 | ivalues[0].toInt() <= std::numeric_limits<worker_id_t>::max(), |
72 | "GloballyUniqueId createdOn out of range, got " , |
73 | ivalues[0].toInt()); |
74 | worker_id_t createdOn = ivalues[0].toInt(); |
75 | |
76 | TORCH_CHECK( |
77 | ivalues[1].toInt() <= std::numeric_limits<local_id_t>::max(), |
78 | "GloballyUniqueId localId out of range, got " , |
79 | ivalues[1].toInt()); |
80 | local_id_t localId = ivalues[1].toInt(); |
81 | |
82 | return GloballyUniqueId(createdOn, localId); |
83 | } |
84 | |
85 | std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) { |
86 | return os << "GloballyUniqueId(created_on=" << globalId.createdOn_ |
87 | << ", local_id=" << globalId.localId_ << ")" ; |
88 | } |
89 | |
90 | /////////////////////////// SerializedPyObj /////////////////////////// |
91 | |
92 | std::vector<at::IValue> SerializedPyObj::toIValues() && { |
93 | std::vector<at::IValue> ivalues; |
94 | ivalues.reserve(tensors_.size() + 1); |
95 | for (auto& tensor : tensors_) { |
96 | ivalues.emplace_back(std::move(tensor)); |
97 | } |
98 | ivalues.emplace_back(std::move(payload_)); |
99 | return ivalues; |
100 | } |
101 | |
102 | SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) { |
103 | std::string payload = values.back().toStringRef(); |
104 | values.pop_back(); |
105 | std::vector<at::Tensor> tensors; |
106 | tensors.reserve(values.size()); |
107 | for (auto& value : values) { |
108 | tensors.emplace_back(value.toTensor()); |
109 | } |
110 | return SerializedPyObj(std::move(payload), std::move(tensors)); |
111 | } |
112 | |
113 | } // namespace rpc |
114 | } // namespace distributed |
115 | } // namespace torch |
116 | |