1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <atomic> |
5 | |
6 | namespace torch { |
7 | namespace distributed { |
8 | namespace rpc { |
9 | |
10 | using worker_id_t = int16_t; |
11 | using local_id_t = int64_t; |
12 | |
13 | bool getAllowJitRRefPickle(); |
14 | TORCH_API void enableJitRRefPickle(); |
15 | TORCH_API void disableJitRRefPickle(); |
16 | |
17 | struct TORCH_API JitRRefPickleGuard { |
18 | JitRRefPickleGuard(); |
19 | ~JitRRefPickleGuard(); |
20 | }; |
21 | |
22 | struct TORCH_API GloballyUniqueId final { |
23 | GloballyUniqueId(worker_id_t createdOn, local_id_t localId); |
24 | GloballyUniqueId(const GloballyUniqueId& other) = default; |
25 | GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete; |
26 | |
27 | bool operator==(const GloballyUniqueId& other) const; |
28 | bool operator!=(const GloballyUniqueId& other) const; |
29 | |
30 | at::IValue toIValue() const; |
31 | static GloballyUniqueId fromIValue(const at::IValue&); |
32 | |
33 | struct Hash { |
34 | size_t operator()(const GloballyUniqueId& key) const { |
35 | return (uint64_t(key.createdOn_) << kLocalIdBits) | key.localId_; |
36 | } |
37 | }; |
38 | |
39 | static constexpr int kLocalIdBits = 48; |
40 | |
41 | const worker_id_t createdOn_; |
42 | const local_id_t localId_; |
43 | }; |
44 | |
45 | TORCH_API std::ostream& operator<<( |
46 | std::ostream& os, |
47 | const GloballyUniqueId& globalId); |
48 | |
49 | using RRefId = GloballyUniqueId; |
50 | using ForkId = GloballyUniqueId; |
51 | using ProfilingId = GloballyUniqueId; |
52 | |
53 | struct TORCH_API SerializedPyObj final { |
54 | SerializedPyObj(std::string&& payload, std::vector<at::Tensor>&& tensors) |
55 | : payload_(std::move(payload)), tensors_(std::move(tensors)) {} |
56 | |
57 | std::vector<at::IValue> toIValues() &&; |
58 | static SerializedPyObj fromIValues(std::vector<at::IValue> value); |
59 | |
60 | std::string payload_; |
61 | std::vector<at::Tensor> tensors_; |
62 | }; |
63 | |
64 | } // namespace rpc |
65 | } // namespace distributed |
66 | } // namespace torch |
67 |