1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/nccl/collective_communicator.h"
17
18#include "tensorflow/core/framework/cancellation.h"
19
20#if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
21
22#include "absl/memory/memory.h"
23#include "tensorflow/core/nccl/nccl_manager.h"
24#include "tensorflow/core/platform/tracing.h"
25#include "tensorflow/core/profiler/lib/traceme.h"
26
27namespace tensorflow {
28
29class NcclCommunicator : public NcclCommunicatorInterface {
30 public:
31 string GenerateCommunicatorKey() override {
32 return nccl_manager_.GenerateCommunicatorKey();
33 }
34
35 void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
36 StatusCallback done) override;
37
38 void StartAbort(const Status& s) override;
39
40 private:
41 NcclManager nccl_manager_;
42};
43
44namespace {
45Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
46 if (merge_op == "Add") {
47 *reduction_op = ncclSum;
48 return OkStatus();
49 } else if (merge_op == "Mul") {
50 *reduction_op = ncclProd;
51 return OkStatus();
52 } else if (merge_op == "Maximum") {
53 *reduction_op = ncclMax;
54 return OkStatus();
55 } else if (merge_op == "Minimum") {
56 *reduction_op = ncclMin;
57 return OkStatus();
58 } else {
59 return errors::Internal(
60 "Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
61 merge_op);
62 }
63}
64
65string NcclCollectiveKey(const string& exec_key, int step_id) {
66 return strings::StrCat(exec_key, ":", step_id);
67}
68} // namespace
69
70std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator(
71 const ConfigProto& config) {
72 // Skip creating a NcclCommunicator if there are 0 GPUs configured.
73 const auto& device_count = config.device_count();
74 auto item = device_count.find("GPU");
75 if (item != device_count.end() && item->second == 0) {
76 return nullptr;
77 }
78 return absl::make_unique<NcclCommunicator>();
79}
80
81void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
82 StatusCallback done) {
83 const CollectiveParams* col_params = col_ctx->col_params.get();
84 const int num_global_devices = col_params->group.group_size;
85 const int num_local_devices = col_params->group.num_devices_per_task.at(
86 col_params->group.members[col_params->default_rank].task);
87 const string nccl_collective_key =
88 NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
89 auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
90 auto* gpu_info =
91 col_ctx->op_ctx->device()->tensorflow_accelerator_device_info();
92 auto participant = absl::make_unique<NcclManager::Participant>(
93 compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
94 col_ctx->output, col_ctx->col_params->default_rank,
95 /*done_callback=*/nullptr);
96 CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
97 if (cancel_mgr == nullptr) {
98 participant->done_callback = std::move(done);
99 } else {
100 CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
101 bool already_cancelled =
102 !cancel_mgr->RegisterCallback(cancel_token, [this]() {
103 nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
104 nccl_manager_.Reset();
105 });
106 if (already_cancelled) {
107 done(errors::Cancelled("op cancelled"));
108 return;
109 }
110 participant->done_callback = [cancel_mgr, cancel_token,
111 done = std::move(done)](const Status& s) {
112 // Do not block on deregistration since this can be invoked by
113 // NcclManager::StartAbort() in the cancellation callback.
114 cancel_mgr->TryDeregisterCallback(cancel_token);
115 done(s);
116 };
117 }
118 NcclManager::Context context(
119 nccl_collective_key, num_local_devices, num_global_devices,
120 col_params->group.runtime_details.communicator_key,
121 col_params->source_rank);
122 VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
123 << " num_tasks " << col_params->group.num_tasks << " current task "
124 << col_params->group.members[col_params->default_rank].task
125 << " num local devices " << num_local_devices
126 << " num global devices " << num_global_devices << " device "
127 << col_ctx->device_name << " instance "
128 << col_params->instance.instance_key;
129 // `AddTo*` performs consistency checks for the NCCL call and enqueues the
130 // `Participant` struct locally. When all local participants with this
131 // `nccl_collective_key` have called `AddToAllReduce` and
132 // `SignalMultiNodeReady`, all devices at this worker are ready to process
133 // this NCCL op.
134 //
135 // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
136 // point, it synchronizes the NCCL stream with the compute stream, and then
137 // enqueues the NCCL kernel on the NCCL stream.
138 switch (col_params->instance.type) {
139 case REDUCTION_COLLECTIVE: {
140 ncclRedOp_t reduction_op;
141 Status s =
142 ReductionOp(col_params->merge_op->type_string(), &reduction_op);
143 if (!s.ok()) {
144 participant->done_callback(s);
145 return;
146 }
147 nccl_manager_.AddToAllReduce(std::move(participant), context,
148 reduction_op);
149 break;
150 }
151 case GATHER_COLLECTIVE: {
152 nccl_manager_.AddToAllGather(std::move(participant), context);
153 break;
154 }
155 case BROADCAST_COLLECTIVE: {
156 if (col_params->is_source) {
157 nccl_manager_.AddBroadcastSend(std::move(participant), context);
158 } else {
159 nccl_manager_.AddBroadcastRecv(std::move(participant), context);
160 }
161 break;
162 }
163 default: {
164 participant->done_callback(errors::Internal("Unexpected CollectiveType ",
165 col_params->instance.type));
166 return;
167 }
168 }
169 // NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
170 // deadlocks. In the current implementation, we define a deterministic
171 // sequential launch order between potentially concurrent collective instances
172 // by introducing control information during static graph analysis in
173 // graph/collective_order.cc. This can be either in the form of explicit
174 // control edges or via `wait_for` attribute on the collective op.
175 //
176 // The other end of the design spectrum would have a distinguished node
177 // dynamically signal the next collective to launch to all other participants.
178 // This has higher degree of runtime coordination, but it may be able to
179 // achieve better performance if the (arbitrary) static execution order
180 // assigned in the first approach turns out to not be good from a scheduling
181 // perspective. e.g. consider a graph in which c1, c2, and c3 are three
182 // concurrent collective instances, and the static ordering assigns c1 -> c2
183 // -> c3. In practice, it could turn out that c3 is always ready to execute
184 // before c1 or c2.
185 {
186 // `WaitForDependencies` may block if the collective instances on which this
187 // op depends have not yet launched. When this function returns, this op is
188 // ready to go.
189 profiler::TraceMe activity("WaitForDependencies",
190 profiler::TraceMeLevel::kInfo);
191 col_ctx->col_exec->WaitForDependencies(*col_params);
192 nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
193 }
194 {
195 // When all devices at this worker have called `SignalMultiNodeReady`, the
196 // `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
197 // implementation of `UnblockDependencies` keeps track of the number of
198 // devices that have launched.
199 profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
200 col_ctx->col_exec->UnblockDependencies(*col_params);
201 }
202}
203
204void NcclCommunicator::StartAbort(const Status& s) {
205 nccl_manager_.StartAbort(s);
206}
207
208} // namespace tensorflow
209
210#else
211namespace tensorflow {
212std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator(
213 const ConfigProto& config) {
214 return nullptr;
215}
216} // namespace tensorflow
217#endif // TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
218