1#pragma once
2
3#include <ATen/ATen.h>
4#include <stdexcept>
5#include <vector>
6
7constexpr auto kNoTimeout = std::chrono::milliseconds(0);
8
9namespace c10d {
10
11constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY";
12
13enum class OpType : std::uint8_t {
14 BROADCAST = 0,
15 ALLREDUCE = 1,
16 ALLREDUCE_COALESCED = 2,
17 REDUCE = 3,
18 ALLGATHER = 4,
19 _ALLGATHER_BASE = 5,
20 ALLGATHER_COALESCED = 6,
21 GATHER = 7,
22 SCATTER = 8,
23 REDUCE_SCATTER = 9,
24 ALLTOALL_BASE = 10,
25 ALLTOALL = 11,
26 SEND = 12,
27 RECV = 13,
28 RECVANYSOURCE = 14,
29 BARRIER = 15,
30 _REDUCE_SCATTER_BASE = 16,
31 UNKNOWN = 100,
32};
33
34// Converts OpType to human readable string.
35TORCH_API std::string opTypeToString(OpType opType);
36
37// Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE)
38TORCH_API bool isP2POp(OpType opType, bool batchP2P = false);
39
40// Please do not use Work API, it is going away, to be
41// replaced by ivalue::Future.
42// Python binding for this class might change, please do not assume
43// this will be bound using pybind.
44class TORCH_API Work : public torch::CustomClassHolder {
45 public:
46 Work(
47 int rank = -1,
48 OpType opType = OpType::UNKNOWN,
49 const char* profilingTitle = nullptr,
50 const c10::optional<std::vector<at::Tensor>>& inputTensors =
51 c10::nullopt);
52
53 ~Work() override;
54
55 // Checks if request has completed. Non-blocking operation.
56 virtual bool isCompleted();
57
58 // Returns if the work completed successfully.
59 // If false, the exception function can be called to get details.
60 virtual bool isSuccess() const;
61
62 // Returns exception if isSuccess() returned false.
63 virtual std::exception_ptr exception() const;
64
65 // Returns source rank if this objects represents a recv-from-any.
66 virtual int sourceRank() const;
67
68 // Returns result tensors, if applicable.
69 // If work is not supposed to have result, we return empty list.
70 virtual std::vector<at::Tensor> result();
71
72 // Ensures that operations on the output tensors that are invoked
73 // after this function returns are correctly sequenced after the
74 // asynchronous completion of this work.
75 //
76 // For CUDA tensors, it inserts stream synchronization such that
77 // the streams of the caller wait for completion of the
78 // asynchronous operations on the destination tensors.
79 //
80 // For CPU tensors, it is currently a nop.
81 //
82 // This function should only be used if the caller polls for
83 // completion through the `isCompleted` function, it has returned
84 // true, and the `isSuccess` function also has returned true.
85 //
86 virtual void synchronize();
87
88 // Waits until request completes. Blocking operation.
89 // Throws if the work completed with an exception.
90 // Returns false if the work is aborted.
91 // Otherwise, it always returns true, indicating the work is completed.
92 //
93 // Functionally equivalent to:
94 //
95 // while (!isCompleted()) { /* nop */ }
96 // auto success = isSuccess();
97 // if (!success) { std::rethrow_exception(exception()); }
98 // return success;
99 //
100 virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
101
102 virtual void abort();
103
104 // Returns a Future object that will be associated with the completion of
105 // work. Only NCCL backend is currently supported.
106 virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
107
108 OpType retrieveOpType();
109
110 static c10::intrusive_ptr<Work> create_from_future(
111 c10::intrusive_ptr<c10::ivalue::Future>);
112
113 protected:
114 // Completes the work object and optionally sets the exception in a
115 // thread-safe manner. Notifies all waiting condition variables as well.
116 void finish(std::exception_ptr exception = nullptr);
117
118 // Similar to finish, but throws an exception if one is already set or
119 // provided by the user.
120 void finishAndThrow(std::exception_ptr exception);
121
122 mutable std::mutex mutex_;
123 std::condition_variable cv_;
124 bool completed_ = false;
125 std::exception_ptr exception_;
126
127 // Current rank of the node.
128 const int rank_;
129
130 // Operation type that this work object refers to.
131 OpType opType_;
132
133 // When profiling, the callback to record end of operation event. This
134 // callback needs to be called when collective operation is complete.
135 std::function<void()> recordFunctionEndCallback_;
136};
137
138} // namespace c10d
139