1#pragma once
2
3#ifdef USE_TENSORPIPE
4
5#include <torch/csrc/distributed/rpc/utils.h>
6
7namespace tensorpipe {
8class Message;
9class Allocation;
10class Descriptor;
11} // namespace tensorpipe
12
13namespace torch {
14namespace distributed {
15namespace rpc {
16
17TORCH_API const c10::Stream& getStreamForDevice(
18 const std::vector<c10::Stream>& streams,
19 const c10::Device& device);
20
21// Inspired by c10/core/impl/DeviceGuardImplInterface.h.
22
23class TensorpipeDeviceTypeConverter {
24 public:
25 // Ideally we'd want this to also return a tensorpipe::Message::Tensor object
26 // but we cannot forward-declare that class (because it's nested), and we
27 // cannot include the TensorPipe headers because it's a private dependency.
28 // Thus we bend over backwards and entrust this method with appending that
29 // object to the `tensors` field of the tensorpipe::Message object we pass.
30 virtual c10::optional<std::vector<char>> prepareTensorForSending(
31 const c10::Storage& storage,
32 const std::vector<c10::Stream>& streams,
33 tensorpipe::Message& message) const = 0;
34
35 // Same as above: this method cannot return a tensorpipe::Allocation::Tensor,
36 // thus it appends it to the `tensors` field of the tensorpipe::Allocation.
37 virtual at::DataPtr allocateTensorForReceiving(
38 int deviceIndex,
39 size_t length,
40 const std::vector<c10::Stream>& streams,
41 tensorpipe::Allocation& allocation) const = 0;
42
43 virtual ~TensorpipeDeviceTypeConverter() = default;
44};
45
46extern TORCH_API std::array<
47 std::atomic<const TensorpipeDeviceTypeConverter*>,
48 static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
49 device_type_converter_registry;
50
51class TORCH_API TensorpipeDeviceTypeConverterRegistrar {
52 public:
53 TensorpipeDeviceTypeConverterRegistrar(
54 DeviceType,
55 const TensorpipeDeviceTypeConverter*);
56};
57
58#define C10_REGISTER_TENSORPIPE_DEVICE_TYPE_CONVERTER( \
59 DevType, TensorpipeDeviceTypeConverter) \
60 static ::torch::distributed::rpc::TensorpipeDeviceTypeConverterRegistrar \
61 C10_ANONYMOUS_VARIABLE(g_##DeviceType)( \
62 ::c10::DeviceType::DevType, new TensorpipeDeviceTypeConverter());
63
64inline const TensorpipeDeviceTypeConverter* getDeviceTypeConverter(
65 DeviceType type) {
66 return device_type_converter_registry[static_cast<size_t>(type)].load();
67}
68
69// A struct that holds pointers that keep alive all the memory that will be
70// accessed by TensorPipe during a write operation.
71struct TensorpipeWriteBuffers {
72 // Allocate on heap so pointers stay valid as we move the holder.
73 std::unique_ptr<MessageType> type;
74 std::unique_ptr<int64_t> id;
75 std::vector<char> payload;
76 std::vector<char> pickle;
77 // This contains the original tensors and the clones of the sparse tensors.
78 std::vector<torch::Tensor> tensors;
79 // This contains the copies of the data of the tensors that didn't own their
80 // memory, e.g., the ones created from torch::from_blob() with no deleter.
81 std::vector<std::vector<char>> copiedTensors;
82};
83
84// A struct that holds pointers that keep alive all the memory that will be
85// accessed by TensorPipe during a read operation.
86struct TensorpipeReadBuffers {
87 // Allocate on heap so pointers stay valid as we move the holder.
88 std::unique_ptr<MessageType> type;
89 std::unique_ptr<int64_t> id;
90 std::vector<char> payload;
91 std::vector<char> pickle;
92 std::vector<c10::DataPtr> tensors;
93};
94
95// Convert an RPC message into a TensorPipe message, plus a holder to all the
96// data that must be kept alive while the write is performed asynchronously.
97TORCH_API std::tuple<tensorpipe::Message, TensorpipeWriteBuffers>
98tensorpipeSerialize(
99 c10::intrusive_ptr<Message> rpcMessage,
100 std::vector<c10::Device> devices,
101 const std::vector<c10::Stream>& streams);
102
103// Allocate the buffers that will hold the incoming data. They will be managed
104// by the returned holder, which must be kept alive until the asynchronous read
105// has finished. Pointers to these buffers will be stored in the returned
106// tensorpipe::Allocation struct.
107TORCH_API std::pair<tensorpipe::Allocation, TensorpipeReadBuffers>
108tensorpipeAllocate(
109 const tensorpipe::Descriptor& tpDescriptor,
110 const std::vector<c10::Stream>& streams);
111
112// Convert a TensorPipe message back into an RPC message. This requires the data
113// to be available and can thus only be performed once the asynchronous read has
114// completed. The holder can be destroyed once this function returns.
115TORCH_API c10::intrusive_ptr<Message> tensorpipeDeserialize(
116 tensorpipe::Descriptor&& tpDescriptor,
117 TensorpipeReadBuffers&& holder);
118
119} // namespace rpc
120} // namespace distributed
121} // namespace torch
122
123#endif // USE_TENSORPIPE
124