1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <atomic>
5
6namespace torch {
7namespace distributed {
8namespace rpc {
9
10using worker_id_t = int16_t;
11using local_id_t = int64_t;
12
13bool getAllowJitRRefPickle();
14TORCH_API void enableJitRRefPickle();
15TORCH_API void disableJitRRefPickle();
16
17struct TORCH_API JitRRefPickleGuard {
18 JitRRefPickleGuard();
19 ~JitRRefPickleGuard();
20};
21
22struct TORCH_API GloballyUniqueId final {
23 GloballyUniqueId(worker_id_t createdOn, local_id_t localId);
24 GloballyUniqueId(const GloballyUniqueId& other) = default;
25 GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete;
26
27 bool operator==(const GloballyUniqueId& other) const;
28 bool operator!=(const GloballyUniqueId& other) const;
29
30 at::IValue toIValue() const;
31 static GloballyUniqueId fromIValue(const at::IValue&);
32
33 struct Hash {
34 size_t operator()(const GloballyUniqueId& key) const {
35 return (uint64_t(key.createdOn_) << kLocalIdBits) | key.localId_;
36 }
37 };
38
39 static constexpr int kLocalIdBits = 48;
40
41 const worker_id_t createdOn_;
42 const local_id_t localId_;
43};
44
45TORCH_API std::ostream& operator<<(
46 std::ostream& os,
47 const GloballyUniqueId& globalId);
48
49using RRefId = GloballyUniqueId;
50using ForkId = GloballyUniqueId;
51using ProfilingId = GloballyUniqueId;
52
53struct TORCH_API SerializedPyObj final {
54 SerializedPyObj(std::string&& payload, std::vector<at::Tensor>&& tensors)
55 : payload_(std::move(payload)), tensors_(std::move(tensors)) {}
56
57 std::vector<at::IValue> toIValues() &&;
58 static SerializedPyObj fromIValues(std::vector<at::IValue> value);
59
60 std::string payload_;
61 std::vector<at::Tensor> tensors_;
62};
63
64} // namespace rpc
65} // namespace distributed
66} // namespace torch
67