1 | #include <ATen/ThreadLocalState.h> |
2 | |
3 | #include <torch/csrc/distributed/c10d/Work.hpp> |
4 | #include <utility> |
5 | |
6 | namespace c10d { |
7 | |
8 | Work::Work( |
9 | int rank, |
10 | OpType opType, |
11 | const char* profilingTitle, |
12 | const c10::optional<std::vector<at::Tensor>>& inputTensors) |
13 | : rank_(rank), opType_(opType) { |
14 | if (profilingTitle != nullptr) { |
15 | auto recordingFunction = |
16 | std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE); |
17 | if (recordingFunction->isActive()) { |
18 | // Work events follow a future like pattern and can potentially be marked |
19 | // as complete by different threads, so explicitly set as async event. |
20 | recordingFunction->_setAsync(); |
21 | // Passing input tensor to recordFunction allows for shape information in |
22 | // profiling output. |
23 | std::vector<c10::IValue> inputs; |
24 | if (inputTensors) { |
25 | inputs.reserve(inputTensors->size()); |
26 | for (const auto& tensor : *inputTensors) { |
27 | inputs.emplace_back(tensor); |
28 | } |
29 | } |
30 | recordingFunction->before( |
31 | profilingTitle, |
32 | c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size())); |
33 | std::function<void()> end_handler = [recordingFunction]() { |
34 | recordingFunction->end(); |
35 | }; |
36 | recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); |
37 | } |
38 | } |
39 | } |
40 | |
41 | OpType Work::retrieveOpType() { |
42 | return opType_; |
43 | } |
44 | |
45 | Work::~Work() = default; |
46 | |
47 | bool Work::isCompleted() { |
48 | std::lock_guard<std::mutex> lock(mutex_); |
49 | return completed_; |
50 | } |
51 | |
52 | bool Work::isSuccess() const { |
53 | std::lock_guard<std::mutex> lock(mutex_); |
54 | return !exception_; |
55 | } |
56 | |
57 | std::exception_ptr Work::exception() const { |
58 | std::lock_guard<std::mutex> lock(mutex_); |
59 | return exception_; |
60 | } |
61 | |
62 | int Work::sourceRank() const { |
63 | TORCH_CHECK( |
64 | false, |
65 | "sourceRank() may only be called on work objects " |
66 | "that correspond to a recv or recv-from-any call." ); |
67 | } |
68 | |
69 | std::vector<at::Tensor> Work::result() { |
70 | TORCH_CHECK(false, "result() not implemented." ); |
71 | } |
72 | |
73 | void Work::synchronize() {} |
74 | |
75 | bool Work::wait(std::chrono::milliseconds timeout) { |
76 | std::unique_lock<std::mutex> lock(mutex_); |
77 | if (timeout == kNoTimeout) { |
78 | // This waits without a timeout. |
79 | cv_.wait(lock, [&] { return completed_; }); |
80 | } else { |
81 | // Waits for the user-provided timeout. |
82 | cv_.wait_for(lock, timeout, [&] { return completed_; }); |
83 | if (!completed_) { |
84 | // Throw exception if the wait operation timed out and the work was not |
85 | // completed. |
86 | TORCH_CHECK(false, "Operation timed out!" ); |
87 | } |
88 | } |
89 | if (exception_) { |
90 | std::rethrow_exception(exception_); |
91 | } |
92 | synchronize(); |
93 | // Always return true, because abort API is not implemented. |
94 | return true; |
95 | } |
96 | |
97 | void Work::abort() { |
98 | TORCH_CHECK(false, "Work::abort not implemented." ); |
99 | } |
100 | |
101 | c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() { |
102 | TORCH_CHECK(false, "Work::getFuture not implemented." ) |
103 | } |
104 | |
105 | void Work::finish(std::exception_ptr exception) { |
106 | std::unique_lock<std::mutex> lock(mutex_); |
107 | completed_ = true; |
108 | exception_ = exception; |
109 | if (recordFunctionEndCallback_) { |
110 | recordFunctionEndCallback_(); |
111 | recordFunctionEndCallback_ = nullptr; |
112 | } |
113 | lock.unlock(); |
114 | cv_.notify_all(); |
115 | } |
116 | |
117 | void Work::finishAndThrow(std::exception_ptr exception) { |
118 | std::unique_lock<std::mutex> lock(mutex_); |
119 | completed_ = true; |
120 | exception_ = exception; |
121 | if (recordFunctionEndCallback_) { |
122 | recordFunctionEndCallback_(); |
123 | recordFunctionEndCallback_ = nullptr; |
124 | } |
125 | if (exception_) { |
126 | std::rethrow_exception(exception_); |
127 | } |
128 | } |
129 | |
130 | class FutureWrappingWork : public Work { |
131 | public: |
132 | FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut) |
133 | : Work(), _fut(std::move(fut)) {} |
134 | |
135 | ~FutureWrappingWork() override = default; |
136 | |
137 | bool isCompleted() override { |
138 | return _fut->completed(); |
139 | } |
140 | |
141 | bool isSuccess() const override { |
142 | return _fut->hasValue(); |
143 | } |
144 | |
145 | std::exception_ptr exception() const override { |
146 | return _fut->exception_ptr(); |
147 | } |
148 | |
149 | int sourceRank() const override { |
150 | TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented" ); |
151 | } |
152 | |
153 | std::vector<at::Tensor> result() override { |
154 | return _fut->value().toPyObjectHolder()->extractTensors(); |
155 | } |
156 | |
157 | bool wait(std::chrono::milliseconds timeout) override { |
158 | // FIXME |
159 | TORCH_CHECK( |
160 | timeout == kNoTimeout, |
161 | "FutureWrappingWork::wait() with finite timeout not implemented" ); |
162 | _fut->wait(); |
163 | return true; |
164 | } |
165 | |
166 | void abort() override { |
167 | TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented" ); |
168 | } |
169 | |
170 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { |
171 | return _fut; |
172 | } |
173 | |
174 | private: |
175 | c10::intrusive_ptr<c10::ivalue::Future> _fut; |
176 | }; |
177 | |
178 | c10::intrusive_ptr<Work> Work::create_from_future( |
179 | c10::intrusive_ptr<c10::ivalue::Future> future) { |
180 | return c10::make_intrusive<FutureWrappingWork>(future); |
181 | } |
182 | |
183 | } // namespace c10d |
184 | |