1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/collective.h" |
16 | |
17 | #include "absl/strings/escaping.h" |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/lib/core/errors.h" |
20 | #include "tensorflow/core/lib/hash/hash.h" |
21 | #include "tensorflow/core/lib/strings/str_util.h" |
22 | #include "tensorflow/core/lib/strings/strcat.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace { |
27 | // A RegistrationInfo object stores a collective implementation registration |
28 | // details. `factory` is used to create instances of the collective |
29 | // implementation. |
30 | struct RegistrationInfo { |
31 | // This constructor also creates, and stores in `param_resolver_instance`, |
32 | // what is effectively a static instance of the collective implementation. |
33 | // During param resolution of collective ops we return this static instance. |
34 | // The actual op execution gets a fresh instance using `factory`. |
35 | RegistrationInfo(const string& n, CollectiveRegistry::Factory f) |
36 | : name(n), |
37 | factory(std::move(f)), |
38 | param_resolver_instance(this->factory()) {} |
39 | string name; |
40 | CollectiveRegistry::Factory factory; |
41 | CollectiveImplementationInterface* param_resolver_instance; |
42 | }; |
43 | |
44 | std::vector<RegistrationInfo>* MutableCollectiveRegistry() { |
45 | static std::vector<RegistrationInfo>* registry = |
46 | new std::vector<RegistrationInfo>; |
47 | return registry; |
48 | } |
49 | } // namespace |
50 | |
51 | string CollGroupRuntimeDetails::ToString() const { |
52 | return strings::StrCat("CollGroupRuntimeDetails {communicator_key=" , |
53 | absl::CEscape(communicator_key), "}" ); |
54 | } |
55 | |
56 | string CollGroupParams::ToString() const { |
57 | string v = strings::StrCat( |
58 | "CollGroupParams {group_key=" , group_key, " group_size=" , group_size, |
59 | " device_type=" , device_type.type_string(), " num_tasks=" , num_tasks, |
60 | " runtime_details=" , runtime_details.ToString(), " devices {" ); |
61 | for (const auto& m : members) { |
62 | strings::StrAppend(&v, m.device.name(), "," ); |
63 | } |
64 | strings::StrAppend(&v, "} num_devices_per_task={" ); |
65 | for (const auto& dpt : num_devices_per_task) { |
66 | strings::StrAppend(&v, dpt.first, ": " , dpt.second, ", " ); |
67 | } |
68 | strings::StrAppend(&v, "}" ); |
69 | return v; |
70 | } |
71 | |
72 | CollInstanceParams& CollInstanceParams::operator=( |
73 | const CollInstanceParams& other) { |
74 | if (this != &other) { |
75 | instance_key = other.instance_key; |
76 | type = other.type; |
77 | data_type = other.data_type; |
78 | shape = other.shape; |
79 | impl_details.subdiv_offsets.assign( |
80 | other.impl_details.subdiv_offsets.begin(), |
81 | other.impl_details.subdiv_offsets.end()); |
82 | impl_details.subdiv_permutations.clear(); |
83 | for (auto p : other.impl_details.subdiv_permutations) { |
84 | impl_details.subdiv_permutations.push_back( |
85 | std::vector<int>(p.begin(), p.end())); |
86 | } |
87 | impl_details.subdiv_source_rank.assign( |
88 | other.impl_details.subdiv_source_rank.begin(), |
89 | other.impl_details.subdiv_source_rank.end()); |
90 | impl_details.dependencies = other.impl_details.dependencies; |
91 | devices.assign(other.devices.begin(), other.devices.end()); |
92 | permutation.assign(other.permutation.begin(), other.permutation.end()); |
93 | } |
94 | return *this; |
95 | } |
96 | |
97 | string CollInstanceParams::ToString() const { |
98 | string v = |
99 | strings::StrCat("CollInstanceParams { instance_key=" , instance_key, |
100 | " type=" , type, " data_type=" , DataTypeString(data_type), |
101 | " shape=" , shape.DebugString(), " devices {" ); |
102 | strings::StrAppend(&v, "}, collective_name=" , impl_details.collective_name, |
103 | ", subdiv_offsets={" ); |
104 | strings::StrAppend(&v, "}, subdiv_offsets={" ); |
105 | for (const auto& d : impl_details.subdiv_offsets) { |
106 | strings::StrAppend(&v, d, "," ); |
107 | } |
108 | strings::StrAppend(&v, "}, subdiv_perms={" ); |
109 | for (const auto& p : impl_details.subdiv_permutations) { |
110 | strings::StrAppend(&v, "{" ); |
111 | for (const auto& i : p) { |
112 | strings::StrAppend(&v, i, "," ); |
113 | } |
114 | strings::StrAppend(&v, "}" ); // one subdiv |
115 | } |
116 | if (!impl_details.subdiv_source_rank.empty()) { |
117 | strings::StrAppend(&v, " subdiv_source_rank={" ); |
118 | for (const auto& r : impl_details.subdiv_source_rank) { |
119 | strings::StrAppend(&v, r, "," ); |
120 | } |
121 | strings::StrAppend(&v, "}" ); |
122 | } // all subdivs |
123 | if (type == PERMUTE_COLLECTIVE) { |
124 | strings::StrAppend(&v, "}, permute_devices {" ); |
125 | for (const auto& d : devices) { |
126 | strings::StrAppend(&v, d, "," ); |
127 | } |
128 | strings::StrAppend(&v, "}, permute_permutation {" ); |
129 | for (const auto& p : permutation) { |
130 | strings::StrAppend(&v, p, "," ); |
131 | } |
132 | strings::StrAppend(&v, "}" ); |
133 | } |
134 | return v; |
135 | } |
136 | |
137 | string CollectiveParams::ToString() const { |
138 | string v = strings::StrCat("CollectiveParams " , name, " {" , group.ToString()); |
139 | strings::StrAppend(&v, " " , instance.ToString()); |
140 | strings::StrAppend(&v, " default_rank=" , default_rank, |
141 | " is_source=" , is_source, " source_rank=" , source_rank, |
142 | " subdiv_rank={" ); |
143 | for (const auto& r : subdiv_rank) { |
144 | strings::StrAppend(&v, r, "," ); |
145 | } |
146 | strings::StrAppend(&v, "}}" ); |
147 | return v; |
148 | } |
149 | |
150 | /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams( |
151 | OpKernelContext* ctx) { |
152 | return ctx->params_; |
153 | } |
154 | |
155 | CollectiveContext::CollectiveContext( |
156 | CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator, |
157 | const DeviceMgr* dev_mgr, OpKernelContext* ctx, |
158 | OpKernelContext::Params* op_params, const CollectiveParams* col_params, |
159 | const string& exec_key, int64_t step_id, const Tensor* input, |
160 | Tensor* output) |
161 | : col_exec(col_exec), |
162 | nccl_communicator(nccl_communicator), |
163 | dev_mgr(dev_mgr), |
164 | op_ctx(ctx), |
165 | op_params(op_params), |
166 | col_params(col_params, /*add_ref=*/true), |
167 | exec_key(exec_key), |
168 | step_id(step_id), |
169 | input(input), |
170 | output(output), |
171 | device(nullptr), |
172 | device_name( |
173 | col_params->group.members[col_params->default_rank].device.name()) {} |
174 | |
175 | /*static*/ |
176 | int64_t CollectiveExecutor::kInvalidId = -1; |
177 | |
178 | /*static*/ |
179 | Status CollectiveRegistry::Lookup( |
180 | const string& collective_name, |
181 | CollectiveImplementationInterface** implementation) { |
182 | return LookupHelper(collective_name, implementation, false); |
183 | } |
184 | |
185 | /*static*/ |
186 | Status CollectiveRegistry::LookupParamResolverInstance( |
187 | const string& collective_name, |
188 | CollectiveImplementationInterface** implementation) { |
189 | return LookupHelper(collective_name, implementation, true); |
190 | } |
191 | |
192 | /*static*/ |
193 | void CollectiveRegistry::GetAll( |
194 | std::vector<CollectiveImplementationInterface*>* implementations) { |
195 | std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); |
196 | for (const RegistrationInfo& reg_info : *registry) |
197 | implementations->emplace_back(reg_info.factory()); |
198 | } |
199 | |
200 | /*static*/ |
201 | Status CollectiveRegistry::Register(const string& collective_name, |
202 | Factory factory) { |
203 | std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); |
204 | for (const RegistrationInfo& reg_info : *registry) { |
205 | if (reg_info.name == collective_name) |
206 | return errors::Internal("Already registered collective " , |
207 | collective_name); |
208 | } |
209 | registry->emplace_back(collective_name, std::move(factory)); |
210 | return OkStatus(); |
211 | } |
212 | |
213 | /*static*/ |
214 | Status CollectiveRegistry::LookupHelper( |
215 | const string& collective_name, |
216 | CollectiveImplementationInterface** implementation, bool param_resolver) { |
217 | std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); |
218 | for (const RegistrationInfo& reg_info : *registry) { |
219 | if (reg_info.name == collective_name) { |
220 | if (param_resolver) { |
221 | *implementation = reg_info.param_resolver_instance; |
222 | } else { |
223 | *implementation = reg_info.factory(); |
224 | } |
225 | return OkStatus(); |
226 | } |
227 | } |
228 | return errors::Internal( |
229 | "CollectiveRegistry::Lookup did not find collective implementation " , |
230 | collective_name); |
231 | } |
232 | |
233 | } // namespace tensorflow |
234 | |