1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/collective.h"
16
17#include "absl/strings/escaping.h"
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/lib/core/errors.h"
20#include "tensorflow/core/lib/hash/hash.h"
21#include "tensorflow/core/lib/strings/str_util.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23
24namespace tensorflow {
25
26namespace {
27// A RegistrationInfo object stores a collective implementation registration
28// details. `factory` is used to create instances of the collective
29// implementation.
30struct RegistrationInfo {
31 // This constructor also creates, and stores in `param_resolver_instance`,
32 // what is effectively a static instance of the collective implementation.
33 // During param resolution of collective ops we return this static instance.
34 // The actual op execution gets a fresh instance using `factory`.
35 RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36 : name(n),
37 factory(std::move(f)),
38 param_resolver_instance(this->factory()) {}
39 string name;
40 CollectiveRegistry::Factory factory;
41 CollectiveImplementationInterface* param_resolver_instance;
42};
43
44std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45 static std::vector<RegistrationInfo>* registry =
46 new std::vector<RegistrationInfo>;
47 return registry;
48}
49} // namespace
50
51string CollGroupRuntimeDetails::ToString() const {
52 return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53 absl::CEscape(communicator_key), "}");
54}
55
56string CollGroupParams::ToString() const {
57 string v = strings::StrCat(
58 "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59 " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60 " runtime_details=", runtime_details.ToString(), " devices {");
61 for (const auto& m : members) {
62 strings::StrAppend(&v, m.device.name(), ",");
63 }
64 strings::StrAppend(&v, "} num_devices_per_task={");
65 for (const auto& dpt : num_devices_per_task) {
66 strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
67 }
68 strings::StrAppend(&v, "}");
69 return v;
70}
71
72CollInstanceParams& CollInstanceParams::operator=(
73 const CollInstanceParams& other) {
74 if (this != &other) {
75 instance_key = other.instance_key;
76 type = other.type;
77 data_type = other.data_type;
78 shape = other.shape;
79 impl_details.subdiv_offsets.assign(
80 other.impl_details.subdiv_offsets.begin(),
81 other.impl_details.subdiv_offsets.end());
82 impl_details.subdiv_permutations.clear();
83 for (auto p : other.impl_details.subdiv_permutations) {
84 impl_details.subdiv_permutations.push_back(
85 std::vector<int>(p.begin(), p.end()));
86 }
87 impl_details.subdiv_source_rank.assign(
88 other.impl_details.subdiv_source_rank.begin(),
89 other.impl_details.subdiv_source_rank.end());
90 impl_details.dependencies = other.impl_details.dependencies;
91 devices.assign(other.devices.begin(), other.devices.end());
92 permutation.assign(other.permutation.begin(), other.permutation.end());
93 }
94 return *this;
95}
96
97string CollInstanceParams::ToString() const {
98 string v =
99 strings::StrCat("CollInstanceParams { instance_key=", instance_key,
100 " type=", type, " data_type=", DataTypeString(data_type),
101 " shape=", shape.DebugString(), " devices {");
102 strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
103 ", subdiv_offsets={");
104 strings::StrAppend(&v, "}, subdiv_offsets={");
105 for (const auto& d : impl_details.subdiv_offsets) {
106 strings::StrAppend(&v, d, ",");
107 }
108 strings::StrAppend(&v, "}, subdiv_perms={");
109 for (const auto& p : impl_details.subdiv_permutations) {
110 strings::StrAppend(&v, "{");
111 for (const auto& i : p) {
112 strings::StrAppend(&v, i, ",");
113 }
114 strings::StrAppend(&v, "}"); // one subdiv
115 }
116 if (!impl_details.subdiv_source_rank.empty()) {
117 strings::StrAppend(&v, " subdiv_source_rank={");
118 for (const auto& r : impl_details.subdiv_source_rank) {
119 strings::StrAppend(&v, r, ",");
120 }
121 strings::StrAppend(&v, "}");
122 } // all subdivs
123 if (type == PERMUTE_COLLECTIVE) {
124 strings::StrAppend(&v, "}, permute_devices {");
125 for (const auto& d : devices) {
126 strings::StrAppend(&v, d, ",");
127 }
128 strings::StrAppend(&v, "}, permute_permutation {");
129 for (const auto& p : permutation) {
130 strings::StrAppend(&v, p, ",");
131 }
132 strings::StrAppend(&v, "}");
133 }
134 return v;
135}
136
137string CollectiveParams::ToString() const {
138 string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
139 strings::StrAppend(&v, " ", instance.ToString());
140 strings::StrAppend(&v, " default_rank=", default_rank,
141 " is_source=", is_source, " source_rank=", source_rank,
142 " subdiv_rank={");
143 for (const auto& r : subdiv_rank) {
144 strings::StrAppend(&v, r, ",");
145 }
146 strings::StrAppend(&v, "}}");
147 return v;
148}
149
150/*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
151 OpKernelContext* ctx) {
152 return ctx->params_;
153}
154
155CollectiveContext::CollectiveContext(
156 CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
157 const DeviceMgr* dev_mgr, OpKernelContext* ctx,
158 OpKernelContext::Params* op_params, const CollectiveParams* col_params,
159 const string& exec_key, int64_t step_id, const Tensor* input,
160 Tensor* output)
161 : col_exec(col_exec),
162 nccl_communicator(nccl_communicator),
163 dev_mgr(dev_mgr),
164 op_ctx(ctx),
165 op_params(op_params),
166 col_params(col_params, /*add_ref=*/true),
167 exec_key(exec_key),
168 step_id(step_id),
169 input(input),
170 output(output),
171 device(nullptr),
172 device_name(
173 col_params->group.members[col_params->default_rank].device.name()) {}
174
175/*static*/
176int64_t CollectiveExecutor::kInvalidId = -1;
177
178/*static*/
179Status CollectiveRegistry::Lookup(
180 const string& collective_name,
181 CollectiveImplementationInterface** implementation) {
182 return LookupHelper(collective_name, implementation, false);
183}
184
185/*static*/
186Status CollectiveRegistry::LookupParamResolverInstance(
187 const string& collective_name,
188 CollectiveImplementationInterface** implementation) {
189 return LookupHelper(collective_name, implementation, true);
190}
191
192/*static*/
193void CollectiveRegistry::GetAll(
194 std::vector<CollectiveImplementationInterface*>* implementations) {
195 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
196 for (const RegistrationInfo& reg_info : *registry)
197 implementations->emplace_back(reg_info.factory());
198}
199
200/*static*/
201Status CollectiveRegistry::Register(const string& collective_name,
202 Factory factory) {
203 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
204 for (const RegistrationInfo& reg_info : *registry) {
205 if (reg_info.name == collective_name)
206 return errors::Internal("Already registered collective ",
207 collective_name);
208 }
209 registry->emplace_back(collective_name, std::move(factory));
210 return OkStatus();
211}
212
213/*static*/
214Status CollectiveRegistry::LookupHelper(
215 const string& collective_name,
216 CollectiveImplementationInterface** implementation, bool param_resolver) {
217 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
218 for (const RegistrationInfo& reg_info : *registry) {
219 if (reg_info.name == collective_name) {
220 if (param_resolver) {
221 *implementation = reg_info.param_resolver_instance;
222 } else {
223 *implementation = reg_info.factory();
224 }
225 return OkStatus();
226 }
227 }
228 return errors::Internal(
229 "CollectiveRegistry::Lookup did not find collective implementation ",
230 collective_name);
231}
232
233} // namespace tensorflow
234