1#include <torch/csrc/distributed/rpc/types.h>
2
3namespace torch {
4namespace distributed {
5namespace rpc {
6
7// Thread local flag to enforce rref JIT pickling to be allowed only
8// in the scope of an rpc call. For other scopes like when model is
9// saved by calling torch.save(), rref is not allowed to be pickled directly.
10static thread_local bool allowJitRRefPickle = false;
11
12bool getAllowJitRRefPickle() {
13 return allowJitRRefPickle;
14}
15
16void enableJitRRefPickle() {
17 allowJitRRefPickle = true;
18}
19
20void disableJitRRefPickle() {
21 allowJitRRefPickle = false;
22}
23
24static_assert(
25 std::numeric_limits<local_id_t>::max() <=
26 std::numeric_limits<int64_t>::max(),
27 "The max value of local_id_t must be within the range of int64_t");
28static_assert(
29 std::numeric_limits<worker_id_t>::max() <=
30 std::numeric_limits<int64_t>::max(),
31 "The max value of worker_id_t must be within the range of int64_t");
32
33/////////////////////////// JitRRefPickleGuard ///////////////////////////
34JitRRefPickleGuard::JitRRefPickleGuard() {
35 allowJitRRefPickle = true;
36}
37JitRRefPickleGuard::~JitRRefPickleGuard() {
38 allowJitRRefPickle = false;
39}
40
41/////////////////////////// GloballyUniqueId ///////////////////////////
42
43GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId)
44 : createdOn_(createdOn), localId_(localId) {}
45
46bool GloballyUniqueId::operator==(const GloballyUniqueId& other) const {
47 return createdOn_ == other.createdOn_ && localId_ == other.localId_;
48}
49
50bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const {
51 return createdOn_ != other.createdOn_ || localId_ != other.localId_;
52}
53
54at::IValue GloballyUniqueId::toIValue() const {
55 return c10::ivalue::Tuple::create(
56 {static_cast<int64_t>(createdOn_), static_cast<int64_t>(localId_)});
57}
58
59GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
60 TORCH_INTERNAL_ASSERT(
61 ivalue.isTuple(),
62 "GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
63 const auto& ivalues = ivalue.toTupleRef().elements();
64 TORCH_CHECK(
65 ivalues.size() == 2,
66 "Constructing GloballyUniqueId from ivalue "
67 "expects a GenericList of two elements, but got ",
68 ivalues.size());
69
70 TORCH_CHECK(
71 ivalues[0].toInt() <= std::numeric_limits<worker_id_t>::max(),
72 "GloballyUniqueId createdOn out of range, got ",
73 ivalues[0].toInt());
74 worker_id_t createdOn = ivalues[0].toInt();
75
76 TORCH_CHECK(
77 ivalues[1].toInt() <= std::numeric_limits<local_id_t>::max(),
78 "GloballyUniqueId localId out of range, got ",
79 ivalues[1].toInt());
80 local_id_t localId = ivalues[1].toInt();
81
82 return GloballyUniqueId(createdOn, localId);
83}
84
85std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) {
86 return os << "GloballyUniqueId(created_on=" << globalId.createdOn_
87 << ", local_id=" << globalId.localId_ << ")";
88}
89
90/////////////////////////// SerializedPyObj ///////////////////////////
91
92std::vector<at::IValue> SerializedPyObj::toIValues() && {
93 std::vector<at::IValue> ivalues;
94 ivalues.reserve(tensors_.size() + 1);
95 for (auto& tensor : tensors_) {
96 ivalues.emplace_back(std::move(tensor));
97 }
98 ivalues.emplace_back(std::move(payload_));
99 return ivalues;
100}
101
102SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) {
103 std::string payload = values.back().toStringRef();
104 values.pop_back();
105 std::vector<at::Tensor> tensors;
106 tensors.reserve(values.size());
107 for (auto& value : values) {
108 tensors.emplace_back(value.toTensor());
109 }
110 return SerializedPyObj(std::move(payload), std::move(tensors));
111}
112
113} // namespace rpc
114} // namespace distributed
115} // namespace torch
116