1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <string>
16#include <utility>
17
18#include "absl/strings/str_cat.h"
19#include "absl/strings/str_format.h"
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/collective.h"
22#include "tensorflow/core/framework/device_attributes.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/op_requires.h"
26#include "tensorflow/core/framework/resource_handle.h"
27#include "tensorflow/core/framework/resource_mgr.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/framework/types.pb.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/platform/errors.h"
33#include "tensorflow/core/platform/refcount.h"
34#include "tensorflow/core/platform/status.h"
35#include "tensorflow/core/platform/types.h"
36
37namespace tensorflow {
38
39namespace {
40
41static string CollectiveKey(OpKernelContext* ctx, int32_t group_key,
42 int32_t instance_key) {
43 return strings::StrCat(group_key, ":", instance_key, ":",
44 ctx->frame_iter().frame_id, ":",
45 ctx->frame_iter().iter_id);
46}
47
48static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
49 const string& name,
50 NodeDef* sub_node) {
51 std::unique_ptr<OpKernel> k;
52 if (name.empty() || name == "Id") return k;
53 sub_node->set_name(name);
54 sub_node->set_op(name);
55 Status status;
56 k = CreateOpKernel(c->device_type(), c->device(),
57 c->device()->GetAllocator(AllocatorAttributes()),
58 *sub_node, c->graph_def_version(), &status);
59 if (!status.ok()) {
60 c->CtxFailureWithWarning(errors::Internal(
61 "Failed to build OpKernel for ", name, " : ", status.error_message()));
62 }
63 return k;
64}
65
66class CollectiveOpV1Kernel : public AsyncOpKernel {
67 public:
68 explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
69 : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
70
71 ~CollectiveOpV1Kernel() override { col_params_->Unref(); }
72
73 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
74 CollectiveExecutor* col_exec = c->collective_executor();
75 OP_REQUIRES_ASYNC(
76 c, col_exec,
77 errors::Internal(
78 "Failed to get CollectiveExecutor from OpKernelContext for Op ",
79 name_),
80 done);
81 const CancellationToken token =
82 c->cancellation_manager()->get_cancellation_token();
83 const bool already_cancelled =
84 !c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
85 // We must call StartAbort() within the callback. StartAbort() relies
86 // on resources that may be deallocated if all execution of a graph is
87 // finished.
88 col_exec->StartAbort(errors::Cancelled("op cancelled"));
89 });
90 OP_REQUIRES_ASYNC(c, !already_cancelled,
91 errors::Cancelled("op cancelled ", name_), done);
92
93 auto deregister_and_done = [c, token, done = std::move(done)]() {
94 // Once done() is called, StartAbort() won't have any effect, so we
95 // don't need to block on the deregistration. Also StartAbort() may call
96 // done() and DeregisterCallback may deadlock.
97 c->cancellation_manager()->TryDeregisterCallback(token);
98 done();
99 };
100 ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
101 }
102
103 // A string encoding instance, frame and iter to be handed off to
104 // the implementation for use in generating RecvBuf keys.
105 string GetCollectiveKey(OpKernelContext* c) {
106 return CollectiveKey(c, col_params_->group.group_key,
107 col_params_->instance.instance_key);
108 }
109
110 // Returns false if calling invocation of ComputeAsync should return
111 // immediately.
112 bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
113 const DoneCallback& done) {
114 if (col_params_->group.group_size > col_params_->group.members.size()) {
115 // This is the first invocation: Finish initializing col_params_.
116 // Schedule the `CompleteParamsAsync` call on a work queue that can handle
117 // blocking work because it's not guaranteed that this call cannot block.
118 c->collective_executor()->RunClosure([this, c, col_exec, done]() {
119 VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
120 << col_params_->name << " device " << c->device()->name()
121 << " group " << col_params_->group.group_key << " instance "
122 << col_params_->instance.instance_key;
123 col_exec->CompleteParamsAsync(
124 c->device()->attributes(), col_params_, c->cancellation_manager(),
125 [this, c, done](const Status& s) {
126 if (s.ok()) {
127 col_params_->instance.impl_details.dependencies = dependencies_;
128 ComputeAsync(c, done);
129 } else {
130 c->SetStatus(s);
131 done();
132 }
133 });
134 });
135 return false;
136 }
137 return true;
138 }
139
140 protected:
141 virtual void ComputeAsyncImpl(OpKernelContext* c,
142 CollectiveExecutor* col_exec,
143 DoneCallback done) = 0;
144
145 string name_;
146 CollectiveParams* col_params_;
147 std::vector<int32> dependencies_;
148};
149
150class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
151 public:
152 explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
153 : CollectiveOpV1Kernel(c) {
154 col_params_->instance.type = GATHER_COLLECTIVE;
155 OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
156 OP_REQUIRES(
157 c, col_params_->group.group_size > 0,
158 errors::InvalidArgument("group_size must be positive integer but got ",
159 col_params_->group.group_size));
160 OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
161 OP_REQUIRES_OK(
162 c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
163 OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
164 OP_REQUIRES_OK(
165 c, c->GetAttr("communication_hint",
166 &col_params_->instance.impl_details.communication_hint));
167 OP_REQUIRES_OK(
168 c, c->GetAttr("timeout_seconds",
169 &col_params_->instance.impl_details.timeout_seconds));
170 const NodeDef& real_node = c->def();
171 col_params_->name = strings::StrCat(real_node.name(), ": Gather");
172 col_params_->group.device_type = c->device_type();
173 }
174
175 protected:
176 void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
177 DoneCallback done) override {
178 auto output_shape = c->input(0).shape();
179 OP_REQUIRES_ASYNC(c, output_shape.dims() > 0,
180 errors::InvalidArgument("input should have rank > 0, ",
181 "recieved ", output_shape.dims()),
182 done);
183 output_shape.set_dim(
184 0, output_shape.dim_size(0) * col_params_->group.group_size);
185 col_params_->instance.shape = output_shape;
186
187 // Allocate output on the first pass through this function. This must be
188 // done immediately, while we're still in the executor thread. Otherwise
189 // the memory is not guaranteed to be unused by any concurrently executing
190 // GPU kernel.
191 if (c->mutable_output(0) == nullptr) {
192 // Allocate the output tensor.
193 Tensor* output = nullptr;
194 OP_REQUIRES_OK_ASYNC(
195 c, c->allocate_output(0, col_params_->instance.shape, &output), done);
196 }
197 if (!CanProceedWithCompute(c, col_exec, done)) return;
198
199 auto actual_done = [c, col_params = col_params_, done](const Status& s) {
200 VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
201 << c->op_kernel().name() << " device " << c->device()->name()
202 << " group " << col_params->group.group_key << " instance "
203 << col_params->instance.instance_key << " status " << s;
204 col_params->Unref();
205 OP_REQUIRES_OK_ASYNC(c, s, done);
206 done();
207 };
208 VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
209 << col_params_->name << " device " << c->device()->name()
210 << " group " << col_params_->group.group_key << " instance "
211 << col_params_->instance.instance_key;
212 col_params_->Ref();
213 col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
214 }
215
216 private:
217 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveGatherOpKernel);
218};
219
220REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
221 CollectiveGatherOpKernel);
222REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
223 CollectiveGatherOpKernel);
224
225class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
226 public:
227 explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
228 : CollectiveOpV1Kernel(c) {
229 col_params_->instance.type = REDUCTION_COLLECTIVE;
230 OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
231 OP_REQUIRES(
232 c, col_params_->group.group_size > 0,
233 errors::InvalidArgument("group_size must be positive integer but got ",
234 col_params_->group.group_size));
235 OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
236 OP_REQUIRES_OK(
237 c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
238 OP_REQUIRES_OK(
239 c, c->GetAttr("subdiv_offsets",
240 &col_params_->instance.impl_details.subdiv_offsets));
241 string merge_op_name;
242 OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
243 if (merge_op_name == "Max") {
244 merge_op_name = "Maximum";
245 } else if (merge_op_name == "Min") {
246 merge_op_name = "Minimum";
247 }
248 string final_op_name;
249 OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
250 OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
251 errors::InvalidArgument(
252 "final_op must be one of {\"Id\", \"Div\"} but got ",
253 final_op_name));
254 OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
255 OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
256 OP_REQUIRES_OK(
257 c, c->GetAttr("communication_hint",
258 &col_params_->instance.impl_details.communication_hint));
259 OP_REQUIRES_OK(
260 c, c->GetAttr("timeout_seconds",
261 &col_params_->instance.impl_details.timeout_seconds));
262 VLOG(2) << "CollectiveReduce instance "
263 << col_params_->instance.instance_key << " merge_op "
264 << merge_op_name << " final_op " << final_op_name
265 << " communication_hint "
266 << col_params_->instance.impl_details.communication_hint
267 << " timeout "
268 << col_params_->instance.impl_details.timeout_seconds;
269
270 const NodeDef& real_node = c->def();
271 col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
272 merge_op_name, ",", final_op_name, ")");
273 col_params_->group.device_type = c->device_type();
274
275 // Find the OpKernels by name, type and device type.
276 NodeDef sub_node;
277 // The merge_op takes two inputs
278 sub_node.add_input(real_node.input(0));
279 sub_node.add_input(real_node.input(0));
280 sub_node.set_device(real_node.device());
281 SetAttrValue(col_params_->instance.data_type,
282 &(*sub_node.mutable_attr())["T"]);
283 merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
284 final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
285 col_params_->merge_op = merge_op_.get();
286 col_params_->final_op = final_op_.get();
287 }
288
289 protected:
290 void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
291 DoneCallback done) override {
292 // Allocate output on the first pass through this function. This must be
293 // done immediately, while we're still in the executor thread. Otherwise
294 // the memory is not guaranteed to be unused by any concurrently executing
295 // GPU kernel.
296 if (c->mutable_output(0) == nullptr) {
297 // Allocate the output tensor, trying to reuse the input.
298 Tensor* output = nullptr;
299 OP_REQUIRES_OK_ASYNC(c,
300 c->forward_input_or_allocate_output(
301 {0}, 0, c->input(0).shape(), &output),
302 done);
303 col_params_->instance.shape = c->input(0).shape();
304 }
305 if (!CanProceedWithCompute(c, col_exec, done)) return;
306
307 auto actual_done = [c, col_params = col_params_, done](const Status& s) {
308 VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
309 << c->op_kernel().name() << " device " << c->device()->name()
310 << " group " << col_params->group.group_key << " instance "
311 << col_params->instance.instance_key << " status " << s;
312 col_params->Unref();
313 OP_REQUIRES_OK_ASYNC(c, s, done);
314 done();
315 };
316 VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
317 << col_params_->name << " device " << c->device()->name()
318 << " group " << col_params_->group.group_key << " instance "
319 << col_params_->instance.instance_key;
320 col_params_->Ref();
321 col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
322 }
323
324 private:
325 std::unique_ptr<OpKernel> merge_op_;
326 std::unique_ptr<OpKernel> final_op_;
327 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
328};
329
330REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
331 CollectiveReduceOpKernel);
332REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
333 CollectiveReduceOpKernel);
334
335class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
336 public:
337 explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
338 : CollectiveOpV1Kernel(c) {
339 col_params_->instance.type = BROADCAST_COLLECTIVE;
340 OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
341 OP_REQUIRES(
342 c, col_params_->group.group_size > 0,
343 errors::InvalidArgument("group_size must be positive integer but got ",
344 col_params_->group.group_size));
345 OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
346 OP_REQUIRES_OK(
347 c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
348 OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
349 OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
350 OP_REQUIRES_OK(
351 c, c->GetAttr("communication_hint",
352 &col_params_->instance.impl_details.communication_hint));
353 OP_REQUIRES_OK(
354 c, c->GetAttr("timeout_seconds",
355 &col_params_->instance.impl_details.timeout_seconds));
356 col_params_->is_source = true;
357 col_params_->instance.impl_details.subdiv_offsets = {0};
358
359 col_params_->name =
360 strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
361 col_params_->group.device_type = c->device_type();
362 }
363
364 protected:
365 void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
366 DoneCallback done) override {
367 // Allocate output on the first pass through this function. This must be
368 // done immediately, while we're still in the executor thread. Otherwise
369 // the memory is not guaranteed to be unused by any concurrently executing
370 // GPU kernel.
371 if (c->mutable_output(0) == nullptr) {
372 // Allocate the output tensor, trying to reuse the input.
373 Tensor* output = nullptr;
374 OP_REQUIRES_OK_ASYNC(c,
375 c->forward_input_or_allocate_output(
376 {0}, 0, col_params_->instance.shape, &output),
377 done);
378 }
379 if (!CanProceedWithCompute(c, col_exec, done)) return;
380 OP_REQUIRES_ASYNC(
381 c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
382 errors::Internal("Declared shape of op ", col_params_->name,
383 " does not match shape of input"),
384 done);
385
386 auto actual_done = [c, col_params = col_params_, done](const Status& s) {
387 VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
388 << c->op_kernel().name() << " device " << c->device()->name()
389 << " group " << col_params->group.group_key << " instance "
390 << col_params->instance.instance_key << " status " << s;
391 col_params->Unref();
392 OP_REQUIRES_OK_ASYNC(c, s, done);
393 done();
394 };
395 VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
396 << col_params_->name << " device " << c->device()->name()
397 << " group " << col_params_->group.group_key << " instance "
398 << col_params_->instance.instance_key;
399 col_params_->Ref();
400 col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
401 }
402
403 private:
404 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
405};
406
407REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
408 CollectiveBcastSendOpKernel);
409REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_DEFAULT),
410 CollectiveBcastSendOpKernel);
411
412class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
413 public:
414 explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
415 : CollectiveOpV1Kernel(c) {
416 col_params_->instance.type = BROADCAST_COLLECTIVE;
417 OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
418 OP_REQUIRES(
419 c, col_params_->group.group_size > 0,
420 errors::InvalidArgument("group_size must be positive integer but got ",
421 col_params_->group.group_size));
422 OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
423 OP_REQUIRES_OK(
424 c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
425 OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
426 OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
427 OP_REQUIRES_OK(
428 c, c->GetAttr("communication_hint",
429 &col_params_->instance.impl_details.communication_hint));
430 OP_REQUIRES_OK(
431 c, c->GetAttr("timeout_seconds",
432 &col_params_->instance.impl_details.timeout_seconds));
433 col_params_->is_source = false;
434 col_params_->instance.impl_details.subdiv_offsets = {0};
435
436 col_params_->name =
437 strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
438 col_params_->group.device_type = c->device_type();
439 }
440
441 protected:
442 void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
443 DoneCallback done) override {
444 // Allocate output on the first pass through this function. This must be
445 // done immediately, while we're still in the executor thread. Otherwise
446 // the memory is not guaranteed to be unused by any concurrently executing
447 // GPU kernel.
448 if (c->mutable_output(0) == nullptr) {
449 // No input, so must allocate output.
450 Tensor* output = nullptr;
451 OP_REQUIRES_OK_ASYNC(
452 c, c->allocate_output(0, col_params_->instance.shape, &output), done);
453 }
454 if (!CanProceedWithCompute(c, col_exec, done)) return;
455
456 auto actual_done = [c, col_params = col_params_, done](const Status& s) {
457 VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
458 << c->op_kernel().name() << " device " << c->device()->name()
459 << " group " << col_params->group.group_key << " instance_key "
460 << col_params->instance.instance_key << " status " << s;
461 col_params->Unref();
462 OP_REQUIRES_OK_ASYNC(c, s, done);
463 done();
464 };
465 VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
466 << col_params_->name << " device " << c->device()->name()
467 << " group " << col_params_->group.group_key << " instance "
468 << col_params_->instance.instance_key;
469 col_params_->Ref();
470 col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
471 }
472
473 private:
474 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
475};
476
477REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
478 CollectiveBcastRecvOpKernel);
479REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_DEFAULT),
480 CollectiveBcastRecvOpKernel);
481
482class CollectiveAssignGroupV2OpKernel : public OpKernel {
483 public:
484 explicit CollectiveAssignGroupV2OpKernel(OpKernelConstruction* c)
485 : OpKernel(c) {}
486
487 void Compute(OpKernelContext* context) override {
488 const Tensor& group_assignment = context->input(0);
489 const Tensor& device_index = context->input(1);
490 const Tensor& base_key = context->input(2);
491
492 OP_REQUIRES(
493 context, TensorShapeUtils::IsScalar(device_index.shape()),
494 errors::InvalidArgument(
495 "device_index must be a scalar, but received tensor of shape: ",
496 device_index.shape().DebugString()));
497
498 OP_REQUIRES(
499 context, TensorShapeUtils::IsMatrix(group_assignment.shape()),
500 errors::InvalidArgument("group_assignment must be a 2-d Tensor, but "
501 "received tensor of shape: ",
502 group_assignment.shape().DebugString()));
503 OP_REQUIRES(context, TensorShapeUtils::IsScalar(base_key.shape()),
504 errors::InvalidArgument(
505 "base_key must be a scalar, but received tensor of shape: ",
506 base_key.shape().DebugString()));
507
508 Tensor* group_key = nullptr;
509 Tensor* group_size = nullptr;
510 AllocatorAttributes attr;
511 attr.set_on_host(true);
512 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
513 &group_size, attr));
514
515 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
516 &group_key, attr));
517
518 OP_REQUIRES_OK(
519 context,
520 ComputeGroupKey(group_assignment, device_index.scalar<int32_t>()(),
521 base_key.scalar<int32_t>()(), group_size, group_key));
522 }
523
524 private:
525 static Status ComputeGroupKey(const Tensor& group_assignment,
526 const int32_t device_index,
527 const int32_t base_key, Tensor* group_size,
528 Tensor* group_key) {
529 group_size->flat<int32_t>()(0) = group_assignment.dim_size(1);
530
531 for (int group_id = 0; group_id < group_assignment.dim_size(0);
532 group_id++) {
533 int32_t key = static_cast<int32_t>(static_cast<uint32_t>(base_key) +
534 static_cast<uint32_t>(group_id));
535 if (key == 0) {
536 return errors::InvalidArgument(
537 "Using the reserved group_key = 0 is not allowed: group_id = ",
538 group_id, ", base_key = ", base_key);
539 }
540 for (int color = 0; color < group_assignment.dim_size(1); color++) {
541 const auto index = group_assignment.matrix<int32>()(group_id, color);
542 if (index < 0 || index >= group_assignment.shape().num_elements()) {
543 return errors::InvalidArgument("Not all items in group_assignment ",
544 group_assignment.DebugString(),
545 " is within [0, number of devices)");
546 }
547 if (index == device_index) {
548 group_key->flat<int32_t>()(0) = key;
549 VLOG(2) << " group_assignment = " << group_assignment.DebugString()
550 << " device_index = " << index
551 << " group_key = " << group_key->DebugString()
552 << " group_size = " << group_size->DebugString();
553 return OkStatus();
554 }
555 }
556 }
557 return errors::InvalidArgument("device_index ", device_index,
558 " is not found in group_assignment ",
559 group_assignment.DebugString());
560 }
561};
562
563REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2").Device(DEVICE_CPU),
564 CollectiveAssignGroupV2OpKernel);
565REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2")
566 .Device(DEVICE_DEFAULT)
567 .HostMemory("device_index")
568 .HostMemory("group_assignment")
569 .HostMemory("base_key")
570 .HostMemory("group_size")
571 .HostMemory("group_key"),
572 CollectiveAssignGroupV2OpKernel);
573
574class CollectiveOpV2Kernel : public AsyncOpKernel {
575 public:
576 explicit CollectiveOpV2Kernel(OpKernelConstruction* c)
577 : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
578 OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
579 OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
580 OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
581 device_type_ = c->device_type();
582 }
583
584 protected:
585 // Fills common parts of CollectiveParams according to the Op, *excluding
586 // output_shape*. Kernels should further work on the CollectiveParams if they
587 // need to set additional fields.
588 Status FillCollectiveParams(CollectiveParams* col_params,
589 CollectiveType collective_type,
590 const Tensor& group_size, const Tensor& group_key,
591 const Tensor& instance_key) {
592 if (group_size.dims() > 0) {
593 return errors::InvalidArgument(
594 "Unexpected dimensions on input group_size, got ",
595 group_size.shape().DebugString());
596 }
597 if (group_key.dims() > 0) {
598 return errors::InvalidArgument(
599 "Unexpected dimensions on input group_key, got ",
600 group_key.shape().DebugString());
601 }
602 if (instance_key.dims() > 0) {
603 return errors::InvalidArgument(
604 "Unexpected dimensions on input instance_key, got ",
605 instance_key.shape().DebugString());
606 }
607 col_params->name = name_;
608 col_params->group.device_type = device_type_;
609 col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
610 if (col_params->group.group_size <= 0) {
611 return errors::InvalidArgument(
612 "group_size must be positive integer but got ",
613 col_params->group.group_size);
614 }
615 col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
616 col_params->instance.type = collective_type;
617 col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
618 col_params->instance.data_type = data_type_;
619 col_params->instance.impl_details.communication_hint = communication_hint_;
620 col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
621 return OkStatus();
622 }
623
624 // Runs a collective. The output tensor must be allocated before calling this
625 // method. col_params must live until done is called.
626 void Run(OpKernelContext* c, CollectiveParams* col_params,
627 DoneCallback done) {
628 CollectiveExecutor* col_exec = c->collective_executor();
629 OP_REQUIRES_ASYNC(
630 c, col_exec,
631 errors::Internal(
632 "Failed to get CollectiveExecutor from OpKernelContext for Op ",
633 name_),
634 done);
635 // Resolve the collective params.
636 // Schedule the `CompleteParamsAsync` call on a work queue that can handle
637 // blocking work because it's not guaranteed that this call cannot block.
638 c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
639 col_exec]() {
640 VLOG(1) << "Collective CompleteParams for " << col_params->name
641 << " device " << c->device()->name() << " group "
642 << col_params->group.group_key << " instance "
643 << col_params->instance.instance_key;
644 col_exec->CompleteParamsAsync(
645 c->device()->attributes(), col_params, c->cancellation_manager(),
646 [c, done = std::move(done), col_params, col_exec](const Status& s) {
647 if (s.ok()) {
648 auto actual_done = [c, col_params,
649 done = std::move(done)](const Status& s) {
650 VLOG(1) << "Collective ExecuteAsync done for "
651 << col_params->name << " device " << c->device()->name()
652 << " group " << col_params->group.group_key
653 << " instance " << col_params->instance.instance_key
654 << " status " << s;
655 if (!s.ok()) {
656 c->SetStatus(s);
657 }
658 done();
659 };
660 VLOG(1) << "Collective ExecuteAsync start for "
661 << col_params->name << " device " << c->device()->name()
662 << " group " << col_params->group.group_key
663 << " instance " << col_params->instance.instance_key;
664 col_exec->ExecuteAsync(
665 c, col_params,
666 CollectiveKey(c, col_params->group.group_key,
667 col_params->instance.instance_key),
668 actual_done);
669 } else {
670 c->SetStatus(s);
671 done();
672 }
673 });
674 });
675 }
676
677 protected:
678 string name_;
679 DataType data_type_ = DT_INVALID;
680 string communication_hint_;
681 float timeout_seconds_ = 0;
682 DeviceType device_type_;
683};
684
685class CollectiveReduceV2OpKernel : public CollectiveOpV2Kernel {
686 public:
687 explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
688 : CollectiveOpV2Kernel(c) {
689 string merge_op_name;
690 OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
691 if (merge_op_name == "Max") {
692 merge_op_name = "Maximum";
693 } else if (merge_op_name == "Min") {
694 merge_op_name = "Minimum";
695 }
696 string final_op_name;
697 OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
698 OP_REQUIRES_OK(
699 c, c->GetAttr("max_subdivs_per_device", &max_subdivs_per_device_));
700 // Prepare OpKernels for reduction and final operations.
701 // The merge_op takes two inputs
702 NodeDef sub_node;
703 sub_node.add_input(c->def().input(0));
704 sub_node.add_input(c->def().input(0));
705 sub_node.set_device(c->def().device());
706 SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
707 merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
708 final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
709 name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",",
710 final_op_name, ")");
711 VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_
712 << " communication_hint " << communication_hint_;
713 }
714
715 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
716 auto col_params = new CollectiveParams();
717 auto done_with_cleanup = [col_params, done = std::move(done)]() {
718 done();
719 col_params->Unref();
720 };
721 OP_REQUIRES_OK_ASYNC(c,
722 FillCollectiveParams(col_params, REDUCTION_COLLECTIVE,
723 /*group_size*/ c->input(1),
724 /*group_key*/ c->input(2),
725 /*instance_key*/ c->input(3)),
726 done_with_cleanup);
727 col_params->instance.shape = c->input(0).shape();
728 col_params->merge_op = merge_op_.get();
729 col_params->final_op = final_op_.get();
730 VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
731 << " group_key " << col_params->group.group_key << " instance_key "
732 << col_params->instance.instance_key;
733 // Allocate the output tensor, trying to reuse the input.
734 Tensor* output = nullptr;
735 OP_REQUIRES_OK_ASYNC(c,
736 c->forward_input_or_allocate_output(
737 {0}, 0, col_params->instance.shape, &output),
738 done_with_cleanup);
739 Run(c, col_params, std::move(done_with_cleanup));
740 }
741
742 private:
743 int max_subdivs_per_device_;
744 std::unique_ptr<OpKernel> merge_op_;
745 std::unique_ptr<OpKernel> final_op_;
746};
747
748REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
749 CollectiveReduceV2OpKernel);
750REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
751 .Device(DEVICE_DEFAULT)
752 .HostMemory("group_size")
753 .HostMemory("group_key")
754 .HostMemory("instance_key"),
755 CollectiveReduceV2OpKernel);
756
757class CollectiveGatherV2OpKernel : public CollectiveOpV2Kernel {
758 public:
759 explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
760 : CollectiveOpV2Kernel(c) {
761 name_ = strings::StrCat(c->def().name(), ": GatherV2");
762 VLOG(2) << "CollectiveGatherV2 " << this << " name " << name_
763 << " communication_hint " << communication_hint_;
764 }
765
766 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
767 auto col_params = new CollectiveParams();
768 auto done_with_cleanup = [col_params, done = std::move(done)]() {
769 done();
770 col_params->Unref();
771 };
772 OP_REQUIRES_OK_ASYNC(c,
773 FillCollectiveParams(col_params, GATHER_COLLECTIVE,
774 /*group_size*/ c->input(1),
775 /*group_key*/ c->input(2),
776 /*instance_key*/
777 c->input(3)),
778 done_with_cleanup);
779 auto output_shape = c->input(0).shape();
780 output_shape.set_dim(
781 0, output_shape.dim_size(0) * col_params->group.group_size);
782 col_params->instance.shape = output_shape;
783 VLOG(1) << "CollectiveGatherV2 group_size " << col_params->group.group_size
784 << " group_key " << col_params->group.group_key << " instance_key "
785 << col_params->instance.instance_key;
786 Tensor* output = nullptr;
787 OP_REQUIRES_OK_ASYNC(
788 c, c->allocate_output(0, col_params->instance.shape, &output),
789 done_with_cleanup);
790 Run(c, col_params, std::move(done_with_cleanup));
791 }
792};
793
794REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),
795 CollectiveGatherV2OpKernel);
796REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
797 .Device(DEVICE_DEFAULT)
798 .HostMemory("group_size")
799 .HostMemory("group_key")
800 .HostMemory("instance_key"),
801 CollectiveGatherV2OpKernel);
802
803class CollectiveBcastSendV2OpKernel : public CollectiveOpV2Kernel {
804 public:
805 explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
806 : CollectiveOpV2Kernel(c) {
807 const bool is_source = true;
808 name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
809 }
810
811 protected:
812 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
813 auto col_params = new CollectiveParams();
814 auto done_with_cleanup = [col_params, done = std::move(done)]() {
815 done();
816 col_params->Unref();
817 };
818 OP_REQUIRES_OK_ASYNC(c,
819 FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
820 /*group_size*/ c->input(1),
821 /*group_key*/ c->input(2),
822 /*instance_key*/ c->input(3)),
823 done_with_cleanup);
824 col_params->is_source = true;
825 col_params->instance.shape = c->input(0).shape();
826 // Add a default value for subdiv offsets, which is the same as the default
827 // value in the V1 op's attribute.
828 col_params->instance.impl_details.subdiv_offsets.push_back(0);
829 VLOG(1) << "CollectiveBcastSendV2 group_size "
830 << col_params->group.group_size << " group_key "
831 << col_params->group.group_key << " instance_key "
832 << col_params->instance.instance_key;
833 // Allocate the output tensor, trying to reuse the input.
834 Tensor* output = nullptr;
835 OP_REQUIRES_OK_ASYNC(c,
836 c->forward_input_or_allocate_output(
837 {0}, 0, col_params->instance.shape, &output),
838 done_with_cleanup);
839 Run(c, col_params, std::move(done_with_cleanup));
840 }
841};
842
843REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
844 CollectiveBcastSendV2OpKernel);
845REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
846 .Device(DEVICE_DEFAULT)
847 .HostMemory("group_size")
848 .HostMemory("group_key")
849 .HostMemory("instance_key"),
850 CollectiveBcastSendV2OpKernel);
851
852class CollectiveBcastRecvV2OpKernel : public CollectiveOpV2Kernel {
853 public:
854 explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
855 : CollectiveOpV2Kernel(c) {
856 const bool is_source = false;
857 name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
858 }
859
860 protected:
861 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
862 auto col_params = new CollectiveParams();
863 auto done_with_cleanup = [col_params, done = std::move(done)]() {
864 done();
865 col_params->Unref();
866 };
867 OP_REQUIRES_OK_ASYNC(c,
868 FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
869 /*group_size*/ c->input(0),
870 /*group_key*/ c->input(1),
871 /*instance_key*/ c->input(2)),
872 done_with_cleanup);
873 col_params->is_source = false;
874 TensorShape output_shape;
875 OP_REQUIRES_OK_ASYNC(c, tensor::MakeShape(c->input(3), &output_shape),
876 done_with_cleanup);
877 col_params->instance.shape = output_shape;
878 // Add a default value for subdiv offsets, which is the same as the default
879 // value in the V1 op's attribute.
880 col_params->instance.impl_details.subdiv_offsets.push_back(0);
881 VLOG(1) << "CollectiveBcastRecvV2 group_size "
882 << col_params->group.group_size << " group_key "
883 << col_params->group.group_key << " instance_key "
884 << col_params->instance.instance_key;
885 Tensor* output = nullptr;
886 OP_REQUIRES_OK_ASYNC(
887 c, c->allocate_output(0, col_params->instance.shape, &output),
888 done_with_cleanup);
889 Run(c, col_params, std::move(done_with_cleanup));
890 }
891};
892
893REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
894 CollectiveBcastRecvV2OpKernel);
895REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
896 .Device(DEVICE_DEFAULT)
897 .HostMemory("group_size")
898 .HostMemory("group_key")
899 .HostMemory("instance_key")
900 .HostMemory("shape"),
901 CollectiveBcastRecvV2OpKernel);
902
903/*
904 * Resource for holding group for CollectiveOps.
905 * This resource is returned from CollectiveInitializeCommunicatorOpKernel
906 * It generates next instance key for the group for each collective operation.
907 */
908class CollectiveGroupResource : public ResourceBase {
909 public:
910 CollectiveGroupResource(int32 group_key, int32 rank, int32 group_size,
911 string communication_hint, float timeout_seconds)
912 : group_key_(group_key),
913 rank_(rank),
914 group_size_(group_size),
915 communication_hint_(communication_hint),
916 timeout_seconds_(timeout_seconds) {}
917
918 std::string DebugString() const override {
919 return absl::StrFormat(
920 "Collective Group with group_key = %d, group_size = %d, rank = %d",
921 group_key_, group_size_, rank_);
922 }
923
924 int get_next_instance_key() {
925 return instance_key_.fetch_add(1, std::memory_order_relaxed);
926 }
927
928 int32 group_key() const { return group_key_; }
929
930 int32 rank() const { return rank_; }
931
932 int32 group_size() const { return group_size_; }
933
934 string communication_hint() const { return communication_hint_; }
935
936 float timeout_seconds() const { return timeout_seconds_; }
937
938 private:
939 int32 group_key_, rank_, group_size_;
940 string communication_hint_;
941 std::atomic<int> instance_key_{0};
942 float timeout_seconds_ = 0;
943};
944
945class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel {
946 public:
947 explicit CollectiveInitializeCommunicatorOpKernel(OpKernelConstruction* c)
948 : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
949 OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
950 OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
951 device_type_ = c->device_type();
952 }
953
954 Status CheckInputs(Tensor group_size_t, Tensor group_key_t) {
955 if (group_size_t.dims() > 0) {
956 return errors::InvalidArgument(
957 "Unexpected dimensions on input group_size. "
958 "It shoulbe a scalar, got tensor with shape ",
959 group_size_t.shape().DebugString());
960 }
961 if (group_key_t.dims() > 0) {
962 return errors::InvalidArgument(
963 "Unexpected dimensions on input group_key, got ",
964 group_key_t.shape().DebugString());
965 }
966
967 auto group_size = group_size_t.unaligned_flat<int32>()(0);
968 if (group_size <= 0) {
969 return errors::InvalidArgument(
970 "group_size must be positive integer but got ", group_size);
971 }
972 return OkStatus();
973 }
974
975 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
976 auto group_key_t = c->input(0);
977 auto rank_t = c->input(1);
978 auto group_size_t = c->input(2);
979
980 OP_REQUIRES_OK_ASYNC(c, CheckInputs(group_size_t, group_key_t), done);
981
982 auto group_size = group_size_t.unaligned_flat<int32>()(0);
983 auto group_key = group_key_t.unaligned_flat<int32>()(0);
984 auto rank = rank_t.unaligned_flat<int32>()(0);
985
986 ResourceHandle resource_handle =
987 MakeResourceHandle<CollectiveGroupResource>(
988 c, "collective_op_group",
989 absl::StrFormat("%d:r%04d", group_key, rank));
990
991 Tensor* output_handle = nullptr;
992 OP_REQUIRES_OK_ASYNC(
993 c, c->allocate_output(0, TensorShape({}), &output_handle), done);
994 output_handle->scalar<ResourceHandle>()() = resource_handle;
995
996 CollectiveGroupResource* resource = new CollectiveGroupResource(
997 group_key, rank, group_size, this->communication_hint_,
998 this->timeout_seconds_);
999 OP_REQUIRES_OK_ASYNC(
1000 c,
1001 CreateResource<CollectiveGroupResource>(c, resource_handle, resource),
1002 done);
1003 auto group_params = new CollGroupParams();
1004 group_params->device_type = device_type_;
1005 group_params->group_size = resource->group_size();
1006 group_params->group_key = resource->group_key();
1007 group_params->user_specified_rank = resource->rank();
1008
1009 auto* col_exec = c->collective_executor();
1010
1011 c->collective_executor()->RunClosure([c, done = std::move(done),
1012 group_params, col_exec]() {
1013 VLOG(1) << "Collective Group initialization for "
1014 << " device " << c->device()->name() << " group "
1015 << group_params->group_key;
1016 col_exec->CompleteGroupAsync(
1017 c->device()->attributes(), group_params, c->cancellation_manager(),
1018 [c, done = std::move(done), group_params](const Status& s) {
1019 if (s.ok()) {
1020 VLOG(1) << "Collective Group initialization done for device "
1021 << c->device()->name() << " group "
1022 << group_params->group_key << " status " << s;
1023 } else {
1024 c->SetStatus(s);
1025 }
1026 delete group_params;
1027 done();
1028 });
1029 });
1030 }
1031
1032 private:
1033 string communication_hint_;
1034 DeviceType device_type_;
1035 float timeout_seconds_ = 0;
1036};
1037
1038REGISTER_KERNEL_BUILDER(
1039 Name("CollectiveInitializeCommunicator").Device(DEVICE_CPU),
1040 CollectiveInitializeCommunicatorOpKernel);
1041REGISTER_KERNEL_BUILDER(Name("CollectiveInitializeCommunicator")
1042 .Device(DEVICE_GPU)
1043 .HostMemory("group_size")
1044 .HostMemory("group_key")
1045 .HostMemory("rank"),
1046 CollectiveInitializeCommunicatorOpKernel);
1047
1048class CollectiveOpV3Kernel : public AsyncOpKernel {
1049 public:
1050 explicit CollectiveOpV3Kernel(OpKernelConstruction* c)
1051 : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
1052 OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
1053 if (c->HasAttr("timeout_seconds")) {
1054 OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
1055 } else {
1056 timeout_seconds_ = -1;
1057 }
1058 device_type_ = c->device_type();
1059 }
1060
1061 protected:
1062 // Fills common parts of CollectiveParams according to the Op, *excluding
1063 // output_shape*. Kernels should further work on the CollectiveParams if they
1064 // need to set additional fields.
1065 Status FillCollectiveParams(CollectiveParams* col_params,
1066 const Tensor& group_assignment,
1067 CollectiveType collective_type,
1068 CollectiveGroupResource* resource) {
1069 int64 group_id;
1070 int64 group_size;
1071 if (group_assignment.NumElements() == 0) {
1072 // No group assignments, perform collective as a single group.
1073 group_id = 0;
1074 group_size = resource->group_size();
1075 } else {
1076 return errors::Unimplemented("Group assignments are not supported yet.");
1077 }
1078
1079 // Construct instance key with format:
1080 // <11 bits for group><21 bits for atomic incremented instance key>
1081 int32 instance_key = group_id << 21 | resource->get_next_instance_key();
1082 col_params->name = name_;
1083 col_params->group.device_type = device_type_;
1084 col_params->group.group_size = group_size;
1085 col_params->group.group_key = resource->group_key();
1086 col_params->group.user_specified_rank = resource->rank();
1087 col_params->instance.type = collective_type;
1088 col_params->instance.instance_key = instance_key;
1089 col_params->instance.data_type = data_type_;
1090 col_params->instance.impl_details.communication_hint =
1091 resource->communication_hint();
1092 col_params->instance.impl_details.timeout_seconds =
1093 timeout_seconds_ > 0 ? resource->timeout_seconds() : timeout_seconds_;
1094 col_params->run_group_initialization = false;
1095 return OkStatus();
1096 }
1097
1098 // Runs a collective. The output tensor must be allocated before calling this
1099 // method. col_params must live until done is called.
1100 void Run(OpKernelContext* c, CollectiveParams* col_params,
1101 DoneCallback done) {
1102 CollectiveExecutor* col_exec = c->collective_executor();
1103 OP_REQUIRES_ASYNC(
1104 c, col_exec,
1105 errors::Internal(
1106 "Failed to get CollectiveExecutor from OpKernelContext for Op ",
1107 name_),
1108 done);
1109 // Resolve the collective params.
1110 // Schedule the `CompleteParamsAsync` call on a work queue that can handle
1111 // blocking work because it's not guaranteed that this call cannot block.
1112 col_exec->RunClosure([c, done = std::move(done), col_params, col_exec]() {
1113 VLOG(1) << "Collective CompleteParams for " << col_params->name
1114 << " device " << c->device()->name() << " group "
1115 << col_params->group.group_key << " instance "
1116 << col_params->instance.instance_key;
1117 col_exec->CompleteParamsAsync(
1118 c->device()->attributes(), col_params, c->cancellation_manager(),
1119 [c, done = std::move(done), col_params, col_exec](const Status& s) {
1120 if (s.ok()) {
1121 auto actual_done = [c, col_params,
1122 done = std::move(done)](const Status& s) {
1123 VLOG(1) << "Collective ExecuteAsync done for "
1124 << col_params->name << " device " << c->device()->name()
1125 << " group " << col_params->group.group_key
1126 << " instance " << col_params->instance.instance_key
1127 << " status " << s;
1128 if (!s.ok()) {
1129 c->SetStatus(s);
1130 }
1131 done();
1132 };
1133 VLOG(1) << "Collective ExecuteAsync start for "
1134 << col_params->name << " device " << c->device()->name()
1135 << " group " << col_params->group.group_key
1136 << " instance " << col_params->instance.instance_key;
1137 col_exec->ExecuteAsync(
1138 c, col_params,
1139 CollectiveKey(c, col_params->group.group_key,
1140 col_params->instance.instance_key),
1141 actual_done);
1142 } else {
1143 c->SetStatus(s);
1144 done();
1145 }
1146 });
1147 });
1148 }
1149
1150 protected:
1151 string name_;
1152 DataType data_type_ = DT_INVALID;
1153 DeviceType device_type_;
1154 float timeout_seconds_ = 0;
1155};
1156
1157class CollectiveReduceV3OpKernel : public CollectiveOpV3Kernel {
1158 public:
1159 explicit CollectiveReduceV3OpKernel(OpKernelConstruction* c)
1160 : CollectiveOpV3Kernel(c) {
1161 string reduction;
1162 OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
1163 if (reduction == "Max") {
1164 reduction = "Maximum";
1165 } else if (reduction == "Min") {
1166 reduction = "Minimum";
1167 }
1168 // Prepare OpKernels for reduction and final operations.
1169 // The merge_op takes two inputs
1170 NodeDef sub_node;
1171 sub_node.add_input(c->def().input(0));
1172 sub_node.add_input(c->def().input(0));
1173 sub_node.set_device(c->def().device());
1174 SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
1175 merge_op_ = BuildOpKernel(c, reduction, &sub_node);
1176 final_op_ = BuildOpKernel(c, "Id", &sub_node);
1177 name_ = strings::StrCat(c->def().name(), ": ReduceV3(", reduction, ")");
1178 VLOG(2) << "CollectiveReduceV3 " << this << " name " << name_;
1179 }
1180
1181 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
1182 auto col_params = new CollectiveParams();
1183 auto done_with_cleanup = [col_params, done = std::move(done)]() {
1184 done();
1185 col_params->Unref();
1186 };
1187 core::RefCountPtr<CollectiveGroupResource> resource;
1188 OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
1189 done_with_cleanup);
1190
1191 Tensor group_assignment = c->input(2);
1192
1193 OP_REQUIRES_OK_ASYNC(
1194 c,
1195 FillCollectiveParams(col_params, group_assignment, REDUCTION_COLLECTIVE,
1196 resource.get()),
1197 done_with_cleanup);
1198 col_params->instance.shape = c->input(0).shape();
1199 col_params->merge_op = merge_op_.get();
1200 col_params->final_op = final_op_.get();
1201 VLOG(1) << "CollectiveReduceV3 group_size " << col_params->group.group_size
1202 << " group_key " << col_params->group.group_key << " instance_key "
1203 << col_params->instance.instance_key;
1204 // Allocate the output tensor, trying to reuse the input.
1205 Tensor* output = nullptr;
1206 OP_REQUIRES_OK_ASYNC(c,
1207 c->forward_input_or_allocate_output(
1208 {0}, 0, col_params->instance.shape, &output),
1209 done_with_cleanup);
1210 Run(c, col_params, std::move(done_with_cleanup));
1211 }
1212
1213 private:
1214 std::unique_ptr<OpKernel> merge_op_;
1215 std::unique_ptr<OpKernel> final_op_;
1216};
1217
1218REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_CPU),
1219 CollectiveReduceV3OpKernel);
1220REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_GPU),
1221 CollectiveReduceV3OpKernel);
1222
1223class CollectiveAllToAllV3OpKernel : public CollectiveOpV3Kernel {
1224 public:
1225 explicit CollectiveAllToAllV3OpKernel(OpKernelConstruction* c)
1226 : CollectiveOpV3Kernel(c) {
1227 name_ = strings::StrCat(c->def().name(), ": AllToAllV3");
1228 VLOG(2) << "CollectiveAllToAllV3 " << this << " name " << name_;
1229 }
1230
1231 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
1232 auto col_params = new CollectiveParams();
1233 auto done_with_cleanup = [col_params, done = std::move(done)]() {
1234 done();
1235 col_params->Unref();
1236 };
1237 core::RefCountPtr<CollectiveGroupResource> resource;
1238 OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
1239 done_with_cleanup);
1240
1241 Tensor group_assignment = c->input(2);
1242
1243 OP_REQUIRES_OK_ASYNC(
1244 c,
1245 FillCollectiveParams(col_params, group_assignment,
1246 ALL_TO_ALL_COLLECTIVE, resource.get()),
1247 done_with_cleanup);
1248 col_params->instance.shape = c->input(0).shape();
1249 VLOG(1) << "CollectiveAllToAll group_size " << col_params->group.group_size
1250 << " group_key " << col_params->group.group_key << " instance_key "
1251 << col_params->instance.instance_key;
1252 // Allocate the output tensor, trying to reuse the input.
1253 Tensor* output = nullptr;
1254 OP_REQUIRES_OK_ASYNC(c,
1255 c->forward_input_or_allocate_output(
1256 {0}, 0, col_params->instance.shape, &output),
1257 done_with_cleanup);
1258 Run(c, col_params, std::move(done_with_cleanup));
1259 }
1260};
1261
1262REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_CPU),
1263 CollectiveAllToAllV3OpKernel);
1264REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_GPU),
1265 CollectiveAllToAllV3OpKernel);
1266} // namespace
1267} // namespace tensorflow
1268