1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstdint> |
20 | #include <map> |
21 | #include <memory> |
22 | #include <numeric> |
23 | #include <set> |
24 | #include <string> |
25 | #include <string_view> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "absl/container/inlined_vector.h" |
30 | #include "absl/strings/str_cat.h" |
31 | #include "absl/strings/str_join.h" |
32 | #include "absl/strings/str_split.h" |
33 | #include "absl/strings/string_view.h" |
34 | #include "absl/types/optional.h" |
35 | #include "tensorflow/core/framework/tensor_shape.h" |
36 | #include "tensorflow/core/lib/math/math_util.h" |
37 | #include "tensorflow/core/platform/errors.h" |
38 | #include "tensorflow/core/platform/fingerprint.h" |
39 | #include "tensorflow/core/platform/logging.h" |
40 | #include "tensorflow/core/platform/statusor.h" |
41 | #include "tensorflow/core/util/device_name_utils.h" |
42 | #include "tensorflow/dtensor/cc/dstatus.h" |
43 | #include "tensorflow/dtensor/proto/layout.pb.h" |
44 | |
45 | namespace tensorflow { |
46 | namespace dtensor { |
47 | |
48 | constexpr const char* Layout::kUnshardedDim; |
49 | constexpr const char* Layout::kAny; |
50 | constexpr const char* Layout::kEmptyLayoutString; |
51 | constexpr const char* Layout::kMatch; |
52 | constexpr const char* Mesh::kEmptyMeshString; |
53 | |
54 | namespace { |
55 | // Obtain all possible forms of indexing a mesh. |
56 | // |
57 | // e.g. given a mesh with dimensions [x=2, y=3], returns { |
58 | // [0, 0], [0, 1], [0, 2], |
59 | // [1, 0], [1, 1], [1, 2] |
60 | // } |
61 | inline std::vector<DeviceLocation> ComputeDeviceLocations(const Mesh* mesh) { |
62 | std::vector<DeviceLocation> mesh_locs(mesh->size()); |
63 | for (size_t i = 0; i < mesh->size(); ++i) |
64 | mesh_locs[i] = *(mesh->device_location(i)); |
65 | return mesh_locs; |
66 | } |
67 | } // namespace |
68 | |
69 | namespace { |
70 | // Expands a ShardVector into the size defined in new_num_shards_per_dim. |
71 | // |
72 | // For example, the inputs: |
73 | // - shard_vec: shards = [(1,1)] num_shards_per_dim = [1,1] |
74 | // - new_num_shards_per_dim = [2,2] |
75 | // |
76 | // Would lead to: |
77 | // shard_vec: shards = [(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim = [2,2] |
78 | // |
79 | // This is used to check whether two ShardVectors contain the same information |
80 | // while having different number of shards per dimension. The two ShardVectors |
81 | // above are an example of this. |
82 | ShardVector ExpandShardVector(const ShardVector& shard_vec, |
83 | const std::vector<int>& new_num_shards_per_dim) { |
84 | if (shard_vec.shards.empty()) return shard_vec; |
85 | |
86 | // Takes a single shard and expands it into multiple shards. |
87 | auto ExpandShard = [shard_vec, new_num_shards_per_dim]( |
88 | const Shard& shard, |
89 | int dim_ind) -> std::vector<Shard> { |
90 | int original_dim_size = shard_vec.num_shards_per_dim[dim_ind]; |
91 | int new_dim_size = new_num_shards_per_dim[dim_ind]; |
92 | int size_ratio = new_dim_size / original_dim_size; |
93 | |
94 | std::vector<Shard> expanded_shards; |
95 | expanded_shards.reserve(size_ratio); |
96 | for (int i = 0; i < size_ratio; ++i) { |
97 | int original_coord = shard[dim_ind]; |
98 | int shifted_coord = (original_coord - 1) * size_ratio + 1 + i; |
99 | // Copy original shard, then modify it. |
100 | Shard new_shard = shard; |
101 | new_shard[dim_ind] = shifted_coord; |
102 | expanded_shards.push_back(new_shard); |
103 | } |
104 | return expanded_shards; |
105 | }; |
106 | // Iterates over the dimensions of the shard, expanding at each |
107 | // dimension. |
108 | std::vector<Shard> total_expanded_shards = shard_vec.shards; |
109 | for (int dim_ind = 0; dim_ind < new_num_shards_per_dim.size(); ++dim_ind) { |
110 | std::vector<Shard> dim_expanded_shards; |
111 | for (const auto& shard : total_expanded_shards) { |
112 | std::vector<Shard> expanded_shards = ExpandShard(shard, dim_ind); |
113 | // Concatenate newly created shards. |
114 | dim_expanded_shards.insert(dim_expanded_shards.end(), |
115 | expanded_shards.begin(), |
116 | expanded_shards.end()); |
117 | } |
118 | // Copy newly created shards and delete old ones. |
119 | total_expanded_shards = dim_expanded_shards; |
120 | } |
121 | std::sort(total_expanded_shards.begin(), total_expanded_shards.end()); |
122 | ShardVector expanded_shard_vec; |
123 | expanded_shard_vec.shards = total_expanded_shards; |
124 | expanded_shard_vec.num_shards_per_dim = new_num_shards_per_dim; |
125 | return expanded_shard_vec; |
126 | } |
127 | } // namespace |
128 | |
129 | bool ShardVector::operator==(const ShardVector& other) const { |
130 | // Check same number of shards. |
131 | if (this->shards.empty() && other.shards.empty()) return true; |
132 | if (this->shards.empty() || other.shards.empty()) return false; |
133 | |
134 | // Check number of shard dimensions match. |
135 | if (this->num_shards_per_dim.size() != other.num_shards_per_dim.size()) |
136 | return false; |
137 | |
138 | // Compute lowest common multiple for each of the shard dimensions. |
139 | Shard first_shard_this = this->shards[0]; |
140 | Shard first_shard_other = other.shards[0]; |
141 | std::vector<int> new_sizes; |
142 | for (size_t i = 0; i < first_shard_this.size(); ++i) { |
143 | int lcm = this->num_shards_per_dim[i] * other.num_shards_per_dim[i] / |
144 | MathUtil::GCD(static_cast<unsigned>(this->num_shards_per_dim[i]), |
145 | static_cast<unsigned>(other.num_shards_per_dim[i])); |
146 | new_sizes.push_back(lcm); |
147 | } |
148 | |
149 | // Expand and compare. |
150 | return ExpandShardVector(*this, new_sizes).shards == |
151 | ExpandShardVector(other, new_sizes).shards; |
152 | } |
153 | |
154 | std::string ShardVector::ToString() const { |
155 | std::string string = "shards:[" ; |
156 | // Convert each Shard into string. |
157 | std::vector<std::string> shard_strs; |
158 | shard_strs.reserve(shards.size()); |
159 | for (const Shard& shard : shards) |
160 | shard_strs.push_back("(" + absl::StrJoin(shard, "," ) + ")" ); |
161 | // Join shards, and append dimensions. |
162 | absl::StrAppend(&string, absl::StrJoin(shard_strs, "," )); |
163 | absl::StrAppend(&string, "] num_shards_per_dim:(" ); |
164 | absl::StrAppend(&string, absl::StrJoin(num_shards_per_dim, "," ) + ")" ); |
165 | return string; |
166 | } |
167 | |
168 | bool ShardVector::ContainsShard(const Shard& shard) const { |
169 | for (const auto& shard_in_vec : shards) |
170 | if (shard_in_vec == shard) return true; |
171 | return false; |
172 | } |
173 | |
174 | // static |
175 | std::map<std::string, std::vector<int>>& Mesh::tpu_core_ids() { |
176 | static auto tpu_core_ids = new std::map<std::string, std::vector<int>>(); |
177 | return *tpu_core_ids; |
178 | } |
179 | |
180 | // static |
181 | std::string& Mesh::tpu_host_mesh() { |
182 | static auto tpu_host_mesh = new std::string; |
183 | return *tpu_host_mesh; |
184 | } |
185 | |
186 | // static |
187 | StatusOr<Mesh> Mesh::ParseFromProto(const MeshProto& proto) { |
188 | Mesh mesh; |
189 | mesh.name_ = proto.name(); |
190 | |
191 | for (const auto& device : proto.local_devices()) { |
192 | mesh.local_devices_.push_back(device); |
193 | } |
194 | |
195 | // Define local device ids. |
196 | for (const auto& device_id : proto.local_device_ids()) { |
197 | mesh.local_device_ids_.push_back(device_id); |
198 | } |
199 | |
200 | for (const auto& device_id : proto.global_device_ids()) { |
201 | mesh.global_device_ids_.push_back(device_id); |
202 | } |
203 | |
204 | for (const auto& device : proto.global_devices()) { |
205 | mesh.global_devices_.push_back(device); |
206 | } |
207 | |
208 | // Assign Mesh Dimensions. |
209 | mesh.mesh_dims_.resize(proto.mesh_dimensions_size()); |
210 | for (int i = 0; i < proto.mesh_dimensions_size(); ++i) { |
211 | const MeshDimensionProto& dim = proto.mesh_dimensions(i); |
212 | mesh.mesh_dims_[i].name = dim.name(); |
213 | mesh.mesh_dims_[i].size = dim.size(); |
214 | } |
215 | |
216 | // Check invariants. |
217 | int64 mesh_size = mesh.size(); |
218 | int num_devices = proto.global_device_ids_size(); |
219 | if (mesh_size > 0 && mesh_size != num_devices) { |
220 | TF_RETURN_WITH_CONTEXT( |
221 | errors::InvalidArgument("Number of devices " , num_devices, |
222 | " not matching mesh size " , mesh_size)); |
223 | } |
224 | return mesh; |
225 | } |
226 | |
227 | // static |
228 | StatusOr<Mesh> Mesh::GetAbstractMesh( |
229 | const std::string& name, const std::vector<MeshDimension>& mesh_dims) { |
230 | Mesh mesh; |
231 | mesh.name_ = name; |
232 | mesh.mesh_dims_ = mesh_dims; |
233 | |
234 | // Check no repeated mesh dimension names. |
235 | std::set<std::string> dims_set; |
236 | for (const MeshDimension& dim : mesh.dims()) { |
237 | if (dims_set.find(dim.name) != dims_set.end()) |
238 | TF_RETURN_WITH_CONTEXT( |
239 | errors::InvalidArgument("repeated mesh dimension" )); |
240 | if (dim.name == Layout::kAny || dim.name == Layout::kMatch || |
241 | dim.name == Layout::kUnshardedDim) |
242 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh dimension name " , |
243 | dim.name, " is reserved" )); |
244 | dims_set.insert(dim.name); |
245 | } |
246 | |
247 | return mesh; |
248 | } |
249 | |
250 | // static |
251 | StatusOr<Mesh> Mesh::GetMesh(const std::string& name, |
252 | const std::vector<MeshDimension>& mesh_dims, |
253 | const std::vector<std::int64_t>& global_device_ids, |
254 | const std::vector<std::int64_t>& local_device_ids, |
255 | const std::vector<std::string>& local_devices, |
256 | const std::vector<std::string>& global_devices) { |
257 | TF_ASSIGN_OR_RETURN(Mesh mesh, GetAbstractMesh(name, mesh_dims)); |
258 | mesh.global_device_ids_ = global_device_ids; |
259 | mesh.local_device_ids_ = local_device_ids; |
260 | mesh.local_devices_ = local_devices; |
261 | mesh.global_devices_ = global_devices; |
262 | |
263 | // Check number of devices matches conditions. |
264 | size_t global_n = mesh.global_device_ids_.size(); |
265 | size_t local_n = mesh.local_device_ids_.size(); |
266 | size_t dev_n = mesh.local_devices_.size(); |
267 | |
268 | if (!(global_n >= local_n && dev_n == local_n)) |
269 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
270 | "number of global_device_ids " , std::to_string(global_n), |
271 | " local_devices ids " , std::to_string(local_n), " and local devices " , |
272 | std::to_string(dev_n), "not meeting requirements" )); |
273 | |
274 | // If empty device list, return empty mesh. |
275 | if (global_n == 0) return Mesh::Empty(); |
276 | |
277 | if (local_n && !(global_n % local_n == 0)) |
278 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
279 | "Uneven local clusters with global_ids " , std::to_string(global_n), |
280 | " and local_devices ids " , std::to_string(local_n))); |
281 | |
282 | // Check mesh size matches number of devices. |
283 | if (mesh.size() != global_n) |
284 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh size doesn't match" , |
285 | "number of devices" )); |
286 | |
287 | // Check local device invariants. |
288 | TF_ASSIGN_OR_RETURN(const auto& parsed_devs, mesh.ParsedDevices()); |
289 | std::set<std::string> types_set; |
290 | for (const DeviceNameUtils::ParsedName& dev : parsed_devs) { |
291 | if (!dev.has_job || !dev.has_task || !dev.has_type) |
292 | return errors::InvalidArgument( |
293 | "Failed to either identify host or device type" ); |
294 | types_set.insert(dev.type); |
295 | if (types_set.size() > 1) |
296 | return errors::InvalidArgument( |
297 | "More than one device type per mesh not supported. Found " , |
298 | types_set.size()); |
299 | } |
300 | |
301 | return mesh; |
302 | } |
303 | |
304 | StatusOr<int64_t> Mesh::dim_size(absl::string_view name) const { |
305 | for (const auto& mesh_dim : dims()) { |
306 | if (name == mesh_dim.name) { |
307 | return mesh_dim.size; |
308 | } |
309 | } |
310 | |
311 | std::vector<std::string> dim_names; |
312 | for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name); |
313 | |
314 | return errors::NotFound( |
315 | "Dimension " , name, " does not exist in mesh." , |
316 | "Available dimensions: " , absl::StrJoin(dim_names, "," )); |
317 | } |
318 | |
319 | std::vector<int64_t> Mesh::dim_sizes() const { |
320 | std::vector<int64_t> dim_sizes; |
321 | if (mesh_dims_.empty()) return dim_sizes; |
322 | for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size); |
323 | return dim_sizes; |
324 | } |
325 | |
326 | bool Mesh::operator==(const Mesh& b) const { |
327 | return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto()); |
328 | } |
329 | |
330 | bool Mesh::IsEmpty() const { return global_device_ids_.empty(); } |
331 | |
332 | StatusOr<const std::vector<DeviceNameUtils::ParsedName>> Mesh::ParsedDevices() |
333 | const { |
334 | std::vector<DeviceNameUtils::ParsedName> parsed_devices( |
335 | local_devices_.size()); |
336 | for (std::size_t i = 0; i < local_devices_.size(); ++i) |
337 | if (!DeviceNameUtils::ParseFullOrLocalName( |
338 | absl::string_view(local_devices_[i]), &parsed_devices[i])) |
339 | return errors::InvalidArgument("Failed to parse local_devices" ); |
340 | |
341 | return parsed_devices; |
342 | } |
343 | |
344 | StatusOr<Mesh> Mesh::ToDeviceType(const std::string& device_type) const { |
345 | std::vector<std::string> to_local_devices; |
346 | DeviceNameUtils::ParsedName parsed_dev; |
347 | for (const std::string& local_dev : local_devices_) { |
348 | if (!DeviceNameUtils::ParseFullOrLocalName(absl::string_view(local_dev), |
349 | &parsed_dev)) { |
350 | return errors::InvalidArgument("Failed to parse local devices" ); |
351 | } |
352 | // Converted mesh using full task name with job, replica and task ids. |
353 | to_local_devices.push_back( |
354 | DeviceNameUtils::FullName(parsed_dev.job, parsed_dev.replica, |
355 | parsed_dev.task, device_type, parsed_dev.id)); |
356 | parsed_dev.Clear(); |
357 | } |
358 | return GetMesh(name_, mesh_dims_, global_device_ids_, local_device_ids_, |
359 | to_local_devices, /*global_devices=*/{}); |
360 | } |
361 | |
362 | namespace { |
363 | std::string HostFromParsedDev(const DeviceNameUtils::ParsedName& dev) { |
364 | return "/job:" + dev.job + "/task:" + std::to_string(dev.task); |
365 | } |
366 | } // namespace |
367 | |
368 | std::vector<std::string> Mesh::hosts() const { |
369 | std::vector<std::string> host_list; |
370 | if (IsEmpty()) return host_list; |
371 | |
372 | const auto parsed_devices = ParsedDevices().value(); |
373 | for (const DeviceNameUtils::ParsedName& dev : parsed_devices) { |
374 | std::string host = HostFromParsedDev(dev); |
375 | if (std::find(host_list.begin(), host_list.end(), host) == host_list.end()) |
376 | host_list.push_back(host); |
377 | } |
378 | return host_list; |
379 | } |
380 | |
381 | std::string Mesh::device_type() const { |
382 | if (IsEmpty()) return std::string(); |
383 | std::string device; |
384 | if (!global_devices_.empty()) { |
385 | device = global_devices_[0]; |
386 | } else { |
387 | device = local_devices_[0]; |
388 | } |
389 | DeviceNameUtils::ParsedName dev; |
390 | DeviceNameUtils::ParseFullOrLocalName(device, &dev); |
391 | return dev.type; |
392 | } |
393 | |
394 | bool Mesh::IsMeshDim(const std::string& dim_name) const { |
395 | for (const auto& mesh_dim : dims()) |
396 | if (dim_name == mesh_dim.name) return true; |
397 | return false; |
398 | } |
399 | |
400 | int Mesh::GetMeshDimIndexWithName(const std::string& mesh_name) const { |
401 | int mesh_index = -1; |
402 | for (int i = 0; i < dims().size(); ++i) { |
403 | const auto mesh_dim = dim(i); |
404 | if (mesh_dim.name == mesh_name) mesh_index = i; |
405 | } |
406 | assert(mesh_index >= 0); |
407 | return mesh_index; |
408 | } |
409 | |
410 | int64 Mesh::rank() const { return mesh_dims_.size(); } |
411 | |
412 | int64 Mesh::size() const { |
413 | if (mesh_dims_.empty()) return 0; |
414 | |
415 | int64 size = 1; |
416 | for (const MeshDimension& dim : mesh_dims_) size *= dim.size; |
417 | return size; |
418 | } |
419 | |
420 | Mesh Mesh::Empty() { return Mesh(); } |
421 | |
422 | MeshProto Mesh::ToProto() const { |
423 | MeshProto mesh_proto; |
424 | mesh_proto.set_name(name()); |
425 | |
426 | for (const auto& d : local_devices_) { |
427 | mesh_proto.add_local_devices(d); |
428 | } |
429 | |
430 | for (const auto& i : local_device_ids_) { |
431 | mesh_proto.add_local_device_ids(i); |
432 | } |
433 | |
434 | for (const auto& i : global_device_ids_) { |
435 | mesh_proto.add_global_device_ids(i); |
436 | } |
437 | |
438 | for (const auto& dim : mesh_dims_) { |
439 | MeshDimensionProto* mesh_dim_proto = mesh_proto.add_mesh_dimensions(); |
440 | mesh_dim_proto->set_name(dim.name); |
441 | mesh_dim_proto->set_size(dim.size); |
442 | } |
443 | |
444 | for (const auto& d : global_devices_) { |
445 | mesh_proto.add_global_devices(d); |
446 | } |
447 | return mesh_proto; |
448 | } |
449 | |
450 | std::string Mesh::ToString() const { |
451 | if (Mesh::IsEmpty()) return kEmptyMeshString; |
452 | |
453 | // We use "|" to separate name, mesh dimensions and devices. |
454 | std::string mesh_str = absl::StrCat(Mesh::name(), "|" ); |
455 | |
456 | // Add mesh dimensions |
457 | absl::InlinedVector<std::string, 4> mesh_dim_lst; |
458 | for (const auto& dim : mesh_dims_) |
459 | mesh_dim_lst.push_back(absl::StrCat(dim.name, "=" , dim.size)); |
460 | mesh_str += absl::StrJoin(mesh_dim_lst, "," ) + "|" ; |
461 | |
462 | // Add flattened list of global device ids |
463 | mesh_str += absl::StrJoin(global_device_ids_, "," ) + "|" ; |
464 | |
465 | // Add flattened list of local device ids |
466 | mesh_str += absl::StrJoin(local_device_ids_, "," ) + "|" ; |
467 | |
468 | // Add flattened list of local devices |
469 | mesh_str += absl::StrJoin(local_devices_, "," ); |
470 | |
471 | if (!global_devices_.empty()) { |
472 | // Add flattened list of global devices |
473 | mesh_str += "|" ; |
474 | mesh_str += absl::StrJoin(global_devices_, "," ); |
475 | } |
476 | return mesh_str; |
477 | } |
478 | |
479 | uint64 Mesh::GlobalFingerprint() const { |
480 | if (Mesh::IsEmpty()) return Fingerprint64(kEmptyMeshString); |
481 | |
482 | std::string mesh_str; |
483 | // Add mesh dimensions |
484 | absl::InlinedVector<std::string, 4> mesh_dim_lst; |
485 | for (const auto& dim : mesh_dims_) |
486 | mesh_dim_lst.push_back(absl::StrCat(dim.name, "=" , dim.size)); |
487 | mesh_str += absl::StrJoin(mesh_dim_lst, "," ) + "|" ; |
488 | |
489 | // Ignore local_device_ids_, local_devices and name which might be not global |
490 | // unique. |
491 | // Add flattened list of global device ids |
492 | mesh_str += absl::StrJoin(global_device_ids_, "," ) + "|" ; |
493 | |
494 | if (!global_devices_.empty()) { |
495 | // Add flattened list of global devices |
496 | mesh_str += "|" ; |
497 | mesh_str += absl::StrJoin(global_devices_, "," ); |
498 | } |
499 | // mesh dims | global device ids (| global devices) |
500 | return Fingerprint64(mesh_str); |
501 | } |
502 | |
503 | namespace { |
504 | MeshDimension StrToMeshDimension(const std::string& str) { |
505 | MeshDimension mesh_dim; |
506 | if (str.empty()) return mesh_dim; |
507 | |
508 | std::vector<std::string> mesh_dim_parts = absl::StrSplit(str, '='); |
509 | |
510 | mesh_dim.name = mesh_dim_parts[0]; |
511 | mesh_dim.size = std::stoi(mesh_dim_parts[1]); |
512 | return mesh_dim; |
513 | } |
514 | |
515 | StatusOr<Mesh> GenerateMeshDevicesForTests( |
516 | const std::string& name, const std::vector<MeshDimension>& mesh_dims, |
517 | const std::string& mesh_gen_instruction) { |
518 | // Parse mesh generation instruction. |
519 | std::vector<std::string> instruction_parts = |
520 | absl::StrSplit(mesh_gen_instruction, '*'); |
521 | if (instruction_parts.size() != 2) |
522 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
523 | "Expected a * in mesh_gen_instructions but found " , |
524 | mesh_gen_instruction)); |
525 | std::string device_type = instruction_parts[1]; |
526 | |
527 | // Get Mesh Size. |
528 | int64 mesh_size = 0; |
529 | if (!mesh_dims.empty()) { |
530 | mesh_size = 1; |
531 | for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size; |
532 | } |
533 | |
534 | // Generate device ids. |
535 | std::vector<int64_t> global_device_ids; |
536 | std::vector<int64_t> local_device_ids; |
537 | std::vector<std::string> local_devices; |
538 | for (std::size_t i = 0; i < mesh_size; ++i) { |
539 | global_device_ids.push_back(i); |
540 | local_device_ids.push_back(i); |
541 | local_devices.push_back("/job:localhost/task:0/device:" + device_type + |
542 | ":" + std::to_string(i)); |
543 | } |
544 | |
545 | TF_ASSIGN_OR_RETURN( |
546 | Mesh mesh, |
547 | Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids, |
548 | local_devices, /*global_devices=*/{})); |
549 | return mesh; |
550 | } |
551 | } // namespace |
552 | |
553 | // static |
554 | StatusOr<Mesh> Mesh::FromString(const std::string& str) { |
555 | if (str == kEmptyMeshString) return Mesh::Empty(); |
556 | |
557 | std::vector<std::string> mesh_parts = absl::StrSplit(str, '|'); |
558 | |
559 | // Check formatting error. |
560 | if (mesh_parts.size() != 3 && mesh_parts.size() != 5 && |
561 | mesh_parts.size() != 6) |
562 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
563 | "Expected either 5, 6 or 3 mesh parts but found" , mesh_parts.size())); |
564 | |
565 | // Populate mesh. |
566 | std::string name = mesh_parts[0]; |
567 | |
568 | // Add mesh dimensions. |
569 | std::vector<MeshDimension> mesh_dims; |
570 | if (!mesh_parts[1].empty()) { |
571 | std::vector<std::string> mesh_dim_strs = absl::StrSplit(mesh_parts[1], ','); |
572 | mesh_dims.reserve(mesh_dim_strs.size()); |
573 | for (const std::string& mesh_dim_str : mesh_dim_strs) |
574 | mesh_dims.push_back(StrToMeshDimension(mesh_dim_str)); |
575 | } |
576 | |
577 | // Check if mesh is set to be autogenerated. |
578 | if (mesh_parts.size() == 3) |
579 | return GenerateMeshDevicesForTests(name, mesh_dims, mesh_parts[2]); |
580 | |
581 | // Add global device ids list. |
582 | std::vector<int64_t> global_device_ids; |
583 | if (!mesh_parts[2].empty()) { |
584 | std::vector<std::string> global_device_ids_strs = |
585 | absl::StrSplit(mesh_parts[2], ','); |
586 | |
587 | global_device_ids.reserve(global_device_ids_strs.size()); |
588 | for (const std::string& id : global_device_ids_strs) |
589 | global_device_ids.push_back(std::stoi(id)); |
590 | } |
591 | |
592 | // Add local device ids list. |
593 | std::vector<int64_t> local_device_ids; |
594 | if (!mesh_parts[3].empty()) { |
595 | std::vector<std::string> local_device_ids_strs = |
596 | absl::StrSplit(mesh_parts[3], ','); |
597 | |
598 | local_device_ids.reserve(local_device_ids_strs.size()); |
599 | for (const std::string& id : local_device_ids_strs) |
600 | local_device_ids.push_back(std::stoi(id)); |
601 | } |
602 | // Add local devices. |
603 | std::vector<std::string> local_devices; |
604 | if (!mesh_parts[4].empty()) |
605 | local_devices = absl::StrSplit(mesh_parts[4], ','); |
606 | |
607 | std::vector<std::string> global_devices; |
608 | if (mesh_parts.size() == 6) { |
609 | // Add global devices. |
610 | if (!mesh_parts[5].empty()) |
611 | global_devices = absl::StrSplit(mesh_parts[5], ','); |
612 | } |
613 | |
614 | TF_ASSIGN_OR_RETURN( |
615 | Mesh mesh, |
616 | Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids, |
617 | local_devices, global_devices)); |
618 | return mesh; |
619 | } |
620 | |
621 | int64 Mesh::num_devices() const { return global_device_ids_.size(); } |
622 | |
623 | StatusOr<const DeviceLocation> Mesh::device_location(int offset) const { |
624 | if (offset < 0 || offset > size() - 1) |
625 | return errors::InvalidArgument( |
626 | "Mesh offset cannot be negative or exceed Mesh's size. Offset size:" , |
627 | offset, " and Mesh size:" , size()); |
628 | |
629 | DeviceLocation dev_loc; |
630 | std::vector<int64> mesh_dim_lengths = dim_sizes(); |
631 | int64 i = mesh_dim_lengths.size() - 1; |
632 | while (i >= 0) { |
633 | dev_loc.insert(dev_loc.begin(), offset % mesh_dim_lengths[i]); |
634 | offset /= mesh_dim_lengths[i]; |
635 | --i; |
636 | } |
637 | return dev_loc; |
638 | } |
639 | |
640 | int64 Mesh::GetFlattenedCoordinate(const DeviceLocation& loc) const { |
641 | const std::vector<int64> mesh_dim_sizes = dim_sizes(); |
642 | int64 i = mesh_dim_sizes.size() - 1; |
643 | int64 acc = 1; |
644 | int64 device_pos = 0; |
645 | while (i >= 0) { |
646 | device_pos += loc[i] * acc; |
647 | acc *= mesh_dim_sizes[i]; |
648 | --i; |
649 | } |
650 | return device_pos; |
651 | } |
652 | |
653 | StatusOr<int32> Mesh::idx_for_dim(absl::string_view dim_name) const { |
654 | for (int i = 0; i < mesh_dims_.size(); ++i) { |
655 | if (mesh_dims_[i].name == dim_name) return i; |
656 | } |
657 | return errors::InvalidArgument("dim name :" , dim_name, |
658 | " does not exist on mesh : " , ToString()); |
659 | } |
660 | |
661 | StatusOr<Layout> Layout::GetLayout( |
662 | const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh) { |
663 | // Re-format sharding specs. |
664 | std::vector<ShardingSpec> sharding_specs; |
665 | sharding_specs.reserve(sharding_spec_strs.size()); |
666 | for (const std::string& spec_str : sharding_spec_strs) { |
667 | ShardingSpec spec; |
668 | spec.set_sharding_spec(spec_str); |
669 | sharding_specs.push_back(spec); |
670 | } |
671 | return GetLayout(sharding_specs, mesh); |
672 | } |
673 | |
674 | StatusOr<Layout> Layout::GetLayout( |
675 | const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh) { |
676 | Layout layout; |
677 | // Append mesh, then check sharding_specs are legal. |
678 | layout.mesh_ = mesh; |
679 | |
680 | // Check sharding_specs are either mesh dimension or special value. |
681 | for (const auto& dim : sharding_specs) { |
682 | const std::string& sharding_spec = dim.sharding_spec(); |
683 | if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny || |
684 | sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) || |
685 | sharding_spec == "scalar" )) |
686 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
687 | "sharding spec (" , sharding_spec, |
688 | ") refers to mesh dimension not contained in mesh " , |
689 | mesh.ToString())); |
690 | } |
691 | // Check same tensor dimensions not sharded over same mesh dimension twice. |
692 | std::set<std::string> dims_set; |
693 | for (const auto& dim : sharding_specs) { |
694 | const std::string& sharding_spec = dim.sharding_spec(); |
695 | if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue; |
696 | // If scalar, delete all sharding specs. |
697 | if (sharding_spec == "scalar" ) { |
698 | if (sharding_specs.size() > 1) |
699 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
700 | "A scalar sharding_spec can only be used as a single sharding_spec " |
701 | "instruction, not as part of list of sharding_specs as attempted " |
702 | "here with " , |
703 | sharding_specs.size(), " sharding_specs" )) |
704 | // Return layout with empty spec to represent scalar behavior. |
705 | return layout; |
706 | } |
707 | if (dims_set.find(sharding_spec) != dims_set.end()) |
708 | TF_RETURN_WITH_CONTEXT( |
709 | errors::InvalidArgument("Attempted to shard two or more tensor " |
710 | "dimensions over mesh dimension " , |
711 | sharding_spec)) |
712 | dims_set.insert(sharding_spec); |
713 | } |
714 | // After checking sharding_specs are legal, append and return layout. |
715 | layout.sharding_specs_ = sharding_specs; |
716 | return layout; |
717 | } |
718 | |
719 | Layout Layout::Empty() { |
720 | Layout result; |
721 | return result; |
722 | } |
723 | |
724 | bool Layout::IsEmpty() const { return mesh_.IsEmpty(); } |
725 | |
726 | namespace { |
727 | Mesh ReducedAbstractMesh(const Layout* layout) { |
728 | const std::vector<std::string>& shard_spec_strs = |
729 | layout->sharding_spec_strs(); |
730 | std::vector<MeshDimension> reduced_mesh_dims; |
731 | reduced_mesh_dims.reserve(layout->mesh().dims().size()); |
732 | for (const MeshDimension& mesh_dim : layout->mesh().dims()) { |
733 | bool IsMeshDimInShardingSpecs = |
734 | std::find(shard_spec_strs.begin(), shard_spec_strs.end(), |
735 | mesh_dim.name) != shard_spec_strs.end(); |
736 | // If dimension not in sharding_spec, flip size to 1. |
737 | MeshDimension reduced_dim = |
738 | IsMeshDimInShardingSpecs ? mesh_dim : MeshDimension(mesh_dim.name, 1); |
739 | reduced_mesh_dims.push_back(reduced_dim); |
740 | } |
741 | return Mesh::GetAbstractMesh("" , reduced_mesh_dims).value(); |
742 | } |
743 | |
744 | } // namespace |
745 | |
746 | Mesh Layout::ReducedMesh() const { |
747 | // Set replicated mesh dimensions to size 1, and create reduced abstract mesh. |
748 | Mesh reduced_mesh = ReducedAbstractMesh(this); |
749 | |
750 | // Populate reduced mesh with global devices from original mesh. |
751 | std::vector<int64_t> reduced_global_device_ids; |
752 | std::vector<std::string> reduced_global_devs; |
753 | for (const DeviceLocation& loc : ComputeDeviceLocations(&reduced_mesh)) { |
754 | int64 pos = mesh().GetFlattenedCoordinate(loc); |
755 | reduced_global_device_ids.push_back(mesh().global_device_ids().at(pos)); |
756 | if (!mesh().global_devices().empty()) { |
757 | reduced_global_devs.push_back(mesh().global_devices().at(pos)); |
758 | } |
759 | } |
760 | |
761 | // Track the set of global device IDs in the abstract mesh. |
762 | std::set<int64_t> reduced_global_device_ids_set( |
763 | reduced_global_device_ids.begin(), reduced_global_device_ids.end()); |
764 | |
765 | // Populate reduced mesh with local devices in the same order as the original |
766 | // mesh. |
767 | std::vector<int64_t> reduced_local_device_ids; |
768 | std::vector<std::string> reduced_local_devs; |
769 | for (size_t i = 0; i < mesh().local_device_ids().size(); ++i) { |
770 | int64_t device_id = mesh().local_device_ids().at(i); |
771 | if (reduced_global_device_ids_set.find(device_id) != |
772 | reduced_global_device_ids_set.end()) { |
773 | reduced_local_device_ids.push_back(device_id); |
774 | reduced_local_devs.push_back(mesh().local_devices().at(i)); |
775 | } |
776 | } |
777 | |
778 | return Mesh::GetMesh(reduced_mesh.name(), reduced_mesh.dims(), |
779 | reduced_global_device_ids, reduced_local_device_ids, |
780 | reduced_local_devs, reduced_global_devs) |
781 | .value(); |
782 | } |
783 | |
784 | namespace { |
785 | Layout ReducedLayout(const Layout* layout) { |
786 | // Change format sharding specs. |
787 | std::vector<ShardingSpec> shard_specs(layout->sharding_specs().size()); |
788 | for (size_t i = 0; i < shard_specs.size(); ++i) |
789 | shard_specs[i] = layout->dim(i); |
790 | // Retrieve layout. |
791 | return Layout::GetLayout(shard_specs, layout->ReducedMesh()).value(); |
792 | } |
793 | |
794 | // Returns index of the given mesh dimension or mesh dim size if not found. |
795 | StatusOr<int> IndexOfMeshDimension(const Mesh& mesh, |
796 | const std::string& dim_name) { |
797 | for (size_t i = 0; i < mesh.dims().size(); ++i) |
798 | if (dim_name == mesh.dims()[i].name) return i; |
799 | return errors::InvalidArgument("Mesh dimension not found" ); |
800 | } |
801 | } // namespace |
802 | |
803 | ShardVector Layout::GetShardVector() const { |
804 | // Change format sharding specs. |
805 | std::vector<ShardingSpec> shard_specs(sharding_specs().size()); |
806 | for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i); |
807 | // Obtain a shard position (i.e. sharded section of a tensor) from a mesh |
808 | // location, using the sharding specs. |
809 | auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard { |
810 | Shard shard; |
811 | for (size_t i = 0; i < shard_specs.size(); ++i) { |
812 | // If unsharded, there is only one shard, that is 1. |
813 | std::string spec = shard_specs[i].sharding_spec(); |
814 | if (spec == Layout::kUnshardedDim) { |
815 | shard.push_back(1); |
816 | } else { |
817 | int mesh_index = IndexOfMeshDimension(mesh(), sharding_spec(i)).value(); |
818 | int shard_number = loc[mesh_index] + 1; |
819 | shard.push_back(shard_number); |
820 | } |
821 | } |
822 | return shard; |
823 | }; |
824 | // Obtain dims of shard vector. |
825 | auto ShardVectorDims = [&]() -> std::vector<int> { |
826 | std::vector<int> num_shards_per_dim(shard_specs.size()); |
827 | for (size_t i = 0; i < sharding_specs().size(); ++i) { |
828 | ShardingSpec spec = sharding_specs()[i]; |
829 | if (Layout::IsShardedSpec(spec)) { |
830 | StatusOr<int64> dim_size = mesh().dim_size(spec.sharding_spec()); |
831 | num_shards_per_dim[i] = dim_size.value(); |
832 | } else { |
833 | num_shards_per_dim[i] = 1; |
834 | } |
835 | } |
836 | return num_shards_per_dim; |
837 | }; |
838 | // Compute mesh locations and obtain shards from them. |
839 | ShardVector shard_vec; |
840 | for (const DeviceLocation& mesh_loc : ComputeDeviceLocations(&mesh())) |
841 | shard_vec.shards.push_back(GetShardFromDeviceLocation(mesh_loc)); |
842 | // Calculate dims. |
843 | shard_vec.num_shards_per_dim = ShardVectorDims(); |
844 | return shard_vec; |
845 | } |
846 | |
847 | std::map<std::string, ShardVector> Layout::HostShardMap() const { |
848 | Layout reduced_layout = ReducedLayout(this); |
849 | Mesh reduced_mesh = reduced_layout.mesh(); |
850 | using HostName = std::string; |
851 | |
852 | // Build a map: {Host : Shards} |
853 | std::map<HostName, ShardVector> host_shards_map; |
854 | ShardVector shard_vec_in_red_layout = reduced_layout.GetShardVector(); |
855 | |
856 | const auto parsed_devs = reduced_mesh.ParsedDevices().value(); |
857 | for (size_t i = 0; i < parsed_devs.size(); ++i) { |
858 | HostName host = HostFromParsedDev(parsed_devs[i]); |
859 | Shard shard_in_device = shard_vec_in_red_layout.shards[i]; |
860 | |
861 | // Check if host in hashtable and append shard. |
862 | auto it = host_shards_map.find(host); |
863 | if (it == host_shards_map.end()) { |
864 | ShardVector shard_vec_in_host; |
865 | shard_vec_in_host.shards.push_back(shard_in_device); |
866 | shard_vec_in_host.num_shards_per_dim = |
867 | shard_vec_in_red_layout.num_shards_per_dim; |
868 | host_shards_map.insert( |
869 | std::pair<HostName, ShardVector>(host, shard_vec_in_host)); |
870 | } else { |
871 | bool isShardInShardVector = it->second.ContainsShard(shard_in_device); |
872 | if (!isShardInShardVector) { |
873 | it->second.shards.push_back(shard_in_device); |
874 | } |
875 | } |
876 | } |
877 | // Sort shards inside each host. |
878 | for (auto it = host_shards_map.begin(); it != host_shards_map.end(); ++it) { |
879 | std::sort(it->second.shards.begin(), it->second.shards.end()); |
880 | } |
881 | return host_shards_map; |
882 | } |
883 | |
884 | const std::string& Layout::sharding_spec(int idx) const { |
885 | return sharding_specs_[idx].sharding_spec(); |
886 | } |
887 | |
888 | std::vector<int32> Layout::num_shards() const { |
889 | std::vector<int32> num_shards; |
890 | num_shards.reserve(sharding_specs_.size()); |
891 | for (const auto& sharding_spec : sharding_specs_) { |
892 | num_shards.push_back(num_shards_for_dim(sharding_spec)); |
893 | } |
894 | return num_shards; |
895 | } |
896 | |
897 | size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const { |
898 | absl::string_view name = dim.sharding_spec(); |
899 | if (name == Layout::kUnshardedDim) return 1; |
900 | if (name == Layout::kMatch) return -1; |
901 | |
902 | return mesh().dim_size(name).value(); |
903 | } |
904 | |
905 | bool Layout::IsFullyReplicated() const { |
906 | for (const auto& sharding_spec : sharding_specs_) { |
907 | if (num_shards_for_dim(sharding_spec) > 1) { |
908 | return false; |
909 | } |
910 | } |
911 | return true; |
912 | } |
913 | |
914 | bool Layout::IsLastDimReplicated() const { |
915 | return (sharding_specs_.empty()) || |
916 | (num_shards_for_dim(sharding_specs_.back()) == 1); |
917 | } |
918 | |
919 | bool Layout::IsBatchParallel() const { |
920 | if (sharding_specs_.empty()) { |
921 | return true; |
922 | } |
923 | |
924 | for (int i = 1; i < sharding_specs_.size(); ++i) { |
925 | const auto& dim = sharding_specs_[i]; |
926 | if (num_shards_for_dim(dim) != 1) { |
927 | return false; |
928 | } |
929 | } |
930 | return true; |
931 | } |
932 | |
933 | // TODO(samuelslee) Replace this with the IsBatchParallel() everywhere |
934 | bool Layout::IsBatchParallel(int non_batch_rank) const { |
935 | if (sharding_specs_.empty()) return true; |
936 | for (int i = rank() - non_batch_rank; i < rank(); ++i) { |
937 | if (num_shards_for_dim(sharding_specs_[i]) != 1) return false; |
938 | } |
939 | return true; |
940 | } |
941 | |
942 | LayoutProto Layout::ToProto() const { |
943 | LayoutProto proto; |
944 | *proto.mutable_mesh_config() = mesh_.ToProto(); |
945 | for (const auto& dim : sharding_specs_) { |
946 | *proto.add_sharding_specs() = dim; |
947 | } |
948 | return proto; |
949 | } |
950 | |
951 | bool Layout::IsEquivalent(const Layout& b) const { |
952 | if (this->rank() != b.rank()) return false; |
953 | if (this->mesh() != b.mesh()) return false; |
954 | for (int i = 0; i < this->rank(); ++i) { |
955 | if (this->sharding_specs_[i].sharding_spec() != |
956 | b.sharding_specs_[i].sharding_spec()) { |
957 | if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) || |
958 | (b.num_shards_for_dim(b.sharding_specs_[i]) != 1)) |
959 | return false; |
960 | } |
961 | } |
962 | return true; |
963 | } |
964 | |
965 | bool Layout::operator==(const Layout& b) const { |
966 | return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto()); |
967 | } |
968 | |
969 | std::vector<int64_t> Layout::GlobalShapeFromLocalShape( |
970 | const std::vector<int64_t>& local_shape) const { |
971 | if (IsFullyReplicated()) { |
972 | return local_shape; |
973 | } |
974 | std::vector<int64_t> global_shape; |
975 | global_shape.reserve(sharding_specs().size()); |
976 | for (int i = 0; i < sharding_specs().size(); ++i) { |
977 | int64_t l_shape = local_shape.empty() ? 1 : local_shape[i]; |
978 | int64_t dim_shards = num_shards()[i]; |
979 | global_shape.emplace_back(l_shape * dim_shards); |
980 | } |
981 | return global_shape; |
982 | } |
983 | |
984 | std::vector<int64_t> Layout::LocalShapeFromGlobalShape( |
985 | absl::Span<const int64_t> global_shape) const { |
986 | if (IsFullyReplicated()) { |
987 | return std::vector<int64_t>(global_shape.begin(), global_shape.end()); |
988 | } |
989 | std::vector<int32> shards = num_shards(); |
990 | std::vector<int64_t> local_shape; |
991 | for (int i = 0; i < sharding_specs().size(); ++i) { |
992 | int64_t dim_shards = shards[i]; |
993 | // TODO(hthu): Shape might not be always divisible. |
994 | local_shape.emplace_back(global_shape[i] / dim_shards); |
995 | } |
996 | return local_shape; |
997 | } |
998 | |
999 | PartialTensorShape Layout::LocalShapeFromGlobalShape( |
1000 | const PartialTensorShape& global_shape) const { |
1001 | if (IsFullyReplicated() || global_shape.dims() == -1) { |
1002 | return global_shape; |
1003 | } |
1004 | std::vector<int32> shards = num_shards(); |
1005 | PartialTensorShape local_shape({}); |
1006 | for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) { |
1007 | int64_t dim_size = global_shape.dim_size(spec_index); |
1008 | local_shape.AddDim(dim_size == -1 ? -1 : dim_size / shards[spec_index]); |
1009 | } |
1010 | return local_shape; |
1011 | } |
1012 | |
1013 | StatusOr<Layout> Layout::FromProto(const LayoutProto& proto) { |
1014 | Layout layout; |
1015 | for (const auto& spec : proto.sharding_specs()) |
1016 | layout.sharding_specs_.push_back(spec); |
1017 | |
1018 | TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config())); |
1019 | layout.mesh_ = std::move(mesh); |
1020 | |
1021 | return GetLayout(layout.sharding_specs_, layout.mesh_); |
1022 | } |
1023 | |
1024 | Layout Layout::ReplicatedOnMesh(const Mesh& mesh, int rank) { |
1025 | std::vector<std::string> specs(rank, kUnshardedDim); |
1026 | return Layout::GetLayout(specs, mesh).value(); |
1027 | } |
1028 | |
1029 | Layout Layout::AnyOnMesh(const Mesh& mesh, int rank) { |
1030 | std::vector<std::string> specs(rank, kAny); |
1031 | return Layout::GetLayout(specs, mesh).value(); |
1032 | } |
1033 | |
1034 | StatusOr<Layout> Layout::Transposed2D(const Layout& layout) { |
1035 | if (layout.rank() < 2) { |
1036 | return errors::InvalidArgument("Transposed2D requires rank to be >= 2" ); |
1037 | } |
1038 | std::vector<std::string> transposed_specs = layout.sharding_spec_strs(); |
1039 | std::iter_swap(transposed_specs.end() - 2, transposed_specs.end() - 1); |
1040 | return Layout::GetLayout(transposed_specs, layout.mesh()).value(); |
1041 | } |
1042 | |
1043 | // static |
1044 | StatusOr<Layout> Layout::FromString(std::string layout_str) { |
1045 | if (layout_str == kEmptyLayoutString) return Layout::Empty(); |
1046 | |
1047 | // Print sharding specs. |
1048 | std::vector<absl::string_view> layout_parts = absl::StrSplit(layout_str, ' '); |
1049 | // Check formatting error. |
1050 | if (layout_parts.size() != 2) { |
1051 | TF_RETURN_WITH_CONTEXT(errors::InvalidArgument( |
1052 | "Expected 2 items but found " , layout_parts.size(), layout_parts[0])); |
1053 | } |
1054 | // Substract prefixes. |
1055 | absl::string_view sharding_spec_str = layout_parts[0]; |
1056 | absl::ConsumePrefix(&sharding_spec_str, "sharding_specs:" ); |
1057 | |
1058 | absl::string_view mesh_str = layout_parts[1]; |
1059 | absl::ConsumePrefix(&mesh_str, "mesh:" ); |
1060 | |
1061 | // Add sharding specs. |
1062 | std::vector<std::string> sharding_spec_strs = |
1063 | absl::StrSplit(sharding_spec_str, ','); |
1064 | sharding_spec_strs.pop_back(); |
1065 | |
1066 | // Add mesh. |
1067 | TF_ASSIGN_OR_RETURN(Mesh mesh, Mesh::FromString(string(mesh_str))); |
1068 | // Try to create layout. |
1069 | TF_ASSIGN_OR_RETURN(Layout layout, |
1070 | Layout::GetLayout(sharding_spec_strs, mesh)); |
1071 | return layout; |
1072 | } |
1073 | |
1074 | std::vector<std::string> Layout::sharding_spec_strs() const { |
1075 | std::vector<std::string> sharding_spec_strs(sharding_specs().size()); |
1076 | for (size_t i = 0; i < sharding_specs().size(); ++i) |
1077 | sharding_spec_strs[i] = sharding_spec(i); |
1078 | return sharding_spec_strs; |
1079 | } |
1080 | |
1081 | std::string Layout::ToString() const { |
1082 | if (Layout::IsEmpty()) return kEmptyLayoutString; |
1083 | |
1084 | std::string layout_str = "sharding_specs:" ; |
1085 | // Print sharding specs. |
1086 | for (const ShardingSpec& dim : sharding_specs_) { |
1087 | std::string dim_name = dim.sharding_spec(); |
1088 | absl::StrAppend(&layout_str, dim_name + "," ); |
1089 | } |
1090 | // Append mesh. |
1091 | absl::StrAppend(&layout_str, " mesh:" , mesh_.ToString()); |
1092 | return layout_str; |
1093 | } |
1094 | |
1095 | Layout Layout::GetLayoutWithReducedDims( |
1096 | const absl::flat_hash_set<int>& reduced_dims, bool keep_dims) const { |
1097 | dtensor::LayoutProto output_layout; |
1098 | *output_layout.mutable_mesh_config() = mesh().ToProto(); |
1099 | |
1100 | for (int i = 0; i < rank(); ++i) { |
1101 | // reduced_dims may contain negative values. |
1102 | if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) { |
1103 | *output_layout.add_sharding_specs() = dim(i); |
1104 | } else if (keep_dims) { |
1105 | auto* replicated_dim = output_layout.add_sharding_specs(); |
1106 | replicated_dim->set_sharding_spec(kUnshardedDim); |
1107 | } |
1108 | } |
1109 | return Layout::FromProto(output_layout).value(); |
1110 | } |
1111 | |
1112 | Layout Layout::Truncate(int64 split_point, bool end) const { |
1113 | if ((split_point == 0 && end) || (split_point == rank() && !end)) |
1114 | return *this; |
1115 | |
1116 | dtensor::LayoutProto output_layout; |
1117 | *output_layout.mutable_mesh_config() = mesh().ToProto(); |
1118 | |
1119 | if (end) { |
1120 | for (int i = split_point; i < rank(); ++i) |
1121 | *output_layout.add_sharding_specs() = dim(i); |
1122 | } else { |
1123 | for (int i = 0; i < split_point; ++i) |
1124 | *output_layout.add_sharding_specs() = dim(i); |
1125 | } |
1126 | return Layout::FromProto(output_layout).value(); |
1127 | } |
1128 | |
1129 | namespace { |
1130 | // Adds unsharded sharding specs to layout. |
1131 | Layout PadLayout(const int64 rank, const bool is_padding_before, |
1132 | const Layout& layout) { |
1133 | if (rank <= layout.rank()) return layout; |
1134 | |
1135 | // Create list of padding sharding specs. |
1136 | const int n = rank - layout.rank(); |
1137 | std::vector<ShardingSpec> new_specs(n); |
1138 | for (int i = 0; i < n; ++i) |
1139 | new_specs[i].set_sharding_spec(Layout::kUnshardedDim); |
1140 | |
1141 | // Define concatenation point of layout specs. |
1142 | auto concat_point = is_padding_before ? new_specs.end() : new_specs.begin(); |
1143 | |
1144 | // Concatenate old layout specs and new unsharded specs. |
1145 | new_specs.insert(concat_point, layout.sharding_specs().begin(), |
1146 | layout.sharding_specs().end()); |
1147 | return Layout::GetLayout(new_specs, layout.mesh()).value(); |
1148 | } |
1149 | } // namespace |
1150 | |
1151 | Layout Layout::LeftPad(int64 rank) const { |
1152 | bool is_padding_before = true; |
1153 | return PadLayout(rank, is_padding_before, *this); |
1154 | } |
1155 | |
1156 | StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a, |
1157 | const Layout& layout_b) { |
1158 | if (layout_a.mesh() != layout_b.mesh()) |
1159 | return errors::InvalidArgument( |
1160 | "unable to concatenate layouts as they are on different meshes." ); |
1161 | |
1162 | absl::flat_hash_set<std::string> layout_a_mesh_dims; |
1163 | for (int i = 0; i < layout_a.rank(); ++i) |
1164 | if (layout_a.sharding_spec(i) != Layout::kUnshardedDim) |
1165 | layout_a_mesh_dims.emplace(layout_a.sharding_spec(i)); |
1166 | |
1167 | for (int i = 0; i < layout_b.rank(); ++i) |
1168 | if (layout_b.sharding_spec(i) != Layout::kUnshardedDim && |
1169 | layout_a_mesh_dims.contains(layout_b.sharding_spec(i))) |
1170 | return errors::InvalidArgument( |
1171 | "unable to concatenate layouts as they use the same meshes " |
1172 | "dimension: " , |
1173 | layout_b.sharding_spec(i), " is used in both layouts." ); |
1174 | |
1175 | LayoutProto layout_proto_a = layout_a.ToProto(); |
1176 | LayoutProto layout_proto_b = layout_b.ToProto(); |
1177 | LayoutProto output_layout_proto; |
1178 | |
1179 | *output_layout_proto.mutable_mesh_config() = layout_proto_a.mesh_config(); |
1180 | for (int i = 0; i < layout_proto_a.sharding_specs_size(); ++i) |
1181 | *output_layout_proto.add_sharding_specs() = |
1182 | layout_proto_a.sharding_specs(i); |
1183 | for (int i = 0; i < layout_proto_b.sharding_specs_size(); ++i) |
1184 | *output_layout_proto.add_sharding_specs() = |
1185 | layout_proto_b.sharding_specs(i); |
1186 | return Layout::FromProto(output_layout_proto); |
1187 | } |
1188 | |
1189 | } // namespace dtensor |
1190 | } // namespace tensorflow |
1191 | |