1#include <ATen/ThreadLocalState.h>
2
3#include <torch/csrc/distributed/c10d/Work.hpp>
4#include <utility>
5
6namespace c10d {
7
8Work::Work(
9 int rank,
10 OpType opType,
11 const char* profilingTitle,
12 const c10::optional<std::vector<at::Tensor>>& inputTensors)
13 : rank_(rank), opType_(opType) {
14 if (profilingTitle != nullptr) {
15 auto recordingFunction =
16 std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
17 if (recordingFunction->isActive()) {
18 // Work events follow a future like pattern and can potentially be marked
19 // as complete by different threads, so explicitly set as async event.
20 recordingFunction->_setAsync();
21 // Passing input tensor to recordFunction allows for shape information in
22 // profiling output.
23 std::vector<c10::IValue> inputs;
24 if (inputTensors) {
25 inputs.reserve(inputTensors->size());
26 for (const auto& tensor : *inputTensors) {
27 inputs.emplace_back(tensor);
28 }
29 }
30 recordingFunction->before(
31 profilingTitle,
32 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
33 std::function<void()> end_handler = [recordingFunction]() {
34 recordingFunction->end();
35 };
36 recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
37 }
38 }
39}
40
41OpType Work::retrieveOpType() {
42 return opType_;
43}
44
45Work::~Work() = default;
46
47bool Work::isCompleted() {
48 std::lock_guard<std::mutex> lock(mutex_);
49 return completed_;
50}
51
52bool Work::isSuccess() const {
53 std::lock_guard<std::mutex> lock(mutex_);
54 return !exception_;
55}
56
57std::exception_ptr Work::exception() const {
58 std::lock_guard<std::mutex> lock(mutex_);
59 return exception_;
60}
61
62int Work::sourceRank() const {
63 TORCH_CHECK(
64 false,
65 "sourceRank() may only be called on work objects "
66 "that correspond to a recv or recv-from-any call.");
67}
68
69std::vector<at::Tensor> Work::result() {
70 TORCH_CHECK(false, "result() not implemented.");
71}
72
73void Work::synchronize() {}
74
75bool Work::wait(std::chrono::milliseconds timeout) {
76 std::unique_lock<std::mutex> lock(mutex_);
77 if (timeout == kNoTimeout) {
78 // This waits without a timeout.
79 cv_.wait(lock, [&] { return completed_; });
80 } else {
81 // Waits for the user-provided timeout.
82 cv_.wait_for(lock, timeout, [&] { return completed_; });
83 if (!completed_) {
84 // Throw exception if the wait operation timed out and the work was not
85 // completed.
86 TORCH_CHECK(false, "Operation timed out!");
87 }
88 }
89 if (exception_) {
90 std::rethrow_exception(exception_);
91 }
92 synchronize();
93 // Always return true, because abort API is not implemented.
94 return true;
95}
96
97void Work::abort() {
98 TORCH_CHECK(false, "Work::abort not implemented.");
99}
100
101c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
102 TORCH_CHECK(false, "Work::getFuture not implemented.")
103}
104
105void Work::finish(std::exception_ptr exception) {
106 std::unique_lock<std::mutex> lock(mutex_);
107 completed_ = true;
108 exception_ = exception;
109 if (recordFunctionEndCallback_) {
110 recordFunctionEndCallback_();
111 recordFunctionEndCallback_ = nullptr;
112 }
113 lock.unlock();
114 cv_.notify_all();
115}
116
117void Work::finishAndThrow(std::exception_ptr exception) {
118 std::unique_lock<std::mutex> lock(mutex_);
119 completed_ = true;
120 exception_ = exception;
121 if (recordFunctionEndCallback_) {
122 recordFunctionEndCallback_();
123 recordFunctionEndCallback_ = nullptr;
124 }
125 if (exception_) {
126 std::rethrow_exception(exception_);
127 }
128}
129
130class FutureWrappingWork : public Work {
131 public:
132 FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
133 : Work(), _fut(std::move(fut)) {}
134
135 ~FutureWrappingWork() override = default;
136
137 bool isCompleted() override {
138 return _fut->completed();
139 }
140
141 bool isSuccess() const override {
142 return _fut->hasValue();
143 }
144
145 std::exception_ptr exception() const override {
146 return _fut->exception_ptr();
147 }
148
149 int sourceRank() const override {
150 TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
151 }
152
153 std::vector<at::Tensor> result() override {
154 return _fut->value().toPyObjectHolder()->extractTensors();
155 }
156
157 bool wait(std::chrono::milliseconds timeout) override {
158 // FIXME
159 TORCH_CHECK(
160 timeout == kNoTimeout,
161 "FutureWrappingWork::wait() with finite timeout not implemented");
162 _fut->wait();
163 return true;
164 }
165
166 void abort() override {
167 TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
168 }
169
170 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
171 return _fut;
172 }
173
174 private:
175 c10::intrusive_ptr<c10::ivalue::Future> _fut;
176};
177
178c10::intrusive_ptr<Work> Work::create_from_future(
179 c10::intrusive_ptr<c10::ivalue::Future> future) {
180 return c10::make_intrusive<FutureWrappingWork>(future);
181}
182
183} // namespace c10d
184