1 | #pragma once |
2 | |
3 | #ifdef USE_TENSORPIPE |
4 | |
5 | #include <torch/csrc/distributed/rpc/utils.h> |
6 | |
7 | namespace tensorpipe { |
8 | class Message; |
9 | class Allocation; |
10 | class Descriptor; |
11 | } // namespace tensorpipe |
12 | |
13 | namespace torch { |
14 | namespace distributed { |
15 | namespace rpc { |
16 | |
17 | TORCH_API const c10::Stream& getStreamForDevice( |
18 | const std::vector<c10::Stream>& streams, |
19 | const c10::Device& device); |
20 | |
21 | // Inspired by c10/core/impl/DeviceGuardImplInterface.h. |
22 | |
23 | class TensorpipeDeviceTypeConverter { |
24 | public: |
25 | // Ideally we'd want this to also return a tensorpipe::Message::Tensor object |
26 | // but we cannot forward-declare that class (because it's nested), and we |
27 | // cannot include the TensorPipe headers because it's a private dependency. |
28 | // Thus we bend over backwards and entrust this method with appending that |
29 | // object to the `tensors` field of the tensorpipe::Message object we pass. |
30 | virtual c10::optional<std::vector<char>> prepareTensorForSending( |
31 | const c10::Storage& storage, |
32 | const std::vector<c10::Stream>& streams, |
33 | tensorpipe::Message& message) const = 0; |
34 | |
35 | // Same as above: this method cannot return a tensorpipe::Allocation::Tensor, |
36 | // thus it appends it to the `tensors` field of the tensorpipe::Allocation. |
37 | virtual at::DataPtr allocateTensorForReceiving( |
38 | int deviceIndex, |
39 | size_t length, |
40 | const std::vector<c10::Stream>& streams, |
41 | tensorpipe::Allocation& allocation) const = 0; |
42 | |
43 | virtual ~TensorpipeDeviceTypeConverter() = default; |
44 | }; |
45 | |
46 | extern TORCH_API std::array< |
47 | std::atomic<const TensorpipeDeviceTypeConverter*>, |
48 | static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> |
49 | device_type_converter_registry; |
50 | |
51 | class TORCH_API TensorpipeDeviceTypeConverterRegistrar { |
52 | public: |
53 | TensorpipeDeviceTypeConverterRegistrar( |
54 | DeviceType, |
55 | const TensorpipeDeviceTypeConverter*); |
56 | }; |
57 | |
58 | #define C10_REGISTER_TENSORPIPE_DEVICE_TYPE_CONVERTER( \ |
59 | DevType, TensorpipeDeviceTypeConverter) \ |
60 | static ::torch::distributed::rpc::TensorpipeDeviceTypeConverterRegistrar \ |
61 | C10_ANONYMOUS_VARIABLE(g_##DeviceType)( \ |
62 | ::c10::DeviceType::DevType, new TensorpipeDeviceTypeConverter()); |
63 | |
64 | inline const TensorpipeDeviceTypeConverter* getDeviceTypeConverter( |
65 | DeviceType type) { |
66 | return device_type_converter_registry[static_cast<size_t>(type)].load(); |
67 | } |
68 | |
69 | // A struct that holds pointers that keep alive all the memory that will be |
70 | // accessed by TensorPipe during a write operation. |
71 | struct TensorpipeWriteBuffers { |
72 | // Allocate on heap so pointers stay valid as we move the holder. |
73 | std::unique_ptr<MessageType> type; |
74 | std::unique_ptr<int64_t> id; |
75 | std::vector<char> payload; |
76 | std::vector<char> pickle; |
77 | // This contains the original tensors and the clones of the sparse tensors. |
78 | std::vector<torch::Tensor> tensors; |
79 | // This contains the copies of the data of the tensors that didn't own their |
80 | // memory, e.g., the ones created from torch::from_blob() with no deleter. |
81 | std::vector<std::vector<char>> copiedTensors; |
82 | }; |
83 | |
84 | // A struct that holds pointers that keep alive all the memory that will be |
85 | // accessed by TensorPipe during a read operation. |
86 | struct TensorpipeReadBuffers { |
87 | // Allocate on heap so pointers stay valid as we move the holder. |
88 | std::unique_ptr<MessageType> type; |
89 | std::unique_ptr<int64_t> id; |
90 | std::vector<char> payload; |
91 | std::vector<char> pickle; |
92 | std::vector<c10::DataPtr> tensors; |
93 | }; |
94 | |
95 | // Convert an RPC message into a TensorPipe message, plus a holder to all the |
96 | // data that must be kept alive while the write is performed asynchronously. |
97 | TORCH_API std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> |
98 | tensorpipeSerialize( |
99 | c10::intrusive_ptr<Message> rpcMessage, |
100 | std::vector<c10::Device> devices, |
101 | const std::vector<c10::Stream>& streams); |
102 | |
103 | // Allocate the buffers that will hold the incoming data. They will be managed |
104 | // by the returned holder, which must be kept alive until the asynchronous read |
105 | // has finished. Pointers to these buffers will be stored in the returned |
106 | // tensorpipe::Allocation struct. |
107 | TORCH_API std::pair<tensorpipe::Allocation, TensorpipeReadBuffers> |
108 | tensorpipeAllocate( |
109 | const tensorpipe::Descriptor& tpDescriptor, |
110 | const std::vector<c10::Stream>& streams); |
111 | |
112 | // Convert a TensorPipe message back into an RPC message. This requires the data |
113 | // to be available and can thus only be performed once the asynchronous read has |
114 | // completed. The holder can be destroyed once this function returns. |
115 | TORCH_API c10::intrusive_ptr<Message> tensorpipeDeserialize( |
116 | tensorpipe::Descriptor&& tpDescriptor, |
117 | TensorpipeReadBuffers&& holder); |
118 | |
119 | } // namespace rpc |
120 | } // namespace distributed |
121 | } // namespace torch |
122 | |
123 | #endif // USE_TENSORPIPE |
124 | |