1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <string> |
16 | #include <utility> |
17 | |
18 | #include "absl/strings/str_cat.h" |
19 | #include "absl/strings/str_format.h" |
20 | #include "tensorflow/core/framework/attr_value.pb.h" |
21 | #include "tensorflow/core/framework/collective.h" |
22 | #include "tensorflow/core/framework/device_attributes.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/op_requires.h" |
26 | #include "tensorflow/core/framework/resource_handle.h" |
27 | #include "tensorflow/core/framework/resource_mgr.h" |
28 | #include "tensorflow/core/framework/tensor_util.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/framework/types.pb.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/platform/errors.h" |
33 | #include "tensorflow/core/platform/refcount.h" |
34 | #include "tensorflow/core/platform/status.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | namespace { |
40 | |
41 | static string CollectiveKey(OpKernelContext* ctx, int32_t group_key, |
42 | int32_t instance_key) { |
43 | return strings::StrCat(group_key, ":" , instance_key, ":" , |
44 | ctx->frame_iter().frame_id, ":" , |
45 | ctx->frame_iter().iter_id); |
46 | } |
47 | |
48 | static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c, |
49 | const string& name, |
50 | NodeDef* sub_node) { |
51 | std::unique_ptr<OpKernel> k; |
52 | if (name.empty() || name == "Id" ) return k; |
53 | sub_node->set_name(name); |
54 | sub_node->set_op(name); |
55 | Status status; |
56 | k = CreateOpKernel(c->device_type(), c->device(), |
57 | c->device()->GetAllocator(AllocatorAttributes()), |
58 | *sub_node, c->graph_def_version(), &status); |
59 | if (!status.ok()) { |
60 | c->CtxFailureWithWarning(errors::Internal( |
61 | "Failed to build OpKernel for " , name, " : " , status.error_message())); |
62 | } |
63 | return k; |
64 | } |
65 | |
66 | class CollectiveOpV1Kernel : public AsyncOpKernel { |
67 | public: |
68 | explicit CollectiveOpV1Kernel(OpKernelConstruction* c) |
69 | : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {} |
70 | |
71 | ~CollectiveOpV1Kernel() override { col_params_->Unref(); } |
72 | |
73 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
74 | CollectiveExecutor* col_exec = c->collective_executor(); |
75 | OP_REQUIRES_ASYNC( |
76 | c, col_exec, |
77 | errors::Internal( |
78 | "Failed to get CollectiveExecutor from OpKernelContext for Op " , |
79 | name_), |
80 | done); |
81 | const CancellationToken token = |
82 | c->cancellation_manager()->get_cancellation_token(); |
83 | const bool already_cancelled = |
84 | !c->cancellation_manager()->RegisterCallback(token, [col_exec]() { |
85 | // We must call StartAbort() within the callback. StartAbort() relies |
86 | // on resources that may be deallocated if all execution of a graph is |
87 | // finished. |
88 | col_exec->StartAbort(errors::Cancelled("op cancelled" )); |
89 | }); |
90 | OP_REQUIRES_ASYNC(c, !already_cancelled, |
91 | errors::Cancelled("op cancelled " , name_), done); |
92 | |
93 | auto deregister_and_done = [c, token, done = std::move(done)]() { |
94 | // Once done() is called, StartAbort() won't have any effect, so we |
95 | // don't need to block on the deregistration. Also StartAbort() may call |
96 | // done() and DeregisterCallback may deadlock. |
97 | c->cancellation_manager()->TryDeregisterCallback(token); |
98 | done(); |
99 | }; |
100 | ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done)); |
101 | } |
102 | |
103 | // A string encoding instance, frame and iter to be handed off to |
104 | // the implementation for use in generating RecvBuf keys. |
105 | string GetCollectiveKey(OpKernelContext* c) { |
106 | return CollectiveKey(c, col_params_->group.group_key, |
107 | col_params_->instance.instance_key); |
108 | } |
109 | |
110 | // Returns false if calling invocation of ComputeAsync should return |
111 | // immediately. |
112 | bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec, |
113 | const DoneCallback& done) { |
114 | if (col_params_->group.group_size > col_params_->group.members.size()) { |
115 | // This is the first invocation: Finish initializing col_params_. |
116 | // Schedule the `CompleteParamsAsync` call on a work queue that can handle |
117 | // blocking work because it's not guaranteed that this call cannot block. |
118 | c->collective_executor()->RunClosure([this, c, col_exec, done]() { |
119 | VLOG(1) << "CollectiveOpKernel CompleteParams for collective " |
120 | << col_params_->name << " device " << c->device()->name() |
121 | << " group " << col_params_->group.group_key << " instance " |
122 | << col_params_->instance.instance_key; |
123 | col_exec->CompleteParamsAsync( |
124 | c->device()->attributes(), col_params_, c->cancellation_manager(), |
125 | [this, c, done](const Status& s) { |
126 | if (s.ok()) { |
127 | col_params_->instance.impl_details.dependencies = dependencies_; |
128 | ComputeAsync(c, done); |
129 | } else { |
130 | c->SetStatus(s); |
131 | done(); |
132 | } |
133 | }); |
134 | }); |
135 | return false; |
136 | } |
137 | return true; |
138 | } |
139 | |
140 | protected: |
141 | virtual void ComputeAsyncImpl(OpKernelContext* c, |
142 | CollectiveExecutor* col_exec, |
143 | DoneCallback done) = 0; |
144 | |
145 | string name_; |
146 | CollectiveParams* col_params_; |
147 | std::vector<int32> dependencies_; |
148 | }; |
149 | |
150 | class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { |
151 | public: |
152 | explicit CollectiveGatherOpKernel(OpKernelConstruction* c) |
153 | : CollectiveOpV1Kernel(c) { |
154 | col_params_->instance.type = GATHER_COLLECTIVE; |
155 | OP_REQUIRES_OK(c, c->GetAttr("group_size" , &col_params_->group.group_size)); |
156 | OP_REQUIRES( |
157 | c, col_params_->group.group_size > 0, |
158 | errors::InvalidArgument("group_size must be positive integer but got " , |
159 | col_params_->group.group_size)); |
160 | OP_REQUIRES_OK(c, c->GetAttr("group_key" , &col_params_->group.group_key)); |
161 | OP_REQUIRES_OK( |
162 | c, c->GetAttr("instance_key" , &col_params_->instance.instance_key)); |
163 | OP_REQUIRES_OK(c, c->GetAttr("T" , &col_params_->instance.data_type)); |
164 | OP_REQUIRES_OK( |
165 | c, c->GetAttr("communication_hint" , |
166 | &col_params_->instance.impl_details.communication_hint)); |
167 | OP_REQUIRES_OK( |
168 | c, c->GetAttr("timeout_seconds" , |
169 | &col_params_->instance.impl_details.timeout_seconds)); |
170 | const NodeDef& real_node = c->def(); |
171 | col_params_->name = strings::StrCat(real_node.name(), ": Gather" ); |
172 | col_params_->group.device_type = c->device_type(); |
173 | } |
174 | |
175 | protected: |
176 | void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, |
177 | DoneCallback done) override { |
178 | auto output_shape = c->input(0).shape(); |
179 | OP_REQUIRES_ASYNC(c, output_shape.dims() > 0, |
180 | errors::InvalidArgument("input should have rank > 0, " , |
181 | "recieved " , output_shape.dims()), |
182 | done); |
183 | output_shape.set_dim( |
184 | 0, output_shape.dim_size(0) * col_params_->group.group_size); |
185 | col_params_->instance.shape = output_shape; |
186 | |
187 | // Allocate output on the first pass through this function. This must be |
188 | // done immediately, while we're still in the executor thread. Otherwise |
189 | // the memory is not guaranteed to be unused by any concurrently executing |
190 | // GPU kernel. |
191 | if (c->mutable_output(0) == nullptr) { |
192 | // Allocate the output tensor. |
193 | Tensor* output = nullptr; |
194 | OP_REQUIRES_OK_ASYNC( |
195 | c, c->allocate_output(0, col_params_->instance.shape, &output), done); |
196 | } |
197 | if (!CanProceedWithCompute(c, col_exec, done)) return; |
198 | |
199 | auto actual_done = [c, col_params = col_params_, done](const Status& s) { |
200 | VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective " |
201 | << c->op_kernel().name() << " device " << c->device()->name() |
202 | << " group " << col_params->group.group_key << " instance " |
203 | << col_params->instance.instance_key << " status " << s; |
204 | col_params->Unref(); |
205 | OP_REQUIRES_OK_ASYNC(c, s, done); |
206 | done(); |
207 | }; |
208 | VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective " |
209 | << col_params_->name << " device " << c->device()->name() |
210 | << " group " << col_params_->group.group_key << " instance " |
211 | << col_params_->instance.instance_key; |
212 | col_params_->Ref(); |
213 | col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); |
214 | } |
215 | |
216 | private: |
217 | TF_DISALLOW_COPY_AND_ASSIGN(CollectiveGatherOpKernel); |
218 | }; |
219 | |
220 | REGISTER_KERNEL_BUILDER(Name("CollectiveGather" ).Device(DEVICE_CPU), |
221 | CollectiveGatherOpKernel); |
222 | REGISTER_KERNEL_BUILDER(Name("CollectiveGather" ).Device(DEVICE_GPU), |
223 | CollectiveGatherOpKernel); |
224 | |
225 | class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { |
226 | public: |
227 | explicit CollectiveReduceOpKernel(OpKernelConstruction* c) |
228 | : CollectiveOpV1Kernel(c) { |
229 | col_params_->instance.type = REDUCTION_COLLECTIVE; |
230 | OP_REQUIRES_OK(c, c->GetAttr("group_size" , &col_params_->group.group_size)); |
231 | OP_REQUIRES( |
232 | c, col_params_->group.group_size > 0, |
233 | errors::InvalidArgument("group_size must be positive integer but got " , |
234 | col_params_->group.group_size)); |
235 | OP_REQUIRES_OK(c, c->GetAttr("group_key" , &col_params_->group.group_key)); |
236 | OP_REQUIRES_OK( |
237 | c, c->GetAttr("instance_key" , &col_params_->instance.instance_key)); |
238 | OP_REQUIRES_OK( |
239 | c, c->GetAttr("subdiv_offsets" , |
240 | &col_params_->instance.impl_details.subdiv_offsets)); |
241 | string merge_op_name; |
242 | OP_REQUIRES_OK(c, c->GetAttr("merge_op" , &merge_op_name)); |
243 | if (merge_op_name == "Max" ) { |
244 | merge_op_name = "Maximum" ; |
245 | } else if (merge_op_name == "Min" ) { |
246 | merge_op_name = "Minimum" ; |
247 | } |
248 | string final_op_name; |
249 | OP_REQUIRES_OK(c, c->GetAttr("final_op" , &final_op_name)); |
250 | OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div" , |
251 | errors::InvalidArgument( |
252 | "final_op must be one of {\"Id\", \"Div\"} but got " , |
253 | final_op_name)); |
254 | OP_REQUIRES_OK(c, c->GetAttr("T" , &col_params_->instance.data_type)); |
255 | OP_REQUIRES_OK(c, c->GetAttr("wait_for" , &dependencies_)); |
256 | OP_REQUIRES_OK( |
257 | c, c->GetAttr("communication_hint" , |
258 | &col_params_->instance.impl_details.communication_hint)); |
259 | OP_REQUIRES_OK( |
260 | c, c->GetAttr("timeout_seconds" , |
261 | &col_params_->instance.impl_details.timeout_seconds)); |
262 | VLOG(2) << "CollectiveReduce instance " |
263 | << col_params_->instance.instance_key << " merge_op " |
264 | << merge_op_name << " final_op " << final_op_name |
265 | << " communication_hint " |
266 | << col_params_->instance.impl_details.communication_hint |
267 | << " timeout " |
268 | << col_params_->instance.impl_details.timeout_seconds; |
269 | |
270 | const NodeDef& real_node = c->def(); |
271 | col_params_->name = strings::StrCat(real_node.name(), ": Reduce(" , |
272 | merge_op_name, "," , final_op_name, ")" ); |
273 | col_params_->group.device_type = c->device_type(); |
274 | |
275 | // Find the OpKernels by name, type and device type. |
276 | NodeDef sub_node; |
277 | // The merge_op takes two inputs |
278 | sub_node.add_input(real_node.input(0)); |
279 | sub_node.add_input(real_node.input(0)); |
280 | sub_node.set_device(real_node.device()); |
281 | SetAttrValue(col_params_->instance.data_type, |
282 | &(*sub_node.mutable_attr())["T" ]); |
283 | merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node); |
284 | final_op_ = BuildOpKernel(c, final_op_name, &sub_node); |
285 | col_params_->merge_op = merge_op_.get(); |
286 | col_params_->final_op = final_op_.get(); |
287 | } |
288 | |
289 | protected: |
290 | void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, |
291 | DoneCallback done) override { |
292 | // Allocate output on the first pass through this function. This must be |
293 | // done immediately, while we're still in the executor thread. Otherwise |
294 | // the memory is not guaranteed to be unused by any concurrently executing |
295 | // GPU kernel. |
296 | if (c->mutable_output(0) == nullptr) { |
297 | // Allocate the output tensor, trying to reuse the input. |
298 | Tensor* output = nullptr; |
299 | OP_REQUIRES_OK_ASYNC(c, |
300 | c->forward_input_or_allocate_output( |
301 | {0}, 0, c->input(0).shape(), &output), |
302 | done); |
303 | col_params_->instance.shape = c->input(0).shape(); |
304 | } |
305 | if (!CanProceedWithCompute(c, col_exec, done)) return; |
306 | |
307 | auto actual_done = [c, col_params = col_params_, done](const Status& s) { |
308 | VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective " |
309 | << c->op_kernel().name() << " device " << c->device()->name() |
310 | << " group " << col_params->group.group_key << " instance " |
311 | << col_params->instance.instance_key << " status " << s; |
312 | col_params->Unref(); |
313 | OP_REQUIRES_OK_ASYNC(c, s, done); |
314 | done(); |
315 | }; |
316 | VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective " |
317 | << col_params_->name << " device " << c->device()->name() |
318 | << " group " << col_params_->group.group_key << " instance " |
319 | << col_params_->instance.instance_key; |
320 | col_params_->Ref(); |
321 | col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); |
322 | } |
323 | |
324 | private: |
325 | std::unique_ptr<OpKernel> merge_op_; |
326 | std::unique_ptr<OpKernel> final_op_; |
327 | TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel); |
328 | }; |
329 | |
330 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduce" ).Device(DEVICE_CPU), |
331 | CollectiveReduceOpKernel); |
332 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduce" ).Device(DEVICE_GPU), |
333 | CollectiveReduceOpKernel); |
334 | |
335 | class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel { |
336 | public: |
337 | explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c) |
338 | : CollectiveOpV1Kernel(c) { |
339 | col_params_->instance.type = BROADCAST_COLLECTIVE; |
340 | OP_REQUIRES_OK(c, c->GetAttr("group_size" , &col_params_->group.group_size)); |
341 | OP_REQUIRES( |
342 | c, col_params_->group.group_size > 0, |
343 | errors::InvalidArgument("group_size must be positive integer but got " , |
344 | col_params_->group.group_size)); |
345 | OP_REQUIRES_OK(c, c->GetAttr("group_key" , &col_params_->group.group_key)); |
346 | OP_REQUIRES_OK( |
347 | c, c->GetAttr("instance_key" , &col_params_->instance.instance_key)); |
348 | OP_REQUIRES_OK(c, c->GetAttr("T" , &col_params_->instance.data_type)); |
349 | OP_REQUIRES_OK(c, c->GetAttr("shape" , &col_params_->instance.shape)); |
350 | OP_REQUIRES_OK( |
351 | c, c->GetAttr("communication_hint" , |
352 | &col_params_->instance.impl_details.communication_hint)); |
353 | OP_REQUIRES_OK( |
354 | c, c->GetAttr("timeout_seconds" , |
355 | &col_params_->instance.impl_details.timeout_seconds)); |
356 | col_params_->is_source = true; |
357 | col_params_->instance.impl_details.subdiv_offsets = {0}; |
358 | |
359 | col_params_->name = |
360 | strings::StrCat(name(), ": Broadcast(" , col_params_->is_source, ")" ); |
361 | col_params_->group.device_type = c->device_type(); |
362 | } |
363 | |
364 | protected: |
365 | void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, |
366 | DoneCallback done) override { |
367 | // Allocate output on the first pass through this function. This must be |
368 | // done immediately, while we're still in the executor thread. Otherwise |
369 | // the memory is not guaranteed to be unused by any concurrently executing |
370 | // GPU kernel. |
371 | if (c->mutable_output(0) == nullptr) { |
372 | // Allocate the output tensor, trying to reuse the input. |
373 | Tensor* output = nullptr; |
374 | OP_REQUIRES_OK_ASYNC(c, |
375 | c->forward_input_or_allocate_output( |
376 | {0}, 0, col_params_->instance.shape, &output), |
377 | done); |
378 | } |
379 | if (!CanProceedWithCompute(c, col_exec, done)) return; |
380 | OP_REQUIRES_ASYNC( |
381 | c, col_params_->instance.shape.IsSameSize(c->input(0).shape()), |
382 | errors::Internal("Declared shape of op " , col_params_->name, |
383 | " does not match shape of input" ), |
384 | done); |
385 | |
386 | auto actual_done = [c, col_params = col_params_, done](const Status& s) { |
387 | VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective " |
388 | << c->op_kernel().name() << " device " << c->device()->name() |
389 | << " group " << col_params->group.group_key << " instance " |
390 | << col_params->instance.instance_key << " status " << s; |
391 | col_params->Unref(); |
392 | OP_REQUIRES_OK_ASYNC(c, s, done); |
393 | done(); |
394 | }; |
395 | VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective " |
396 | << col_params_->name << " device " << c->device()->name() |
397 | << " group " << col_params_->group.group_key << " instance " |
398 | << col_params_->instance.instance_key; |
399 | col_params_->Ref(); |
400 | col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); |
401 | } |
402 | |
403 | private: |
404 | TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel); |
405 | }; |
406 | |
407 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend" ).Device(DEVICE_CPU), |
408 | CollectiveBcastSendOpKernel); |
409 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend" ).Device(DEVICE_DEFAULT), |
410 | CollectiveBcastSendOpKernel); |
411 | |
412 | class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel { |
413 | public: |
414 | explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c) |
415 | : CollectiveOpV1Kernel(c) { |
416 | col_params_->instance.type = BROADCAST_COLLECTIVE; |
417 | OP_REQUIRES_OK(c, c->GetAttr("group_size" , &col_params_->group.group_size)); |
418 | OP_REQUIRES( |
419 | c, col_params_->group.group_size > 0, |
420 | errors::InvalidArgument("group_size must be positive integer but got " , |
421 | col_params_->group.group_size)); |
422 | OP_REQUIRES_OK(c, c->GetAttr("group_key" , &col_params_->group.group_key)); |
423 | OP_REQUIRES_OK( |
424 | c, c->GetAttr("instance_key" , &col_params_->instance.instance_key)); |
425 | OP_REQUIRES_OK(c, c->GetAttr("T" , &col_params_->instance.data_type)); |
426 | OP_REQUIRES_OK(c, c->GetAttr("shape" , &col_params_->instance.shape)); |
427 | OP_REQUIRES_OK( |
428 | c, c->GetAttr("communication_hint" , |
429 | &col_params_->instance.impl_details.communication_hint)); |
430 | OP_REQUIRES_OK( |
431 | c, c->GetAttr("timeout_seconds" , |
432 | &col_params_->instance.impl_details.timeout_seconds)); |
433 | col_params_->is_source = false; |
434 | col_params_->instance.impl_details.subdiv_offsets = {0}; |
435 | |
436 | col_params_->name = |
437 | strings::StrCat(name(), ": Broadcast(" , col_params_->is_source, ")" ); |
438 | col_params_->group.device_type = c->device_type(); |
439 | } |
440 | |
441 | protected: |
442 | void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, |
443 | DoneCallback done) override { |
444 | // Allocate output on the first pass through this function. This must be |
445 | // done immediately, while we're still in the executor thread. Otherwise |
446 | // the memory is not guaranteed to be unused by any concurrently executing |
447 | // GPU kernel. |
448 | if (c->mutable_output(0) == nullptr) { |
449 | // No input, so must allocate output. |
450 | Tensor* output = nullptr; |
451 | OP_REQUIRES_OK_ASYNC( |
452 | c, c->allocate_output(0, col_params_->instance.shape, &output), done); |
453 | } |
454 | if (!CanProceedWithCompute(c, col_exec, done)) return; |
455 | |
456 | auto actual_done = [c, col_params = col_params_, done](const Status& s) { |
457 | VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective " |
458 | << c->op_kernel().name() << " device " << c->device()->name() |
459 | << " group " << col_params->group.group_key << " instance_key " |
460 | << col_params->instance.instance_key << " status " << s; |
461 | col_params->Unref(); |
462 | OP_REQUIRES_OK_ASYNC(c, s, done); |
463 | done(); |
464 | }; |
465 | VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective " |
466 | << col_params_->name << " device " << c->device()->name() |
467 | << " group " << col_params_->group.group_key << " instance " |
468 | << col_params_->instance.instance_key; |
469 | col_params_->Ref(); |
470 | col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); |
471 | } |
472 | |
473 | private: |
474 | TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel); |
475 | }; |
476 | |
477 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv" ).Device(DEVICE_CPU), |
478 | CollectiveBcastRecvOpKernel); |
479 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv" ).Device(DEVICE_DEFAULT), |
480 | CollectiveBcastRecvOpKernel); |
481 | |
482 | class CollectiveAssignGroupV2OpKernel : public OpKernel { |
483 | public: |
484 | explicit CollectiveAssignGroupV2OpKernel(OpKernelConstruction* c) |
485 | : OpKernel(c) {} |
486 | |
487 | void Compute(OpKernelContext* context) override { |
488 | const Tensor& group_assignment = context->input(0); |
489 | const Tensor& device_index = context->input(1); |
490 | const Tensor& base_key = context->input(2); |
491 | |
492 | OP_REQUIRES( |
493 | context, TensorShapeUtils::IsScalar(device_index.shape()), |
494 | errors::InvalidArgument( |
495 | "device_index must be a scalar, but received tensor of shape: " , |
496 | device_index.shape().DebugString())); |
497 | |
498 | OP_REQUIRES( |
499 | context, TensorShapeUtils::IsMatrix(group_assignment.shape()), |
500 | errors::InvalidArgument("group_assignment must be a 2-d Tensor, but " |
501 | "received tensor of shape: " , |
502 | group_assignment.shape().DebugString())); |
503 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(base_key.shape()), |
504 | errors::InvalidArgument( |
505 | "base_key must be a scalar, but received tensor of shape: " , |
506 | base_key.shape().DebugString())); |
507 | |
508 | Tensor* group_key = nullptr; |
509 | Tensor* group_size = nullptr; |
510 | AllocatorAttributes attr; |
511 | attr.set_on_host(true); |
512 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), |
513 | &group_size, attr)); |
514 | |
515 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), |
516 | &group_key, attr)); |
517 | |
518 | OP_REQUIRES_OK( |
519 | context, |
520 | ComputeGroupKey(group_assignment, device_index.scalar<int32_t>()(), |
521 | base_key.scalar<int32_t>()(), group_size, group_key)); |
522 | } |
523 | |
524 | private: |
525 | static Status ComputeGroupKey(const Tensor& group_assignment, |
526 | const int32_t device_index, |
527 | const int32_t base_key, Tensor* group_size, |
528 | Tensor* group_key) { |
529 | group_size->flat<int32_t>()(0) = group_assignment.dim_size(1); |
530 | |
531 | for (int group_id = 0; group_id < group_assignment.dim_size(0); |
532 | group_id++) { |
533 | int32_t key = static_cast<int32_t>(static_cast<uint32_t>(base_key) + |
534 | static_cast<uint32_t>(group_id)); |
535 | if (key == 0) { |
536 | return errors::InvalidArgument( |
537 | "Using the reserved group_key = 0 is not allowed: group_id = " , |
538 | group_id, ", base_key = " , base_key); |
539 | } |
540 | for (int color = 0; color < group_assignment.dim_size(1); color++) { |
541 | const auto index = group_assignment.matrix<int32>()(group_id, color); |
542 | if (index < 0 || index >= group_assignment.shape().num_elements()) { |
543 | return errors::InvalidArgument("Not all items in group_assignment " , |
544 | group_assignment.DebugString(), |
545 | " is within [0, number of devices)" ); |
546 | } |
547 | if (index == device_index) { |
548 | group_key->flat<int32_t>()(0) = key; |
549 | VLOG(2) << " group_assignment = " << group_assignment.DebugString() |
550 | << " device_index = " << index |
551 | << " group_key = " << group_key->DebugString() |
552 | << " group_size = " << group_size->DebugString(); |
553 | return OkStatus(); |
554 | } |
555 | } |
556 | } |
557 | return errors::InvalidArgument("device_index " , device_index, |
558 | " is not found in group_assignment " , |
559 | group_assignment.DebugString()); |
560 | } |
561 | }; |
562 | |
563 | REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2" ).Device(DEVICE_CPU), |
564 | CollectiveAssignGroupV2OpKernel); |
565 | REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2" ) |
566 | .Device(DEVICE_DEFAULT) |
567 | .HostMemory("device_index" ) |
568 | .HostMemory("group_assignment" ) |
569 | .HostMemory("base_key" ) |
570 | .HostMemory("group_size" ) |
571 | .HostMemory("group_key" ), |
572 | CollectiveAssignGroupV2OpKernel); |
573 | |
574 | class CollectiveOpV2Kernel : public AsyncOpKernel { |
575 | public: |
576 | explicit CollectiveOpV2Kernel(OpKernelConstruction* c) |
577 | : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) { |
578 | OP_REQUIRES_OK(c, c->GetAttr("T" , &data_type_)); |
579 | OP_REQUIRES_OK(c, c->GetAttr("communication_hint" , &communication_hint_)); |
580 | OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds" , &timeout_seconds_)); |
581 | device_type_ = c->device_type(); |
582 | } |
583 | |
584 | protected: |
585 | // Fills common parts of CollectiveParams according to the Op, *excluding |
586 | // output_shape*. Kernels should further work on the CollectiveParams if they |
587 | // need to set additional fields. |
588 | Status FillCollectiveParams(CollectiveParams* col_params, |
589 | CollectiveType collective_type, |
590 | const Tensor& group_size, const Tensor& group_key, |
591 | const Tensor& instance_key) { |
592 | if (group_size.dims() > 0) { |
593 | return errors::InvalidArgument( |
594 | "Unexpected dimensions on input group_size, got " , |
595 | group_size.shape().DebugString()); |
596 | } |
597 | if (group_key.dims() > 0) { |
598 | return errors::InvalidArgument( |
599 | "Unexpected dimensions on input group_key, got " , |
600 | group_key.shape().DebugString()); |
601 | } |
602 | if (instance_key.dims() > 0) { |
603 | return errors::InvalidArgument( |
604 | "Unexpected dimensions on input instance_key, got " , |
605 | instance_key.shape().DebugString()); |
606 | } |
607 | col_params->name = name_; |
608 | col_params->group.device_type = device_type_; |
609 | col_params->group.group_size = group_size.unaligned_flat<int32>()(0); |
610 | if (col_params->group.group_size <= 0) { |
611 | return errors::InvalidArgument( |
612 | "group_size must be positive integer but got " , |
613 | col_params->group.group_size); |
614 | } |
615 | col_params->group.group_key = group_key.unaligned_flat<int32>()(0); |
616 | col_params->instance.type = collective_type; |
617 | col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0); |
618 | col_params->instance.data_type = data_type_; |
619 | col_params->instance.impl_details.communication_hint = communication_hint_; |
620 | col_params->instance.impl_details.timeout_seconds = timeout_seconds_; |
621 | return OkStatus(); |
622 | } |
623 | |
624 | // Runs a collective. The output tensor must be allocated before calling this |
625 | // method. col_params must live until done is called. |
626 | void Run(OpKernelContext* c, CollectiveParams* col_params, |
627 | DoneCallback done) { |
628 | CollectiveExecutor* col_exec = c->collective_executor(); |
629 | OP_REQUIRES_ASYNC( |
630 | c, col_exec, |
631 | errors::Internal( |
632 | "Failed to get CollectiveExecutor from OpKernelContext for Op " , |
633 | name_), |
634 | done); |
635 | // Resolve the collective params. |
636 | // Schedule the `CompleteParamsAsync` call on a work queue that can handle |
637 | // blocking work because it's not guaranteed that this call cannot block. |
638 | c->collective_executor()->RunClosure([c, done = std::move(done), col_params, |
639 | col_exec]() { |
640 | VLOG(1) << "Collective CompleteParams for " << col_params->name |
641 | << " device " << c->device()->name() << " group " |
642 | << col_params->group.group_key << " instance " |
643 | << col_params->instance.instance_key; |
644 | col_exec->CompleteParamsAsync( |
645 | c->device()->attributes(), col_params, c->cancellation_manager(), |
646 | [c, done = std::move(done), col_params, col_exec](const Status& s) { |
647 | if (s.ok()) { |
648 | auto actual_done = [c, col_params, |
649 | done = std::move(done)](const Status& s) { |
650 | VLOG(1) << "Collective ExecuteAsync done for " |
651 | << col_params->name << " device " << c->device()->name() |
652 | << " group " << col_params->group.group_key |
653 | << " instance " << col_params->instance.instance_key |
654 | << " status " << s; |
655 | if (!s.ok()) { |
656 | c->SetStatus(s); |
657 | } |
658 | done(); |
659 | }; |
660 | VLOG(1) << "Collective ExecuteAsync start for " |
661 | << col_params->name << " device " << c->device()->name() |
662 | << " group " << col_params->group.group_key |
663 | << " instance " << col_params->instance.instance_key; |
664 | col_exec->ExecuteAsync( |
665 | c, col_params, |
666 | CollectiveKey(c, col_params->group.group_key, |
667 | col_params->instance.instance_key), |
668 | actual_done); |
669 | } else { |
670 | c->SetStatus(s); |
671 | done(); |
672 | } |
673 | }); |
674 | }); |
675 | } |
676 | |
677 | protected: |
678 | string name_; |
679 | DataType data_type_ = DT_INVALID; |
680 | string communication_hint_; |
681 | float timeout_seconds_ = 0; |
682 | DeviceType device_type_; |
683 | }; |
684 | |
685 | class CollectiveReduceV2OpKernel : public CollectiveOpV2Kernel { |
686 | public: |
687 | explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c) |
688 | : CollectiveOpV2Kernel(c) { |
689 | string merge_op_name; |
690 | OP_REQUIRES_OK(c, c->GetAttr("merge_op" , &merge_op_name)); |
691 | if (merge_op_name == "Max" ) { |
692 | merge_op_name = "Maximum" ; |
693 | } else if (merge_op_name == "Min" ) { |
694 | merge_op_name = "Minimum" ; |
695 | } |
696 | string final_op_name; |
697 | OP_REQUIRES_OK(c, c->GetAttr("final_op" , &final_op_name)); |
698 | OP_REQUIRES_OK( |
699 | c, c->GetAttr("max_subdivs_per_device" , &max_subdivs_per_device_)); |
700 | // Prepare OpKernels for reduction and final operations. |
701 | // The merge_op takes two inputs |
702 | NodeDef sub_node; |
703 | sub_node.add_input(c->def().input(0)); |
704 | sub_node.add_input(c->def().input(0)); |
705 | sub_node.set_device(c->def().device()); |
706 | SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T" ]); |
707 | merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node); |
708 | final_op_ = BuildOpKernel(c, final_op_name, &sub_node); |
709 | name_ = strings::StrCat(c->def().name(), ": ReduceV2(" , merge_op_name, "," , |
710 | final_op_name, ")" ); |
711 | VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_ |
712 | << " communication_hint " << communication_hint_; |
713 | } |
714 | |
715 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
716 | auto col_params = new CollectiveParams(); |
717 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
718 | done(); |
719 | col_params->Unref(); |
720 | }; |
721 | OP_REQUIRES_OK_ASYNC(c, |
722 | FillCollectiveParams(col_params, REDUCTION_COLLECTIVE, |
723 | /*group_size*/ c->input(1), |
724 | /*group_key*/ c->input(2), |
725 | /*instance_key*/ c->input(3)), |
726 | done_with_cleanup); |
727 | col_params->instance.shape = c->input(0).shape(); |
728 | col_params->merge_op = merge_op_.get(); |
729 | col_params->final_op = final_op_.get(); |
730 | VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size |
731 | << " group_key " << col_params->group.group_key << " instance_key " |
732 | << col_params->instance.instance_key; |
733 | // Allocate the output tensor, trying to reuse the input. |
734 | Tensor* output = nullptr; |
735 | OP_REQUIRES_OK_ASYNC(c, |
736 | c->forward_input_or_allocate_output( |
737 | {0}, 0, col_params->instance.shape, &output), |
738 | done_with_cleanup); |
739 | Run(c, col_params, std::move(done_with_cleanup)); |
740 | } |
741 | |
742 | private: |
743 | int max_subdivs_per_device_; |
744 | std::unique_ptr<OpKernel> merge_op_; |
745 | std::unique_ptr<OpKernel> final_op_; |
746 | }; |
747 | |
748 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2" ).Device(DEVICE_CPU), |
749 | CollectiveReduceV2OpKernel); |
750 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2" ) |
751 | .Device(DEVICE_DEFAULT) |
752 | .HostMemory("group_size" ) |
753 | .HostMemory("group_key" ) |
754 | .HostMemory("instance_key" ), |
755 | CollectiveReduceV2OpKernel); |
756 | |
757 | class CollectiveGatherV2OpKernel : public CollectiveOpV2Kernel { |
758 | public: |
759 | explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c) |
760 | : CollectiveOpV2Kernel(c) { |
761 | name_ = strings::StrCat(c->def().name(), ": GatherV2" ); |
762 | VLOG(2) << "CollectiveGatherV2 " << this << " name " << name_ |
763 | << " communication_hint " << communication_hint_; |
764 | } |
765 | |
766 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
767 | auto col_params = new CollectiveParams(); |
768 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
769 | done(); |
770 | col_params->Unref(); |
771 | }; |
772 | OP_REQUIRES_OK_ASYNC(c, |
773 | FillCollectiveParams(col_params, GATHER_COLLECTIVE, |
774 | /*group_size*/ c->input(1), |
775 | /*group_key*/ c->input(2), |
776 | /*instance_key*/ |
777 | c->input(3)), |
778 | done_with_cleanup); |
779 | auto output_shape = c->input(0).shape(); |
780 | output_shape.set_dim( |
781 | 0, output_shape.dim_size(0) * col_params->group.group_size); |
782 | col_params->instance.shape = output_shape; |
783 | VLOG(1) << "CollectiveGatherV2 group_size " << col_params->group.group_size |
784 | << " group_key " << col_params->group.group_key << " instance_key " |
785 | << col_params->instance.instance_key; |
786 | Tensor* output = nullptr; |
787 | OP_REQUIRES_OK_ASYNC( |
788 | c, c->allocate_output(0, col_params->instance.shape, &output), |
789 | done_with_cleanup); |
790 | Run(c, col_params, std::move(done_with_cleanup)); |
791 | } |
792 | }; |
793 | |
794 | REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2" ).Device(DEVICE_CPU), |
795 | CollectiveGatherV2OpKernel); |
796 | REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2" ) |
797 | .Device(DEVICE_DEFAULT) |
798 | .HostMemory("group_size" ) |
799 | .HostMemory("group_key" ) |
800 | .HostMemory("instance_key" ), |
801 | CollectiveGatherV2OpKernel); |
802 | |
803 | class CollectiveBcastSendV2OpKernel : public CollectiveOpV2Kernel { |
804 | public: |
805 | explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c) |
806 | : CollectiveOpV2Kernel(c) { |
807 | const bool is_source = true; |
808 | name_ = strings::StrCat(name(), ": Broadcast(" , is_source, ")" ); |
809 | } |
810 | |
811 | protected: |
812 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
813 | auto col_params = new CollectiveParams(); |
814 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
815 | done(); |
816 | col_params->Unref(); |
817 | }; |
818 | OP_REQUIRES_OK_ASYNC(c, |
819 | FillCollectiveParams(col_params, BROADCAST_COLLECTIVE, |
820 | /*group_size*/ c->input(1), |
821 | /*group_key*/ c->input(2), |
822 | /*instance_key*/ c->input(3)), |
823 | done_with_cleanup); |
824 | col_params->is_source = true; |
825 | col_params->instance.shape = c->input(0).shape(); |
826 | // Add a default value for subdiv offsets, which is the same as the default |
827 | // value in the V1 op's attribute. |
828 | col_params->instance.impl_details.subdiv_offsets.push_back(0); |
829 | VLOG(1) << "CollectiveBcastSendV2 group_size " |
830 | << col_params->group.group_size << " group_key " |
831 | << col_params->group.group_key << " instance_key " |
832 | << col_params->instance.instance_key; |
833 | // Allocate the output tensor, trying to reuse the input. |
834 | Tensor* output = nullptr; |
835 | OP_REQUIRES_OK_ASYNC(c, |
836 | c->forward_input_or_allocate_output( |
837 | {0}, 0, col_params->instance.shape, &output), |
838 | done_with_cleanup); |
839 | Run(c, col_params, std::move(done_with_cleanup)); |
840 | } |
841 | }; |
842 | |
843 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2" ).Device(DEVICE_CPU), |
844 | CollectiveBcastSendV2OpKernel); |
845 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2" ) |
846 | .Device(DEVICE_DEFAULT) |
847 | .HostMemory("group_size" ) |
848 | .HostMemory("group_key" ) |
849 | .HostMemory("instance_key" ), |
850 | CollectiveBcastSendV2OpKernel); |
851 | |
852 | class CollectiveBcastRecvV2OpKernel : public CollectiveOpV2Kernel { |
853 | public: |
854 | explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c) |
855 | : CollectiveOpV2Kernel(c) { |
856 | const bool is_source = false; |
857 | name_ = strings::StrCat(name(), ": Broadcast(" , is_source, ")" ); |
858 | } |
859 | |
860 | protected: |
861 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
862 | auto col_params = new CollectiveParams(); |
863 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
864 | done(); |
865 | col_params->Unref(); |
866 | }; |
867 | OP_REQUIRES_OK_ASYNC(c, |
868 | FillCollectiveParams(col_params, BROADCAST_COLLECTIVE, |
869 | /*group_size*/ c->input(0), |
870 | /*group_key*/ c->input(1), |
871 | /*instance_key*/ c->input(2)), |
872 | done_with_cleanup); |
873 | col_params->is_source = false; |
874 | TensorShape output_shape; |
875 | OP_REQUIRES_OK_ASYNC(c, tensor::MakeShape(c->input(3), &output_shape), |
876 | done_with_cleanup); |
877 | col_params->instance.shape = output_shape; |
878 | // Add a default value for subdiv offsets, which is the same as the default |
879 | // value in the V1 op's attribute. |
880 | col_params->instance.impl_details.subdiv_offsets.push_back(0); |
881 | VLOG(1) << "CollectiveBcastRecvV2 group_size " |
882 | << col_params->group.group_size << " group_key " |
883 | << col_params->group.group_key << " instance_key " |
884 | << col_params->instance.instance_key; |
885 | Tensor* output = nullptr; |
886 | OP_REQUIRES_OK_ASYNC( |
887 | c, c->allocate_output(0, col_params->instance.shape, &output), |
888 | done_with_cleanup); |
889 | Run(c, col_params, std::move(done_with_cleanup)); |
890 | } |
891 | }; |
892 | |
893 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2" ).Device(DEVICE_CPU), |
894 | CollectiveBcastRecvV2OpKernel); |
895 | REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2" ) |
896 | .Device(DEVICE_DEFAULT) |
897 | .HostMemory("group_size" ) |
898 | .HostMemory("group_key" ) |
899 | .HostMemory("instance_key" ) |
900 | .HostMemory("shape" ), |
901 | CollectiveBcastRecvV2OpKernel); |
902 | |
903 | /* |
904 | * Resource for holding group for CollectiveOps. |
905 | * This resource is returned from CollectiveInitializeCommunicatorOpKernel |
906 | * It generates next instance key for the group for each collective operation. |
907 | */ |
908 | class CollectiveGroupResource : public ResourceBase { |
909 | public: |
910 | CollectiveGroupResource(int32 group_key, int32 rank, int32 group_size, |
911 | string communication_hint, float timeout_seconds) |
912 | : group_key_(group_key), |
913 | rank_(rank), |
914 | group_size_(group_size), |
915 | communication_hint_(communication_hint), |
916 | timeout_seconds_(timeout_seconds) {} |
917 | |
918 | std::string DebugString() const override { |
919 | return absl::StrFormat( |
920 | "Collective Group with group_key = %d, group_size = %d, rank = %d" , |
921 | group_key_, group_size_, rank_); |
922 | } |
923 | |
924 | int get_next_instance_key() { |
925 | return instance_key_.fetch_add(1, std::memory_order_relaxed); |
926 | } |
927 | |
928 | int32 group_key() const { return group_key_; } |
929 | |
930 | int32 rank() const { return rank_; } |
931 | |
932 | int32 group_size() const { return group_size_; } |
933 | |
934 | string communication_hint() const { return communication_hint_; } |
935 | |
936 | float timeout_seconds() const { return timeout_seconds_; } |
937 | |
938 | private: |
939 | int32 group_key_, rank_, group_size_; |
940 | string communication_hint_; |
941 | std::atomic<int> instance_key_{0}; |
942 | float timeout_seconds_ = 0; |
943 | }; |
944 | |
945 | class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel { |
946 | public: |
947 | explicit CollectiveInitializeCommunicatorOpKernel(OpKernelConstruction* c) |
948 | : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) { |
949 | OP_REQUIRES_OK(c, c->GetAttr("communication_hint" , &communication_hint_)); |
950 | OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds" , &timeout_seconds_)); |
951 | device_type_ = c->device_type(); |
952 | } |
953 | |
954 | Status CheckInputs(Tensor group_size_t, Tensor group_key_t) { |
955 | if (group_size_t.dims() > 0) { |
956 | return errors::InvalidArgument( |
957 | "Unexpected dimensions on input group_size. " |
958 | "It shoulbe a scalar, got tensor with shape " , |
959 | group_size_t.shape().DebugString()); |
960 | } |
961 | if (group_key_t.dims() > 0) { |
962 | return errors::InvalidArgument( |
963 | "Unexpected dimensions on input group_key, got " , |
964 | group_key_t.shape().DebugString()); |
965 | } |
966 | |
967 | auto group_size = group_size_t.unaligned_flat<int32>()(0); |
968 | if (group_size <= 0) { |
969 | return errors::InvalidArgument( |
970 | "group_size must be positive integer but got " , group_size); |
971 | } |
972 | return OkStatus(); |
973 | } |
974 | |
975 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
976 | auto group_key_t = c->input(0); |
977 | auto rank_t = c->input(1); |
978 | auto group_size_t = c->input(2); |
979 | |
980 | OP_REQUIRES_OK_ASYNC(c, CheckInputs(group_size_t, group_key_t), done); |
981 | |
982 | auto group_size = group_size_t.unaligned_flat<int32>()(0); |
983 | auto group_key = group_key_t.unaligned_flat<int32>()(0); |
984 | auto rank = rank_t.unaligned_flat<int32>()(0); |
985 | |
986 | ResourceHandle resource_handle = |
987 | MakeResourceHandle<CollectiveGroupResource>( |
988 | c, "collective_op_group" , |
989 | absl::StrFormat("%d:r%04d" , group_key, rank)); |
990 | |
991 | Tensor* output_handle = nullptr; |
992 | OP_REQUIRES_OK_ASYNC( |
993 | c, c->allocate_output(0, TensorShape({}), &output_handle), done); |
994 | output_handle->scalar<ResourceHandle>()() = resource_handle; |
995 | |
996 | CollectiveGroupResource* resource = new CollectiveGroupResource( |
997 | group_key, rank, group_size, this->communication_hint_, |
998 | this->timeout_seconds_); |
999 | OP_REQUIRES_OK_ASYNC( |
1000 | c, |
1001 | CreateResource<CollectiveGroupResource>(c, resource_handle, resource), |
1002 | done); |
1003 | auto group_params = new CollGroupParams(); |
1004 | group_params->device_type = device_type_; |
1005 | group_params->group_size = resource->group_size(); |
1006 | group_params->group_key = resource->group_key(); |
1007 | group_params->user_specified_rank = resource->rank(); |
1008 | |
1009 | auto* col_exec = c->collective_executor(); |
1010 | |
1011 | c->collective_executor()->RunClosure([c, done = std::move(done), |
1012 | group_params, col_exec]() { |
1013 | VLOG(1) << "Collective Group initialization for " |
1014 | << " device " << c->device()->name() << " group " |
1015 | << group_params->group_key; |
1016 | col_exec->CompleteGroupAsync( |
1017 | c->device()->attributes(), group_params, c->cancellation_manager(), |
1018 | [c, done = std::move(done), group_params](const Status& s) { |
1019 | if (s.ok()) { |
1020 | VLOG(1) << "Collective Group initialization done for device " |
1021 | << c->device()->name() << " group " |
1022 | << group_params->group_key << " status " << s; |
1023 | } else { |
1024 | c->SetStatus(s); |
1025 | } |
1026 | delete group_params; |
1027 | done(); |
1028 | }); |
1029 | }); |
1030 | } |
1031 | |
1032 | private: |
1033 | string communication_hint_; |
1034 | DeviceType device_type_; |
1035 | float timeout_seconds_ = 0; |
1036 | }; |
1037 | |
1038 | REGISTER_KERNEL_BUILDER( |
1039 | Name("CollectiveInitializeCommunicator" ).Device(DEVICE_CPU), |
1040 | CollectiveInitializeCommunicatorOpKernel); |
1041 | REGISTER_KERNEL_BUILDER(Name("CollectiveInitializeCommunicator" ) |
1042 | .Device(DEVICE_GPU) |
1043 | .HostMemory("group_size" ) |
1044 | .HostMemory("group_key" ) |
1045 | .HostMemory("rank" ), |
1046 | CollectiveInitializeCommunicatorOpKernel); |
1047 | |
1048 | class CollectiveOpV3Kernel : public AsyncOpKernel { |
1049 | public: |
1050 | explicit CollectiveOpV3Kernel(OpKernelConstruction* c) |
1051 | : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) { |
1052 | OP_REQUIRES_OK(c, c->GetAttr("T" , &data_type_)); |
1053 | if (c->HasAttr("timeout_seconds" )) { |
1054 | OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds" , &timeout_seconds_)); |
1055 | } else { |
1056 | timeout_seconds_ = -1; |
1057 | } |
1058 | device_type_ = c->device_type(); |
1059 | } |
1060 | |
1061 | protected: |
1062 | // Fills common parts of CollectiveParams according to the Op, *excluding |
1063 | // output_shape*. Kernels should further work on the CollectiveParams if they |
1064 | // need to set additional fields. |
1065 | Status FillCollectiveParams(CollectiveParams* col_params, |
1066 | const Tensor& group_assignment, |
1067 | CollectiveType collective_type, |
1068 | CollectiveGroupResource* resource) { |
1069 | int64 group_id; |
1070 | int64 group_size; |
1071 | if (group_assignment.NumElements() == 0) { |
1072 | // No group assignments, perform collective as a single group. |
1073 | group_id = 0; |
1074 | group_size = resource->group_size(); |
1075 | } else { |
1076 | return errors::Unimplemented("Group assignments are not supported yet." ); |
1077 | } |
1078 | |
1079 | // Construct instance key with format: |
1080 | // <11 bits for group><21 bits for atomic incremented instance key> |
1081 | int32 instance_key = group_id << 21 | resource->get_next_instance_key(); |
1082 | col_params->name = name_; |
1083 | col_params->group.device_type = device_type_; |
1084 | col_params->group.group_size = group_size; |
1085 | col_params->group.group_key = resource->group_key(); |
1086 | col_params->group.user_specified_rank = resource->rank(); |
1087 | col_params->instance.type = collective_type; |
1088 | col_params->instance.instance_key = instance_key; |
1089 | col_params->instance.data_type = data_type_; |
1090 | col_params->instance.impl_details.communication_hint = |
1091 | resource->communication_hint(); |
1092 | col_params->instance.impl_details.timeout_seconds = |
1093 | timeout_seconds_ > 0 ? resource->timeout_seconds() : timeout_seconds_; |
1094 | col_params->run_group_initialization = false; |
1095 | return OkStatus(); |
1096 | } |
1097 | |
1098 | // Runs a collective. The output tensor must be allocated before calling this |
1099 | // method. col_params must live until done is called. |
1100 | void Run(OpKernelContext* c, CollectiveParams* col_params, |
1101 | DoneCallback done) { |
1102 | CollectiveExecutor* col_exec = c->collective_executor(); |
1103 | OP_REQUIRES_ASYNC( |
1104 | c, col_exec, |
1105 | errors::Internal( |
1106 | "Failed to get CollectiveExecutor from OpKernelContext for Op " , |
1107 | name_), |
1108 | done); |
1109 | // Resolve the collective params. |
1110 | // Schedule the `CompleteParamsAsync` call on a work queue that can handle |
1111 | // blocking work because it's not guaranteed that this call cannot block. |
1112 | col_exec->RunClosure([c, done = std::move(done), col_params, col_exec]() { |
1113 | VLOG(1) << "Collective CompleteParams for " << col_params->name |
1114 | << " device " << c->device()->name() << " group " |
1115 | << col_params->group.group_key << " instance " |
1116 | << col_params->instance.instance_key; |
1117 | col_exec->CompleteParamsAsync( |
1118 | c->device()->attributes(), col_params, c->cancellation_manager(), |
1119 | [c, done = std::move(done), col_params, col_exec](const Status& s) { |
1120 | if (s.ok()) { |
1121 | auto actual_done = [c, col_params, |
1122 | done = std::move(done)](const Status& s) { |
1123 | VLOG(1) << "Collective ExecuteAsync done for " |
1124 | << col_params->name << " device " << c->device()->name() |
1125 | << " group " << col_params->group.group_key |
1126 | << " instance " << col_params->instance.instance_key |
1127 | << " status " << s; |
1128 | if (!s.ok()) { |
1129 | c->SetStatus(s); |
1130 | } |
1131 | done(); |
1132 | }; |
1133 | VLOG(1) << "Collective ExecuteAsync start for " |
1134 | << col_params->name << " device " << c->device()->name() |
1135 | << " group " << col_params->group.group_key |
1136 | << " instance " << col_params->instance.instance_key; |
1137 | col_exec->ExecuteAsync( |
1138 | c, col_params, |
1139 | CollectiveKey(c, col_params->group.group_key, |
1140 | col_params->instance.instance_key), |
1141 | actual_done); |
1142 | } else { |
1143 | c->SetStatus(s); |
1144 | done(); |
1145 | } |
1146 | }); |
1147 | }); |
1148 | } |
1149 | |
1150 | protected: |
1151 | string name_; |
1152 | DataType data_type_ = DT_INVALID; |
1153 | DeviceType device_type_; |
1154 | float timeout_seconds_ = 0; |
1155 | }; |
1156 | |
1157 | class CollectiveReduceV3OpKernel : public CollectiveOpV3Kernel { |
1158 | public: |
1159 | explicit CollectiveReduceV3OpKernel(OpKernelConstruction* c) |
1160 | : CollectiveOpV3Kernel(c) { |
1161 | string reduction; |
1162 | OP_REQUIRES_OK(c, c->GetAttr("reduction" , &reduction)); |
1163 | if (reduction == "Max" ) { |
1164 | reduction = "Maximum" ; |
1165 | } else if (reduction == "Min" ) { |
1166 | reduction = "Minimum" ; |
1167 | } |
1168 | // Prepare OpKernels for reduction and final operations. |
1169 | // The merge_op takes two inputs |
1170 | NodeDef sub_node; |
1171 | sub_node.add_input(c->def().input(0)); |
1172 | sub_node.add_input(c->def().input(0)); |
1173 | sub_node.set_device(c->def().device()); |
1174 | SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T" ]); |
1175 | merge_op_ = BuildOpKernel(c, reduction, &sub_node); |
1176 | final_op_ = BuildOpKernel(c, "Id" , &sub_node); |
1177 | name_ = strings::StrCat(c->def().name(), ": ReduceV3(" , reduction, ")" ); |
1178 | VLOG(2) << "CollectiveReduceV3 " << this << " name " << name_; |
1179 | } |
1180 | |
1181 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
1182 | auto col_params = new CollectiveParams(); |
1183 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
1184 | done(); |
1185 | col_params->Unref(); |
1186 | }; |
1187 | core::RefCountPtr<CollectiveGroupResource> resource; |
1188 | OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource), |
1189 | done_with_cleanup); |
1190 | |
1191 | Tensor group_assignment = c->input(2); |
1192 | |
1193 | OP_REQUIRES_OK_ASYNC( |
1194 | c, |
1195 | FillCollectiveParams(col_params, group_assignment, REDUCTION_COLLECTIVE, |
1196 | resource.get()), |
1197 | done_with_cleanup); |
1198 | col_params->instance.shape = c->input(0).shape(); |
1199 | col_params->merge_op = merge_op_.get(); |
1200 | col_params->final_op = final_op_.get(); |
1201 | VLOG(1) << "CollectiveReduceV3 group_size " << col_params->group.group_size |
1202 | << " group_key " << col_params->group.group_key << " instance_key " |
1203 | << col_params->instance.instance_key; |
1204 | // Allocate the output tensor, trying to reuse the input. |
1205 | Tensor* output = nullptr; |
1206 | OP_REQUIRES_OK_ASYNC(c, |
1207 | c->forward_input_or_allocate_output( |
1208 | {0}, 0, col_params->instance.shape, &output), |
1209 | done_with_cleanup); |
1210 | Run(c, col_params, std::move(done_with_cleanup)); |
1211 | } |
1212 | |
1213 | private: |
1214 | std::unique_ptr<OpKernel> merge_op_; |
1215 | std::unique_ptr<OpKernel> final_op_; |
1216 | }; |
1217 | |
1218 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3" ).Device(DEVICE_CPU), |
1219 | CollectiveReduceV3OpKernel); |
1220 | REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3" ).Device(DEVICE_GPU), |
1221 | CollectiveReduceV3OpKernel); |
1222 | |
1223 | class CollectiveAllToAllV3OpKernel : public CollectiveOpV3Kernel { |
1224 | public: |
1225 | explicit CollectiveAllToAllV3OpKernel(OpKernelConstruction* c) |
1226 | : CollectiveOpV3Kernel(c) { |
1227 | name_ = strings::StrCat(c->def().name(), ": AllToAllV3" ); |
1228 | VLOG(2) << "CollectiveAllToAllV3 " << this << " name " << name_; |
1229 | } |
1230 | |
1231 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
1232 | auto col_params = new CollectiveParams(); |
1233 | auto done_with_cleanup = [col_params, done = std::move(done)]() { |
1234 | done(); |
1235 | col_params->Unref(); |
1236 | }; |
1237 | core::RefCountPtr<CollectiveGroupResource> resource; |
1238 | OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource), |
1239 | done_with_cleanup); |
1240 | |
1241 | Tensor group_assignment = c->input(2); |
1242 | |
1243 | OP_REQUIRES_OK_ASYNC( |
1244 | c, |
1245 | FillCollectiveParams(col_params, group_assignment, |
1246 | ALL_TO_ALL_COLLECTIVE, resource.get()), |
1247 | done_with_cleanup); |
1248 | col_params->instance.shape = c->input(0).shape(); |
1249 | VLOG(1) << "CollectiveAllToAll group_size " << col_params->group.group_size |
1250 | << " group_key " << col_params->group.group_key << " instance_key " |
1251 | << col_params->instance.instance_key; |
1252 | // Allocate the output tensor, trying to reuse the input. |
1253 | Tensor* output = nullptr; |
1254 | OP_REQUIRES_OK_ASYNC(c, |
1255 | c->forward_input_or_allocate_output( |
1256 | {0}, 0, col_params->instance.shape, &output), |
1257 | done_with_cleanup); |
1258 | Run(c, col_params, std::move(done_with_cleanup)); |
1259 | } |
1260 | }; |
1261 | |
1262 | REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3" ).Device(DEVICE_CPU), |
1263 | CollectiveAllToAllV3OpKernel); |
1264 | REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3" ).Device(DEVICE_GPU), |
1265 | CollectiveAllToAllV3OpKernel); |
1266 | } // namespace |
1267 | } // namespace tensorflow |
1268 | |