1#include <torch/csrc/distributed/rpc/message.h>
2#include <torch/custom_class.h>
3
4namespace torch {
5namespace distributed {
6namespace rpc {
7
8Message::Message() = default;
9
10Message::Message(
11 std::vector<char>&& payload,
12 std::vector<torch::Tensor>&& tensors,
13 MessageType type)
14 : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
15
16Message::Message(
17 std::vector<char>&& payload,
18 std::vector<torch::Tensor>&& tensors,
19 MessageType type,
20 int64_t id)
21 : payload_(std::move(payload)),
22 tensors_(std::move(tensors)),
23 type_(type),
24 id_(id) {}
25
26std::vector<char>&& Message::movePayload() && {
27 return std::move(payload_);
28}
29
30std::vector<char>& Message::payload() {
31 return payload_;
32}
33
34const std::vector<char>& Message::payload() const {
35 return payload_;
36}
37
38std::vector<torch::Tensor>&& Message::moveTensors() && {
39 return std::move(tensors_);
40}
41
42std::vector<torch::Tensor>& Message::tensors() {
43 return tensors_;
44}
45
46const std::vector<torch::Tensor>& Message::tensors() const {
47 return tensors_;
48}
49
50MessageType Message::type() const {
51 return type_;
52}
53
54bool Message::isRequest() const {
55 return MessageTypeFlags::REQUEST_TYPE & type_;
56}
57
58bool Message::isResponse() const {
59 return MessageTypeFlags::RESPONSE_TYPE & type_;
60}
61
62int64_t Message::id() const {
63 return id_;
64}
65
66void Message::setId(int64_t id) {
67 id_ = id;
68}
69
70std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> Message::getStorages()
71 const {
72 // Sparse tensors do not have storage. Instead, a sparse tensor
73 // contains two tensors indices and values, and both contain storage.
74 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages;
75 storages.reserve(2 * tensors_.size());
76 for (const auto& tensor : tensors_) {
77 if (tensor.is_sparse()) {
78 storages.emplace_back(tensor._indices().storage().getWeakStorageImpl());
79 storages.emplace_back(tensor._values().storage().getWeakStorageImpl());
80 } else {
81 storages.emplace_back(tensor.storage().getWeakStorageImpl());
82 }
83 }
84 return storages;
85}
86
87c10::intrusive_ptr<Message> createExceptionResponse(
88 const std::exception& e,
89 int64_t id) {
90 return createExceptionResponse(e.what(), id);
91}
92
93c10::intrusive_ptr<Message> createExceptionResponse(
94 const std::string& exceptionStr,
95 int64_t id) {
96 std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
97 return c10::make_intrusive<Message>(
98 std::move(payload),
99 std::vector<torch::Tensor>(),
100 MessageType::EXCEPTION,
101 id);
102}
103
104namespace {
105
106// NB: need to call torch::class_ to register Message in the map returned by
107// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
108// an IValue.
109// NB: add this line here instead of in rpc/init.cpp because 1) we have C++
110// only tests that won't run rpc/init.cpp; 2) Message is not meant to be
111// visible from Python.
112static const auto message = torch::class_<Message>("rpc", "_Message");
113
114} // namespace
115
116} // namespace rpc
117} // namespace distributed
118} // namespace torch
119