1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/permuter.h" |
16 | |
17 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
18 | #include "tensorflow/core/common_runtime/collective_util.h" |
19 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/common_runtime/dma_helper.h" |
23 | #include "tensorflow/core/common_runtime/process_util.h" |
24 | #include "tensorflow/core/framework/allocator.h" |
25 | #include "tensorflow/core/framework/device_base.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/lib/core/notification.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/lib/strings/str_util.h" |
33 | #include "tensorflow/core/lib/strings/strcat.h" |
34 | #include "tensorflow/core/platform/env.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | Permuter::Permuter() |
40 | : col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {} |
41 | |
42 | StatusCallback Permuter::CheckCounterAndCallDone() { |
43 | return [this](const Status& s) { |
44 | mu_.lock(); |
45 | status_.Update(s); |
46 | int counter = ++counter_; |
47 | Status status = status_; |
48 | mu_.unlock(); |
49 | if (counter == 2) done_(status); |
50 | }; |
51 | } |
52 | |
53 | Status Permuter::InitializeCollectiveContext( |
54 | std::shared_ptr<CollectiveContext> col_ctx) { |
55 | DCHECK(col_ctx->dev_mgr); |
56 | col_ctx_ = col_ctx; |
57 | col_params_ = col_ctx->col_params.get(); |
58 | return collective_util::InitializeDeviceAndLocality( |
59 | col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, |
60 | &col_ctx->device_locality); |
61 | } |
62 | |
63 | void Permuter::Run(StatusCallback done) { |
64 | if (col_params_->instance.permutation.size() != |
65 | col_params_->instance.devices.size()) { |
66 | done(errors::Internal("Permutation must be the same size as devices" )); |
67 | } |
68 | done_ = std::move(done); |
69 | DispatchSend(col_params_->default_rank, |
70 | col_params_->instance.permutation[col_params_->default_rank], |
71 | col_ctx_->input, CheckCounterAndCallDone()); |
72 | for (int i = 0; i < col_params_->instance.permutation.size(); ++i) { |
73 | if (col_params_->default_rank == col_params_->instance.permutation[i]) { |
74 | DispatchRecv(i, col_params_->instance.permutation[i], col_ctx_->output, |
75 | CheckCounterAndCallDone()); |
76 | } |
77 | } |
78 | } |
79 | |
80 | void Permuter::DispatchSend(int src_rank, int target_rank, const Tensor* tensor, |
81 | const StatusCallback& done) { |
82 | string send_buf_key = |
83 | strings::StrCat(col_ctx_->exec_key, src_rank, target_rank); |
84 | VLOG(1) << "DispatchSend " << send_buf_key << " from_device " |
85 | << col_ctx_->device_name << " to_device " |
86 | << col_params_->instance.devices[target_rank] |
87 | << " target_rank=" << target_rank << " src_rank=" << src_rank; |
88 | col_ctx_->col_exec->remote_access()->PostToPeer( |
89 | col_params_->instance.devices[target_rank], |
90 | col_params_->group.members[target_rank].task, send_buf_key, |
91 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
92 | col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, |
93 | col_ctx_->op_ctx->cancellation_manager(), done); |
94 | } |
95 | |
96 | void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor, |
97 | const StatusCallback& done) { |
98 | string recv_buf_key = |
99 | strings::StrCat(col_ctx_->exec_key, src_rank, target_rank); |
100 | VLOG(1) << "DispatchRecv " << recv_buf_key << " to_device " |
101 | << col_ctx_->device_name << " from_device " |
102 | << col_params_->instance.devices[src_rank] |
103 | << " target_rank=" << target_rank << " src_rank=" << src_rank; |
104 | col_ctx_->col_exec->remote_access()->RecvFromPeer( |
105 | col_params_->instance.devices[src_rank], |
106 | col_params_->group.members[src_rank].task, |
107 | col_params_->group.members[src_rank].is_local, recv_buf_key, |
108 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
109 | col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, |
110 | 0, col_ctx_->op_ctx->cancellation_manager(), done); |
111 | } |
112 | namespace { |
113 | REGISTER_COLLECTIVE(Permute, Permuter); |
114 | } // namespace |
115 | |
116 | } // namespace tensorflow |
117 | |