1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/collective_util.h"
16
17#include <memory>
18#include <vector>
19
20#include "tensorflow/core/common_runtime/device.h"
21#include "tensorflow/core/common_runtime/device_mgr.h"
22#include "tensorflow/core/framework/collective.h"
23#include "tensorflow/core/framework/device_attributes.pb.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/strings/strcat.h"
26#include "tensorflow/core/platform/types.h"
27
28namespace tensorflow {
29namespace collective_util {
30
31/*static*/
32Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
33 const string& device_name, Device** device,
34 DeviceLocality* device_locality) {
35 if (!dev_mgr) {
36 return errors::Internal("Required non-null dev_mgr ", dev_mgr,
37 " for InitializeDeviceAndLocality");
38 }
39
40 Status status = dev_mgr->LookupDevice(device_name, device);
41 if (status.ok()) {
42 CHECK(*device);
43 *device_locality = (*device)->attributes().locality();
44 } else {
45 LOG(ERROR) << "Failed to find device " << device_name;
46 for (auto d : dev_mgr->ListDevices()) {
47 LOG(ERROR) << "Available devices " << d->name();
48 }
49 }
50 return status;
51}
52
53/*static*/
54string SubdivPermDebugString(const CollectiveParams& col_params) {
55 const auto& subdiv_perms =
56 col_params.instance.impl_details.subdiv_permutations;
57 string buf;
58 for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) {
59 strings::StrAppend(&buf, "Subdiv ", sdi, " device order:\n");
60 for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
61 int idx = subdiv_perms[sdi][di];
62 if (idx >= 0) {
63 CHECK_GT(col_params.group.members.size(), idx);
64 strings::StrAppend(&buf, col_params.group.members[idx].device.name(),
65 "\n");
66 }
67 }
68 strings::StrAppend(&buf, " subdiv_offsets: ");
69 for (auto o : col_params.instance.impl_details.subdiv_offsets)
70 strings::StrAppend(&buf, o, " ");
71 strings::StrAppend(&buf, " SubdivRank: ");
72 for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " ");
73 if (col_params.instance.type == BROADCAST_COLLECTIVE) {
74 strings::StrAppend(&buf, " subdiv_source_rank: ");
75 for (auto src : col_params.instance.impl_details.subdiv_source_rank)
76 strings::StrAppend(&buf, src, " ");
77 }
78 strings::StrAppend(&buf, "\n");
79 }
80 return buf;
81}
82
83SubContext::SubContext(OpKernelContext* ctx, OpKernelContext::Params* params,
84 OpKernel* op, Tensor* output, Tensor* input)
85 : sub_params_(*params),
86 sub_inputs_({TensorValue(output), TensorValue(input)}),
87 sub_input_attr_({ctx->input_alloc_attr(0), ctx->input_alloc_attr(0)}) {
88 sub_params_.op_kernel = op;
89 sub_params_.inputs = sub_inputs_;
90 sub_params_.input_alloc_attrs = sub_input_attr_;
91 sub_params_.op_device_context = ctx->op_device_context();
92 sub_params_.eigen_gpu_device = nullptr;
93 sub_params_.ensure_eigen_gpu_device();
94 sub_params_.forward_from_array = &forward_from_;
95 sub_ctx_.reset(new OpKernelContext(&sub_params_, 1));
96}
97
98Status ComputeBinOp(OpKernelContext* op_ctx, OpKernelContext::Params* params,
99 Device* device, OpKernel* op, Tensor* output,
100 Tensor* input) {
101 // Prepare an OpKernelContext that is identical to that of the original Op
102 // (i.e. the collective), except for the input output sizes and identities and
103 // the Op itself.
104 // TODO(ayushd, tucker): Is it possible to cache and reuse these objects?
105 // They're mostly identical inside one device execution.
106 std::unique_ptr<SubContext> sub_ctx(
107 new SubContext(op_ctx, params, op, output, input));
108 device->Compute(op, sub_ctx->sub_ctx_.get());
109 return sub_ctx->sub_ctx_->status();
110}
111
112} // namespace collective_util
113} // namespace tensorflow
114