1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/nccl/collective_communicator.h" |
17 | |
18 | #include "tensorflow/core/framework/cancellation.h" |
19 | |
20 | #if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) |
21 | |
22 | #include "absl/memory/memory.h" |
23 | #include "tensorflow/core/nccl/nccl_manager.h" |
24 | #include "tensorflow/core/platform/tracing.h" |
25 | #include "tensorflow/core/profiler/lib/traceme.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | class NcclCommunicator : public NcclCommunicatorInterface { |
30 | public: |
31 | string GenerateCommunicatorKey() override { |
32 | return nccl_manager_.GenerateCommunicatorKey(); |
33 | } |
34 | |
35 | void Enqueue(std::shared_ptr<CollectiveContext> col_ctx, |
36 | StatusCallback done) override; |
37 | |
38 | void StartAbort(const Status& s) override; |
39 | |
40 | private: |
41 | NcclManager nccl_manager_; |
42 | }; |
43 | |
44 | namespace { |
45 | Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) { |
46 | if (merge_op == "Add" ) { |
47 | *reduction_op = ncclSum; |
48 | return OkStatus(); |
49 | } else if (merge_op == "Mul" ) { |
50 | *reduction_op = ncclProd; |
51 | return OkStatus(); |
52 | } else if (merge_op == "Maximum" ) { |
53 | *reduction_op = ncclMax; |
54 | return OkStatus(); |
55 | } else if (merge_op == "Minimum" ) { |
56 | *reduction_op = ncclMin; |
57 | return OkStatus(); |
58 | } else { |
59 | return errors::Internal( |
60 | "Expected merge_op to be in [Add, Mul, Maximum, Minimum], found " , |
61 | merge_op); |
62 | } |
63 | } |
64 | |
65 | string NcclCollectiveKey(const string& exec_key, int step_id) { |
66 | return strings::StrCat(exec_key, ":" , step_id); |
67 | } |
68 | } // namespace |
69 | |
70 | std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator( |
71 | const ConfigProto& config) { |
72 | // Skip creating a NcclCommunicator if there are 0 GPUs configured. |
73 | const auto& device_count = config.device_count(); |
74 | auto item = device_count.find("GPU" ); |
75 | if (item != device_count.end() && item->second == 0) { |
76 | return nullptr; |
77 | } |
78 | return absl::make_unique<NcclCommunicator>(); |
79 | } |
80 | |
81 | void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, |
82 | StatusCallback done) { |
83 | const CollectiveParams* col_params = col_ctx->col_params.get(); |
84 | const int num_global_devices = col_params->group.group_size; |
85 | const int num_local_devices = col_params->group.num_devices_per_task.at( |
86 | col_params->group.members[col_params->default_rank].task); |
87 | const string nccl_collective_key = |
88 | NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id); |
89 | auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream(); |
90 | auto* gpu_info = |
91 | col_ctx->op_ctx->device()->tensorflow_accelerator_device_info(); |
92 | auto participant = absl::make_unique<NcclManager::Participant>( |
93 | compute_stream->parent(), compute_stream, gpu_info, col_ctx->input, |
94 | col_ctx->output, col_ctx->col_params->default_rank, |
95 | /*done_callback=*/nullptr); |
96 | CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager(); |
97 | if (cancel_mgr == nullptr) { |
98 | participant->done_callback = std::move(done); |
99 | } else { |
100 | CancellationToken cancel_token = cancel_mgr->get_cancellation_token(); |
101 | bool already_cancelled = |
102 | !cancel_mgr->RegisterCallback(cancel_token, [this]() { |
103 | nccl_manager_.StartAbort(errors::Cancelled("op cancelled" )); |
104 | nccl_manager_.Reset(); |
105 | }); |
106 | if (already_cancelled) { |
107 | done(errors::Cancelled("op cancelled" )); |
108 | return; |
109 | } |
110 | participant->done_callback = [cancel_mgr, cancel_token, |
111 | done = std::move(done)](const Status& s) { |
112 | // Do not block on deregistration since this can be invoked by |
113 | // NcclManager::StartAbort() in the cancellation callback. |
114 | cancel_mgr->TryDeregisterCallback(cancel_token); |
115 | done(s); |
116 | }; |
117 | } |
118 | NcclManager::Context context( |
119 | nccl_collective_key, num_local_devices, num_global_devices, |
120 | col_params->group.runtime_details.communicator_key, |
121 | col_params->source_rank); |
122 | VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type |
123 | << " num_tasks " << col_params->group.num_tasks << " current task " |
124 | << col_params->group.members[col_params->default_rank].task |
125 | << " num local devices " << num_local_devices |
126 | << " num global devices " << num_global_devices << " device " |
127 | << col_ctx->device_name << " instance " |
128 | << col_params->instance.instance_key; |
129 | // `AddTo*` performs consistency checks for the NCCL call and enqueues the |
130 | // `Participant` struct locally. When all local participants with this |
131 | // `nccl_collective_key` have called `AddToAllReduce` and |
132 | // `SignalMultiNodeReady`, all devices at this worker are ready to process |
133 | // this NCCL op. |
134 | // |
135 | // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this |
136 | // point, it synchronizes the NCCL stream with the compute stream, and then |
137 | // enqueues the NCCL kernel on the NCCL stream. |
138 | switch (col_params->instance.type) { |
139 | case REDUCTION_COLLECTIVE: { |
140 | ncclRedOp_t reduction_op; |
141 | Status s = |
142 | ReductionOp(col_params->merge_op->type_string(), &reduction_op); |
143 | if (!s.ok()) { |
144 | participant->done_callback(s); |
145 | return; |
146 | } |
147 | nccl_manager_.AddToAllReduce(std::move(participant), context, |
148 | reduction_op); |
149 | break; |
150 | } |
151 | case GATHER_COLLECTIVE: { |
152 | nccl_manager_.AddToAllGather(std::move(participant), context); |
153 | break; |
154 | } |
155 | case BROADCAST_COLLECTIVE: { |
156 | if (col_params->is_source) { |
157 | nccl_manager_.AddBroadcastSend(std::move(participant), context); |
158 | } else { |
159 | nccl_manager_.AddBroadcastRecv(std::move(participant), context); |
160 | } |
161 | break; |
162 | } |
163 | default: { |
164 | participant->done_callback(errors::Internal("Unexpected CollectiveType " , |
165 | col_params->instance.type)); |
166 | return; |
167 | } |
168 | } |
169 | // NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent |
170 | // deadlocks. In the current implementation, we define a deterministic |
171 | // sequential launch order between potentially concurrent collective instances |
172 | // by introducing control information during static graph analysis in |
173 | // graph/collective_order.cc. This can be either in the form of explicit |
174 | // control edges or via `wait_for` attribute on the collective op. |
175 | // |
176 | // The other end of the design spectrum would have a distinguished node |
177 | // dynamically signal the next collective to launch to all other participants. |
178 | // This has higher degree of runtime coordination, but it may be able to |
179 | // achieve better performance if the (arbitrary) static execution order |
180 | // assigned in the first approach turns out to not be good from a scheduling |
181 | // perspective. e.g. consider a graph in which c1, c2, and c3 are three |
182 | // concurrent collective instances, and the static ordering assigns c1 -> c2 |
183 | // -> c3. In practice, it could turn out that c3 is always ready to execute |
184 | // before c1 or c2. |
185 | { |
186 | // `WaitForDependencies` may block if the collective instances on which this |
187 | // op depends have not yet launched. When this function returns, this op is |
188 | // ready to go. |
189 | profiler::TraceMe activity("WaitForDependencies" , |
190 | profiler::TraceMeLevel::kInfo); |
191 | col_ctx->col_exec->WaitForDependencies(*col_params); |
192 | nccl_manager_.SignalMultiNodeReady(nccl_collective_key); |
193 | } |
194 | { |
195 | // When all devices at this worker have called `SignalMultiNodeReady`, the |
196 | // `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the |
197 | // implementation of `UnblockDependencies` keeps track of the number of |
198 | // devices that have launched. |
199 | profiler::TraceMe activity("Schedule" , profiler::TraceMeLevel::kInfo); |
200 | col_ctx->col_exec->UnblockDependencies(*col_params); |
201 | } |
202 | } |
203 | |
204 | void NcclCommunicator::StartAbort(const Status& s) { |
205 | nccl_manager_.StartAbort(s); |
206 | } |
207 | |
208 | } // namespace tensorflow |
209 | |
210 | #else |
211 | namespace tensorflow { |
212 | std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator( |
213 | const ConfigProto& config) { |
214 | return nullptr; |
215 | } |
216 | } // namespace tensorflow |
217 | #endif // TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) |
218 | |