1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/collective_util.h" |
16 | |
17 | #include <memory> |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/framework/collective.h" |
23 | #include "tensorflow/core/framework/device_attributes.pb.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/strings/strcat.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace collective_util { |
30 | |
31 | /*static*/ |
32 | Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, |
33 | const string& device_name, Device** device, |
34 | DeviceLocality* device_locality) { |
35 | if (!dev_mgr) { |
36 | return errors::Internal("Required non-null dev_mgr " , dev_mgr, |
37 | " for InitializeDeviceAndLocality" ); |
38 | } |
39 | |
40 | Status status = dev_mgr->LookupDevice(device_name, device); |
41 | if (status.ok()) { |
42 | CHECK(*device); |
43 | *device_locality = (*device)->attributes().locality(); |
44 | } else { |
45 | LOG(ERROR) << "Failed to find device " << device_name; |
46 | for (auto d : dev_mgr->ListDevices()) { |
47 | LOG(ERROR) << "Available devices " << d->name(); |
48 | } |
49 | } |
50 | return status; |
51 | } |
52 | |
53 | /*static*/ |
54 | string SubdivPermDebugString(const CollectiveParams& col_params) { |
55 | const auto& subdiv_perms = |
56 | col_params.instance.impl_details.subdiv_permutations; |
57 | string buf; |
58 | for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) { |
59 | strings::StrAppend(&buf, "Subdiv " , sdi, " device order:\n" ); |
60 | for (int di = 0; di < subdiv_perms[sdi].size(); ++di) { |
61 | int idx = subdiv_perms[sdi][di]; |
62 | if (idx >= 0) { |
63 | CHECK_GT(col_params.group.members.size(), idx); |
64 | strings::StrAppend(&buf, col_params.group.members[idx].device.name(), |
65 | "\n" ); |
66 | } |
67 | } |
68 | strings::StrAppend(&buf, " subdiv_offsets: " ); |
69 | for (auto o : col_params.instance.impl_details.subdiv_offsets) |
70 | strings::StrAppend(&buf, o, " " ); |
71 | strings::StrAppend(&buf, " SubdivRank: " ); |
72 | for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " " ); |
73 | if (col_params.instance.type == BROADCAST_COLLECTIVE) { |
74 | strings::StrAppend(&buf, " subdiv_source_rank: " ); |
75 | for (auto src : col_params.instance.impl_details.subdiv_source_rank) |
76 | strings::StrAppend(&buf, src, " " ); |
77 | } |
78 | strings::StrAppend(&buf, "\n" ); |
79 | } |
80 | return buf; |
81 | } |
82 | |
83 | SubContext::SubContext(OpKernelContext* ctx, OpKernelContext::Params* params, |
84 | OpKernel* op, Tensor* output, Tensor* input) |
85 | : sub_params_(*params), |
86 | sub_inputs_({TensorValue(output), TensorValue(input)}), |
87 | sub_input_attr_({ctx->input_alloc_attr(0), ctx->input_alloc_attr(0)}) { |
88 | sub_params_.op_kernel = op; |
89 | sub_params_.inputs = sub_inputs_; |
90 | sub_params_.input_alloc_attrs = sub_input_attr_; |
91 | sub_params_.op_device_context = ctx->op_device_context(); |
92 | sub_params_.eigen_gpu_device = nullptr; |
93 | sub_params_.ensure_eigen_gpu_device(); |
94 | sub_params_.forward_from_array = &forward_from_; |
95 | sub_ctx_.reset(new OpKernelContext(&sub_params_, 1)); |
96 | } |
97 | |
98 | Status ComputeBinOp(OpKernelContext* op_ctx, OpKernelContext::Params* params, |
99 | Device* device, OpKernel* op, Tensor* output, |
100 | Tensor* input) { |
101 | // Prepare an OpKernelContext that is identical to that of the original Op |
102 | // (i.e. the collective), except for the input output sizes and identities and |
103 | // the Op itself. |
104 | // TODO(ayushd, tucker): Is it possible to cache and reuse these objects? |
105 | // They're mostly identical inside one device execution. |
106 | std::unique_ptr<SubContext> sub_ctx( |
107 | new SubContext(op_ctx, params, op, output, input)); |
108 | device->Compute(op, sub_ctx->sub_ctx_.get()); |
109 | return sub_ctx->sub_ctx_->status(); |
110 | } |
111 | |
112 | } // namespace collective_util |
113 | } // namespace tensorflow |
114 | |