1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/tensor_layout.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <map>
21#include <memory>
22#include <numeric>
23#include <set>
24#include <string>
25#include <string_view>
26#include <utility>
27#include <vector>
28
29#include "absl/container/inlined_vector.h"
30#include "absl/strings/str_cat.h"
31#include "absl/strings/str_join.h"
32#include "absl/strings/str_split.h"
33#include "absl/strings/string_view.h"
34#include "absl/types/optional.h"
35#include "tensorflow/core/framework/tensor_shape.h"
36#include "tensorflow/core/lib/math/math_util.h"
37#include "tensorflow/core/platform/errors.h"
38#include "tensorflow/core/platform/fingerprint.h"
39#include "tensorflow/core/platform/logging.h"
40#include "tensorflow/core/platform/statusor.h"
41#include "tensorflow/core/util/device_name_utils.h"
42#include "tensorflow/dtensor/cc/dstatus.h"
43#include "tensorflow/dtensor/proto/layout.pb.h"
44
45namespace tensorflow {
46namespace dtensor {
47
48constexpr const char* Layout::kUnshardedDim;
49constexpr const char* Layout::kAny;
50constexpr const char* Layout::kEmptyLayoutString;
51constexpr const char* Layout::kMatch;
52constexpr const char* Mesh::kEmptyMeshString;
53
54namespace {
55// Obtain all possible forms of indexing a mesh.
56//
57// e.g. given a mesh with dimensions [x=2, y=3], returns {
58// [0, 0], [0, 1], [0, 2],
59// [1, 0], [1, 1], [1, 2]
60// }
61inline std::vector<DeviceLocation> ComputeDeviceLocations(const Mesh* mesh) {
62 std::vector<DeviceLocation> mesh_locs(mesh->size());
63 for (size_t i = 0; i < mesh->size(); ++i)
64 mesh_locs[i] = *(mesh->device_location(i));
65 return mesh_locs;
66}
67} // namespace
68
69namespace {
70// Expands a ShardVector into the size defined in new_num_shards_per_dim.
71//
72// For example, the inputs:
73// - shard_vec: shards = [(1,1)] num_shards_per_dim = [1,1]
74// - new_num_shards_per_dim = [2,2]
75//
76// Would lead to:
77// shard_vec: shards = [(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim = [2,2]
78//
79// This is used to check whether two ShardVectors contain the same information
80// while having different number of shards per dimension. The two ShardVectors
81// above are an example of this.
82ShardVector ExpandShardVector(const ShardVector& shard_vec,
83 const std::vector<int>& new_num_shards_per_dim) {
84 if (shard_vec.shards.empty()) return shard_vec;
85
86 // Takes a single shard and expands it into multiple shards.
87 auto ExpandShard = [shard_vec, new_num_shards_per_dim](
88 const Shard& shard,
89 int dim_ind) -> std::vector<Shard> {
90 int original_dim_size = shard_vec.num_shards_per_dim[dim_ind];
91 int new_dim_size = new_num_shards_per_dim[dim_ind];
92 int size_ratio = new_dim_size / original_dim_size;
93
94 std::vector<Shard> expanded_shards;
95 expanded_shards.reserve(size_ratio);
96 for (int i = 0; i < size_ratio; ++i) {
97 int original_coord = shard[dim_ind];
98 int shifted_coord = (original_coord - 1) * size_ratio + 1 + i;
99 // Copy original shard, then modify it.
100 Shard new_shard = shard;
101 new_shard[dim_ind] = shifted_coord;
102 expanded_shards.push_back(new_shard);
103 }
104 return expanded_shards;
105 };
106 // Iterates over the dimensions of the shard, expanding at each
107 // dimension.
108 std::vector<Shard> total_expanded_shards = shard_vec.shards;
109 for (int dim_ind = 0; dim_ind < new_num_shards_per_dim.size(); ++dim_ind) {
110 std::vector<Shard> dim_expanded_shards;
111 for (const auto& shard : total_expanded_shards) {
112 std::vector<Shard> expanded_shards = ExpandShard(shard, dim_ind);
113 // Concatenate newly created shards.
114 dim_expanded_shards.insert(dim_expanded_shards.end(),
115 expanded_shards.begin(),
116 expanded_shards.end());
117 }
118 // Copy newly created shards and delete old ones.
119 total_expanded_shards = dim_expanded_shards;
120 }
121 std::sort(total_expanded_shards.begin(), total_expanded_shards.end());
122 ShardVector expanded_shard_vec;
123 expanded_shard_vec.shards = total_expanded_shards;
124 expanded_shard_vec.num_shards_per_dim = new_num_shards_per_dim;
125 return expanded_shard_vec;
126}
127} // namespace
128
129bool ShardVector::operator==(const ShardVector& other) const {
130 // Check same number of shards.
131 if (this->shards.empty() && other.shards.empty()) return true;
132 if (this->shards.empty() || other.shards.empty()) return false;
133
134 // Check number of shard dimensions match.
135 if (this->num_shards_per_dim.size() != other.num_shards_per_dim.size())
136 return false;
137
138 // Compute lowest common multiple for each of the shard dimensions.
139 Shard first_shard_this = this->shards[0];
140 Shard first_shard_other = other.shards[0];
141 std::vector<int> new_sizes;
142 for (size_t i = 0; i < first_shard_this.size(); ++i) {
143 int lcm = this->num_shards_per_dim[i] * other.num_shards_per_dim[i] /
144 MathUtil::GCD(static_cast<unsigned>(this->num_shards_per_dim[i]),
145 static_cast<unsigned>(other.num_shards_per_dim[i]));
146 new_sizes.push_back(lcm);
147 }
148
149 // Expand and compare.
150 return ExpandShardVector(*this, new_sizes).shards ==
151 ExpandShardVector(other, new_sizes).shards;
152}
153
154std::string ShardVector::ToString() const {
155 std::string string = "shards:[";
156 // Convert each Shard into string.
157 std::vector<std::string> shard_strs;
158 shard_strs.reserve(shards.size());
159 for (const Shard& shard : shards)
160 shard_strs.push_back("(" + absl::StrJoin(shard, ",") + ")");
161 // Join shards, and append dimensions.
162 absl::StrAppend(&string, absl::StrJoin(shard_strs, ","));
163 absl::StrAppend(&string, "] num_shards_per_dim:(");
164 absl::StrAppend(&string, absl::StrJoin(num_shards_per_dim, ",") + ")");
165 return string;
166}
167
168bool ShardVector::ContainsShard(const Shard& shard) const {
169 for (const auto& shard_in_vec : shards)
170 if (shard_in_vec == shard) return true;
171 return false;
172}
173
174// static
175std::map<std::string, std::vector<int>>& Mesh::tpu_core_ids() {
176 static auto tpu_core_ids = new std::map<std::string, std::vector<int>>();
177 return *tpu_core_ids;
178}
179
180// static
181std::string& Mesh::tpu_host_mesh() {
182 static auto tpu_host_mesh = new std::string;
183 return *tpu_host_mesh;
184}
185
186// static
187StatusOr<Mesh> Mesh::ParseFromProto(const MeshProto& proto) {
188 Mesh mesh;
189 mesh.name_ = proto.name();
190
191 for (const auto& device : proto.local_devices()) {
192 mesh.local_devices_.push_back(device);
193 }
194
195 // Define local device ids.
196 for (const auto& device_id : proto.local_device_ids()) {
197 mesh.local_device_ids_.push_back(device_id);
198 }
199
200 for (const auto& device_id : proto.global_device_ids()) {
201 mesh.global_device_ids_.push_back(device_id);
202 }
203
204 for (const auto& device : proto.global_devices()) {
205 mesh.global_devices_.push_back(device);
206 }
207
208 // Assign Mesh Dimensions.
209 mesh.mesh_dims_.resize(proto.mesh_dimensions_size());
210 for (int i = 0; i < proto.mesh_dimensions_size(); ++i) {
211 const MeshDimensionProto& dim = proto.mesh_dimensions(i);
212 mesh.mesh_dims_[i].name = dim.name();
213 mesh.mesh_dims_[i].size = dim.size();
214 }
215
216 // Check invariants.
217 int64 mesh_size = mesh.size();
218 int num_devices = proto.global_device_ids_size();
219 if (mesh_size > 0 && mesh_size != num_devices) {
220 TF_RETURN_WITH_CONTEXT(
221 errors::InvalidArgument("Number of devices ", num_devices,
222 " not matching mesh size ", mesh_size));
223 }
224 return mesh;
225}
226
227// static
228StatusOr<Mesh> Mesh::GetAbstractMesh(
229 const std::string& name, const std::vector<MeshDimension>& mesh_dims) {
230 Mesh mesh;
231 mesh.name_ = name;
232 mesh.mesh_dims_ = mesh_dims;
233
234 // Check no repeated mesh dimension names.
235 std::set<std::string> dims_set;
236 for (const MeshDimension& dim : mesh.dims()) {
237 if (dims_set.find(dim.name) != dims_set.end())
238 TF_RETURN_WITH_CONTEXT(
239 errors::InvalidArgument("repeated mesh dimension"));
240 if (dim.name == Layout::kAny || dim.name == Layout::kMatch ||
241 dim.name == Layout::kUnshardedDim)
242 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh dimension name ",
243 dim.name, " is reserved"));
244 dims_set.insert(dim.name);
245 }
246
247 return mesh;
248}
249
250// static
251StatusOr<Mesh> Mesh::GetMesh(const std::string& name,
252 const std::vector<MeshDimension>& mesh_dims,
253 const std::vector<std::int64_t>& global_device_ids,
254 const std::vector<std::int64_t>& local_device_ids,
255 const std::vector<std::string>& local_devices,
256 const std::vector<std::string>& global_devices) {
257 TF_ASSIGN_OR_RETURN(Mesh mesh, GetAbstractMesh(name, mesh_dims));
258 mesh.global_device_ids_ = global_device_ids;
259 mesh.local_device_ids_ = local_device_ids;
260 mesh.local_devices_ = local_devices;
261 mesh.global_devices_ = global_devices;
262
263 // Check number of devices matches conditions.
264 size_t global_n = mesh.global_device_ids_.size();
265 size_t local_n = mesh.local_device_ids_.size();
266 size_t dev_n = mesh.local_devices_.size();
267
268 if (!(global_n >= local_n && dev_n == local_n))
269 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
270 "number of global_device_ids ", std::to_string(global_n),
271 " local_devices ids ", std::to_string(local_n), " and local devices ",
272 std::to_string(dev_n), "not meeting requirements"));
273
274 // If empty device list, return empty mesh.
275 if (global_n == 0) return Mesh::Empty();
276
277 if (local_n && !(global_n % local_n == 0))
278 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
279 "Uneven local clusters with global_ids ", std::to_string(global_n),
280 " and local_devices ids ", std::to_string(local_n)));
281
282 // Check mesh size matches number of devices.
283 if (mesh.size() != global_n)
284 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh size doesn't match",
285 "number of devices"));
286
287 // Check local device invariants.
288 TF_ASSIGN_OR_RETURN(const auto& parsed_devs, mesh.ParsedDevices());
289 std::set<std::string> types_set;
290 for (const DeviceNameUtils::ParsedName& dev : parsed_devs) {
291 if (!dev.has_job || !dev.has_task || !dev.has_type)
292 return errors::InvalidArgument(
293 "Failed to either identify host or device type");
294 types_set.insert(dev.type);
295 if (types_set.size() > 1)
296 return errors::InvalidArgument(
297 "More than one device type per mesh not supported. Found ",
298 types_set.size());
299 }
300
301 return mesh;
302}
303
304StatusOr<int64_t> Mesh::dim_size(absl::string_view name) const {
305 for (const auto& mesh_dim : dims()) {
306 if (name == mesh_dim.name) {
307 return mesh_dim.size;
308 }
309 }
310
311 std::vector<std::string> dim_names;
312 for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name);
313
314 return errors::NotFound(
315 "Dimension ", name, " does not exist in mesh.",
316 "Available dimensions: ", absl::StrJoin(dim_names, ","));
317}
318
319std::vector<int64_t> Mesh::dim_sizes() const {
320 std::vector<int64_t> dim_sizes;
321 if (mesh_dims_.empty()) return dim_sizes;
322 for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size);
323 return dim_sizes;
324}
325
326bool Mesh::operator==(const Mesh& b) const {
327 return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
328}
329
330bool Mesh::IsEmpty() const { return global_device_ids_.empty(); }
331
332StatusOr<const std::vector<DeviceNameUtils::ParsedName>> Mesh::ParsedDevices()
333 const {
334 std::vector<DeviceNameUtils::ParsedName> parsed_devices(
335 local_devices_.size());
336 for (std::size_t i = 0; i < local_devices_.size(); ++i)
337 if (!DeviceNameUtils::ParseFullOrLocalName(
338 absl::string_view(local_devices_[i]), &parsed_devices[i]))
339 return errors::InvalidArgument("Failed to parse local_devices");
340
341 return parsed_devices;
342}
343
344StatusOr<Mesh> Mesh::ToDeviceType(const std::string& device_type) const {
345 std::vector<std::string> to_local_devices;
346 DeviceNameUtils::ParsedName parsed_dev;
347 for (const std::string& local_dev : local_devices_) {
348 if (!DeviceNameUtils::ParseFullOrLocalName(absl::string_view(local_dev),
349 &parsed_dev)) {
350 return errors::InvalidArgument("Failed to parse local devices");
351 }
352 // Converted mesh using full task name with job, replica and task ids.
353 to_local_devices.push_back(
354 DeviceNameUtils::FullName(parsed_dev.job, parsed_dev.replica,
355 parsed_dev.task, device_type, parsed_dev.id));
356 parsed_dev.Clear();
357 }
358 return GetMesh(name_, mesh_dims_, global_device_ids_, local_device_ids_,
359 to_local_devices, /*global_devices=*/{});
360}
361
362namespace {
363std::string HostFromParsedDev(const DeviceNameUtils::ParsedName& dev) {
364 return "/job:" + dev.job + "/task:" + std::to_string(dev.task);
365}
366} // namespace
367
368std::vector<std::string> Mesh::hosts() const {
369 std::vector<std::string> host_list;
370 if (IsEmpty()) return host_list;
371
372 const auto parsed_devices = ParsedDevices().value();
373 for (const DeviceNameUtils::ParsedName& dev : parsed_devices) {
374 std::string host = HostFromParsedDev(dev);
375 if (std::find(host_list.begin(), host_list.end(), host) == host_list.end())
376 host_list.push_back(host);
377 }
378 return host_list;
379}
380
381std::string Mesh::device_type() const {
382 if (IsEmpty()) return std::string();
383 std::string device;
384 if (!global_devices_.empty()) {
385 device = global_devices_[0];
386 } else {
387 device = local_devices_[0];
388 }
389 DeviceNameUtils::ParsedName dev;
390 DeviceNameUtils::ParseFullOrLocalName(device, &dev);
391 return dev.type;
392}
393
394bool Mesh::IsMeshDim(const std::string& dim_name) const {
395 for (const auto& mesh_dim : dims())
396 if (dim_name == mesh_dim.name) return true;
397 return false;
398}
399
400int Mesh::GetMeshDimIndexWithName(const std::string& mesh_name) const {
401 int mesh_index = -1;
402 for (int i = 0; i < dims().size(); ++i) {
403 const auto mesh_dim = dim(i);
404 if (mesh_dim.name == mesh_name) mesh_index = i;
405 }
406 assert(mesh_index >= 0);
407 return mesh_index;
408}
409
410int64 Mesh::rank() const { return mesh_dims_.size(); }
411
412int64 Mesh::size() const {
413 if (mesh_dims_.empty()) return 0;
414
415 int64 size = 1;
416 for (const MeshDimension& dim : mesh_dims_) size *= dim.size;
417 return size;
418}
419
420Mesh Mesh::Empty() { return Mesh(); }
421
422MeshProto Mesh::ToProto() const {
423 MeshProto mesh_proto;
424 mesh_proto.set_name(name());
425
426 for (const auto& d : local_devices_) {
427 mesh_proto.add_local_devices(d);
428 }
429
430 for (const auto& i : local_device_ids_) {
431 mesh_proto.add_local_device_ids(i);
432 }
433
434 for (const auto& i : global_device_ids_) {
435 mesh_proto.add_global_device_ids(i);
436 }
437
438 for (const auto& dim : mesh_dims_) {
439 MeshDimensionProto* mesh_dim_proto = mesh_proto.add_mesh_dimensions();
440 mesh_dim_proto->set_name(dim.name);
441 mesh_dim_proto->set_size(dim.size);
442 }
443
444 for (const auto& d : global_devices_) {
445 mesh_proto.add_global_devices(d);
446 }
447 return mesh_proto;
448}
449
450std::string Mesh::ToString() const {
451 if (Mesh::IsEmpty()) return kEmptyMeshString;
452
453 // We use "|" to separate name, mesh dimensions and devices.
454 std::string mesh_str = absl::StrCat(Mesh::name(), "|");
455
456 // Add mesh dimensions
457 absl::InlinedVector<std::string, 4> mesh_dim_lst;
458 for (const auto& dim : mesh_dims_)
459 mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
460 mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
461
462 // Add flattened list of global device ids
463 mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
464
465 // Add flattened list of local device ids
466 mesh_str += absl::StrJoin(local_device_ids_, ",") + "|";
467
468 // Add flattened list of local devices
469 mesh_str += absl::StrJoin(local_devices_, ",");
470
471 if (!global_devices_.empty()) {
472 // Add flattened list of global devices
473 mesh_str += "|";
474 mesh_str += absl::StrJoin(global_devices_, ",");
475 }
476 return mesh_str;
477}
478
479uint64 Mesh::GlobalFingerprint() const {
480 if (Mesh::IsEmpty()) return Fingerprint64(kEmptyMeshString);
481
482 std::string mesh_str;
483 // Add mesh dimensions
484 absl::InlinedVector<std::string, 4> mesh_dim_lst;
485 for (const auto& dim : mesh_dims_)
486 mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
487 mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
488
489 // Ignore local_device_ids_, local_devices and name which might be not global
490 // unique.
491 // Add flattened list of global device ids
492 mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
493
494 if (!global_devices_.empty()) {
495 // Add flattened list of global devices
496 mesh_str += "|";
497 mesh_str += absl::StrJoin(global_devices_, ",");
498 }
499 // mesh dims | global device ids (| global devices)
500 return Fingerprint64(mesh_str);
501}
502
503namespace {
504MeshDimension StrToMeshDimension(const std::string& str) {
505 MeshDimension mesh_dim;
506 if (str.empty()) return mesh_dim;
507
508 std::vector<std::string> mesh_dim_parts = absl::StrSplit(str, '=');
509
510 mesh_dim.name = mesh_dim_parts[0];
511 mesh_dim.size = std::stoi(mesh_dim_parts[1]);
512 return mesh_dim;
513}
514
515StatusOr<Mesh> GenerateMeshDevicesForTests(
516 const std::string& name, const std::vector<MeshDimension>& mesh_dims,
517 const std::string& mesh_gen_instruction) {
518 // Parse mesh generation instruction.
519 std::vector<std::string> instruction_parts =
520 absl::StrSplit(mesh_gen_instruction, '*');
521 if (instruction_parts.size() != 2)
522 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
523 "Expected a * in mesh_gen_instructions but found ",
524 mesh_gen_instruction));
525 std::string device_type = instruction_parts[1];
526
527 // Get Mesh Size.
528 int64 mesh_size = 0;
529 if (!mesh_dims.empty()) {
530 mesh_size = 1;
531 for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size;
532 }
533
534 // Generate device ids.
535 std::vector<int64_t> global_device_ids;
536 std::vector<int64_t> local_device_ids;
537 std::vector<std::string> local_devices;
538 for (std::size_t i = 0; i < mesh_size; ++i) {
539 global_device_ids.push_back(i);
540 local_device_ids.push_back(i);
541 local_devices.push_back("/job:localhost/task:0/device:" + device_type +
542 ":" + std::to_string(i));
543 }
544
545 TF_ASSIGN_OR_RETURN(
546 Mesh mesh,
547 Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
548 local_devices, /*global_devices=*/{}));
549 return mesh;
550}
551} // namespace
552
553// static
554StatusOr<Mesh> Mesh::FromString(const std::string& str) {
555 if (str == kEmptyMeshString) return Mesh::Empty();
556
557 std::vector<std::string> mesh_parts = absl::StrSplit(str, '|');
558
559 // Check formatting error.
560 if (mesh_parts.size() != 3 && mesh_parts.size() != 5 &&
561 mesh_parts.size() != 6)
562 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
563 "Expected either 5, 6 or 3 mesh parts but found", mesh_parts.size()));
564
565 // Populate mesh.
566 std::string name = mesh_parts[0];
567
568 // Add mesh dimensions.
569 std::vector<MeshDimension> mesh_dims;
570 if (!mesh_parts[1].empty()) {
571 std::vector<std::string> mesh_dim_strs = absl::StrSplit(mesh_parts[1], ',');
572 mesh_dims.reserve(mesh_dim_strs.size());
573 for (const std::string& mesh_dim_str : mesh_dim_strs)
574 mesh_dims.push_back(StrToMeshDimension(mesh_dim_str));
575 }
576
577 // Check if mesh is set to be autogenerated.
578 if (mesh_parts.size() == 3)
579 return GenerateMeshDevicesForTests(name, mesh_dims, mesh_parts[2]);
580
581 // Add global device ids list.
582 std::vector<int64_t> global_device_ids;
583 if (!mesh_parts[2].empty()) {
584 std::vector<std::string> global_device_ids_strs =
585 absl::StrSplit(mesh_parts[2], ',');
586
587 global_device_ids.reserve(global_device_ids_strs.size());
588 for (const std::string& id : global_device_ids_strs)
589 global_device_ids.push_back(std::stoi(id));
590 }
591
592 // Add local device ids list.
593 std::vector<int64_t> local_device_ids;
594 if (!mesh_parts[3].empty()) {
595 std::vector<std::string> local_device_ids_strs =
596 absl::StrSplit(mesh_parts[3], ',');
597
598 local_device_ids.reserve(local_device_ids_strs.size());
599 for (const std::string& id : local_device_ids_strs)
600 local_device_ids.push_back(std::stoi(id));
601 }
602 // Add local devices.
603 std::vector<std::string> local_devices;
604 if (!mesh_parts[4].empty())
605 local_devices = absl::StrSplit(mesh_parts[4], ',');
606
607 std::vector<std::string> global_devices;
608 if (mesh_parts.size() == 6) {
609 // Add global devices.
610 if (!mesh_parts[5].empty())
611 global_devices = absl::StrSplit(mesh_parts[5], ',');
612 }
613
614 TF_ASSIGN_OR_RETURN(
615 Mesh mesh,
616 Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
617 local_devices, global_devices));
618 return mesh;
619}
620
621int64 Mesh::num_devices() const { return global_device_ids_.size(); }
622
623StatusOr<const DeviceLocation> Mesh::device_location(int offset) const {
624 if (offset < 0 || offset > size() - 1)
625 return errors::InvalidArgument(
626 "Mesh offset cannot be negative or exceed Mesh's size. Offset size:",
627 offset, " and Mesh size:", size());
628
629 DeviceLocation dev_loc;
630 std::vector<int64> mesh_dim_lengths = dim_sizes();
631 int64 i = mesh_dim_lengths.size() - 1;
632 while (i >= 0) {
633 dev_loc.insert(dev_loc.begin(), offset % mesh_dim_lengths[i]);
634 offset /= mesh_dim_lengths[i];
635 --i;
636 }
637 return dev_loc;
638}
639
640int64 Mesh::GetFlattenedCoordinate(const DeviceLocation& loc) const {
641 const std::vector<int64> mesh_dim_sizes = dim_sizes();
642 int64 i = mesh_dim_sizes.size() - 1;
643 int64 acc = 1;
644 int64 device_pos = 0;
645 while (i >= 0) {
646 device_pos += loc[i] * acc;
647 acc *= mesh_dim_sizes[i];
648 --i;
649 }
650 return device_pos;
651}
652
653StatusOr<int32> Mesh::idx_for_dim(absl::string_view dim_name) const {
654 for (int i = 0; i < mesh_dims_.size(); ++i) {
655 if (mesh_dims_[i].name == dim_name) return i;
656 }
657 return errors::InvalidArgument("dim name :", dim_name,
658 " does not exist on mesh : ", ToString());
659}
660
661StatusOr<Layout> Layout::GetLayout(
662 const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh) {
663 // Re-format sharding specs.
664 std::vector<ShardingSpec> sharding_specs;
665 sharding_specs.reserve(sharding_spec_strs.size());
666 for (const std::string& spec_str : sharding_spec_strs) {
667 ShardingSpec spec;
668 spec.set_sharding_spec(spec_str);
669 sharding_specs.push_back(spec);
670 }
671 return GetLayout(sharding_specs, mesh);
672}
673
674StatusOr<Layout> Layout::GetLayout(
675 const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh) {
676 Layout layout;
677 // Append mesh, then check sharding_specs are legal.
678 layout.mesh_ = mesh;
679
680 // Check sharding_specs are either mesh dimension or special value.
681 for (const auto& dim : sharding_specs) {
682 const std::string& sharding_spec = dim.sharding_spec();
683 if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny ||
684 sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) ||
685 sharding_spec == "scalar"))
686 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
687 "sharding spec (", sharding_spec,
688 ") refers to mesh dimension not contained in mesh ",
689 mesh.ToString()));
690 }
691 // Check same tensor dimensions not sharded over same mesh dimension twice.
692 std::set<std::string> dims_set;
693 for (const auto& dim : sharding_specs) {
694 const std::string& sharding_spec = dim.sharding_spec();
695 if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue;
696 // If scalar, delete all sharding specs.
697 if (sharding_spec == "scalar") {
698 if (sharding_specs.size() > 1)
699 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
700 "A scalar sharding_spec can only be used as a single sharding_spec "
701 "instruction, not as part of list of sharding_specs as attempted "
702 "here with ",
703 sharding_specs.size(), " sharding_specs"))
704 // Return layout with empty spec to represent scalar behavior.
705 return layout;
706 }
707 if (dims_set.find(sharding_spec) != dims_set.end())
708 TF_RETURN_WITH_CONTEXT(
709 errors::InvalidArgument("Attempted to shard two or more tensor "
710 "dimensions over mesh dimension ",
711 sharding_spec))
712 dims_set.insert(sharding_spec);
713 }
714 // After checking sharding_specs are legal, append and return layout.
715 layout.sharding_specs_ = sharding_specs;
716 return layout;
717}
718
719Layout Layout::Empty() {
720 Layout result;
721 return result;
722}
723
724bool Layout::IsEmpty() const { return mesh_.IsEmpty(); }
725
726namespace {
727Mesh ReducedAbstractMesh(const Layout* layout) {
728 const std::vector<std::string>& shard_spec_strs =
729 layout->sharding_spec_strs();
730 std::vector<MeshDimension> reduced_mesh_dims;
731 reduced_mesh_dims.reserve(layout->mesh().dims().size());
732 for (const MeshDimension& mesh_dim : layout->mesh().dims()) {
733 bool IsMeshDimInShardingSpecs =
734 std::find(shard_spec_strs.begin(), shard_spec_strs.end(),
735 mesh_dim.name) != shard_spec_strs.end();
736 // If dimension not in sharding_spec, flip size to 1.
737 MeshDimension reduced_dim =
738 IsMeshDimInShardingSpecs ? mesh_dim : MeshDimension(mesh_dim.name, 1);
739 reduced_mesh_dims.push_back(reduced_dim);
740 }
741 return Mesh::GetAbstractMesh("", reduced_mesh_dims).value();
742}
743
744} // namespace
745
746Mesh Layout::ReducedMesh() const {
747 // Set replicated mesh dimensions to size 1, and create reduced abstract mesh.
748 Mesh reduced_mesh = ReducedAbstractMesh(this);
749
750 // Populate reduced mesh with global devices from original mesh.
751 std::vector<int64_t> reduced_global_device_ids;
752 std::vector<std::string> reduced_global_devs;
753 for (const DeviceLocation& loc : ComputeDeviceLocations(&reduced_mesh)) {
754 int64 pos = mesh().GetFlattenedCoordinate(loc);
755 reduced_global_device_ids.push_back(mesh().global_device_ids().at(pos));
756 if (!mesh().global_devices().empty()) {
757 reduced_global_devs.push_back(mesh().global_devices().at(pos));
758 }
759 }
760
761 // Track the set of global device IDs in the abstract mesh.
762 std::set<int64_t> reduced_global_device_ids_set(
763 reduced_global_device_ids.begin(), reduced_global_device_ids.end());
764
765 // Populate reduced mesh with local devices in the same order as the original
766 // mesh.
767 std::vector<int64_t> reduced_local_device_ids;
768 std::vector<std::string> reduced_local_devs;
769 for (size_t i = 0; i < mesh().local_device_ids().size(); ++i) {
770 int64_t device_id = mesh().local_device_ids().at(i);
771 if (reduced_global_device_ids_set.find(device_id) !=
772 reduced_global_device_ids_set.end()) {
773 reduced_local_device_ids.push_back(device_id);
774 reduced_local_devs.push_back(mesh().local_devices().at(i));
775 }
776 }
777
778 return Mesh::GetMesh(reduced_mesh.name(), reduced_mesh.dims(),
779 reduced_global_device_ids, reduced_local_device_ids,
780 reduced_local_devs, reduced_global_devs)
781 .value();
782}
783
784namespace {
785Layout ReducedLayout(const Layout* layout) {
786 // Change format sharding specs.
787 std::vector<ShardingSpec> shard_specs(layout->sharding_specs().size());
788 for (size_t i = 0; i < shard_specs.size(); ++i)
789 shard_specs[i] = layout->dim(i);
790 // Retrieve layout.
791 return Layout::GetLayout(shard_specs, layout->ReducedMesh()).value();
792}
793
794// Returns index of the given mesh dimension or mesh dim size if not found.
795StatusOr<int> IndexOfMeshDimension(const Mesh& mesh,
796 const std::string& dim_name) {
797 for (size_t i = 0; i < mesh.dims().size(); ++i)
798 if (dim_name == mesh.dims()[i].name) return i;
799 return errors::InvalidArgument("Mesh dimension not found");
800}
801} // namespace
802
803ShardVector Layout::GetShardVector() const {
804 // Change format sharding specs.
805 std::vector<ShardingSpec> shard_specs(sharding_specs().size());
806 for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i);
807 // Obtain a shard position (i.e. sharded section of a tensor) from a mesh
808 // location, using the sharding specs.
809 auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard {
810 Shard shard;
811 for (size_t i = 0; i < shard_specs.size(); ++i) {
812 // If unsharded, there is only one shard, that is 1.
813 std::string spec = shard_specs[i].sharding_spec();
814 if (spec == Layout::kUnshardedDim) {
815 shard.push_back(1);
816 } else {
817 int mesh_index = IndexOfMeshDimension(mesh(), sharding_spec(i)).value();
818 int shard_number = loc[mesh_index] + 1;
819 shard.push_back(shard_number);
820 }
821 }
822 return shard;
823 };
824 // Obtain dims of shard vector.
825 auto ShardVectorDims = [&]() -> std::vector<int> {
826 std::vector<int> num_shards_per_dim(shard_specs.size());
827 for (size_t i = 0; i < sharding_specs().size(); ++i) {
828 ShardingSpec spec = sharding_specs()[i];
829 if (Layout::IsShardedSpec(spec)) {
830 StatusOr<int64> dim_size = mesh().dim_size(spec.sharding_spec());
831 num_shards_per_dim[i] = dim_size.value();
832 } else {
833 num_shards_per_dim[i] = 1;
834 }
835 }
836 return num_shards_per_dim;
837 };
838 // Compute mesh locations and obtain shards from them.
839 ShardVector shard_vec;
840 for (const DeviceLocation& mesh_loc : ComputeDeviceLocations(&mesh()))
841 shard_vec.shards.push_back(GetShardFromDeviceLocation(mesh_loc));
842 // Calculate dims.
843 shard_vec.num_shards_per_dim = ShardVectorDims();
844 return shard_vec;
845}
846
847std::map<std::string, ShardVector> Layout::HostShardMap() const {
848 Layout reduced_layout = ReducedLayout(this);
849 Mesh reduced_mesh = reduced_layout.mesh();
850 using HostName = std::string;
851
852 // Build a map: {Host : Shards}
853 std::map<HostName, ShardVector> host_shards_map;
854 ShardVector shard_vec_in_red_layout = reduced_layout.GetShardVector();
855
856 const auto parsed_devs = reduced_mesh.ParsedDevices().value();
857 for (size_t i = 0; i < parsed_devs.size(); ++i) {
858 HostName host = HostFromParsedDev(parsed_devs[i]);
859 Shard shard_in_device = shard_vec_in_red_layout.shards[i];
860
861 // Check if host in hashtable and append shard.
862 auto it = host_shards_map.find(host);
863 if (it == host_shards_map.end()) {
864 ShardVector shard_vec_in_host;
865 shard_vec_in_host.shards.push_back(shard_in_device);
866 shard_vec_in_host.num_shards_per_dim =
867 shard_vec_in_red_layout.num_shards_per_dim;
868 host_shards_map.insert(
869 std::pair<HostName, ShardVector>(host, shard_vec_in_host));
870 } else {
871 bool isShardInShardVector = it->second.ContainsShard(shard_in_device);
872 if (!isShardInShardVector) {
873 it->second.shards.push_back(shard_in_device);
874 }
875 }
876 }
877 // Sort shards inside each host.
878 for (auto it = host_shards_map.begin(); it != host_shards_map.end(); ++it) {
879 std::sort(it->second.shards.begin(), it->second.shards.end());
880 }
881 return host_shards_map;
882}
883
884const std::string& Layout::sharding_spec(int idx) const {
885 return sharding_specs_[idx].sharding_spec();
886}
887
888std::vector<int32> Layout::num_shards() const {
889 std::vector<int32> num_shards;
890 num_shards.reserve(sharding_specs_.size());
891 for (const auto& sharding_spec : sharding_specs_) {
892 num_shards.push_back(num_shards_for_dim(sharding_spec));
893 }
894 return num_shards;
895}
896
897size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const {
898 absl::string_view name = dim.sharding_spec();
899 if (name == Layout::kUnshardedDim) return 1;
900 if (name == Layout::kMatch) return -1;
901
902 return mesh().dim_size(name).value();
903}
904
905bool Layout::IsFullyReplicated() const {
906 for (const auto& sharding_spec : sharding_specs_) {
907 if (num_shards_for_dim(sharding_spec) > 1) {
908 return false;
909 }
910 }
911 return true;
912}
913
914bool Layout::IsLastDimReplicated() const {
915 return (sharding_specs_.empty()) ||
916 (num_shards_for_dim(sharding_specs_.back()) == 1);
917}
918
919bool Layout::IsBatchParallel() const {
920 if (sharding_specs_.empty()) {
921 return true;
922 }
923
924 for (int i = 1; i < sharding_specs_.size(); ++i) {
925 const auto& dim = sharding_specs_[i];
926 if (num_shards_for_dim(dim) != 1) {
927 return false;
928 }
929 }
930 return true;
931}
932
933// TODO(samuelslee) Replace this with the IsBatchParallel() everywhere
934bool Layout::IsBatchParallel(int non_batch_rank) const {
935 if (sharding_specs_.empty()) return true;
936 for (int i = rank() - non_batch_rank; i < rank(); ++i) {
937 if (num_shards_for_dim(sharding_specs_[i]) != 1) return false;
938 }
939 return true;
940}
941
942LayoutProto Layout::ToProto() const {
943 LayoutProto proto;
944 *proto.mutable_mesh_config() = mesh_.ToProto();
945 for (const auto& dim : sharding_specs_) {
946 *proto.add_sharding_specs() = dim;
947 }
948 return proto;
949}
950
951bool Layout::IsEquivalent(const Layout& b) const {
952 if (this->rank() != b.rank()) return false;
953 if (this->mesh() != b.mesh()) return false;
954 for (int i = 0; i < this->rank(); ++i) {
955 if (this->sharding_specs_[i].sharding_spec() !=
956 b.sharding_specs_[i].sharding_spec()) {
957 if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) ||
958 (b.num_shards_for_dim(b.sharding_specs_[i]) != 1))
959 return false;
960 }
961 }
962 return true;
963}
964
965bool Layout::operator==(const Layout& b) const {
966 return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
967}
968
969std::vector<int64_t> Layout::GlobalShapeFromLocalShape(
970 const std::vector<int64_t>& local_shape) const {
971 if (IsFullyReplicated()) {
972 return local_shape;
973 }
974 std::vector<int64_t> global_shape;
975 global_shape.reserve(sharding_specs().size());
976 for (int i = 0; i < sharding_specs().size(); ++i) {
977 int64_t l_shape = local_shape.empty() ? 1 : local_shape[i];
978 int64_t dim_shards = num_shards()[i];
979 global_shape.emplace_back(l_shape * dim_shards);
980 }
981 return global_shape;
982}
983
984std::vector<int64_t> Layout::LocalShapeFromGlobalShape(
985 absl::Span<const int64_t> global_shape) const {
986 if (IsFullyReplicated()) {
987 return std::vector<int64_t>(global_shape.begin(), global_shape.end());
988 }
989 std::vector<int32> shards = num_shards();
990 std::vector<int64_t> local_shape;
991 for (int i = 0; i < sharding_specs().size(); ++i) {
992 int64_t dim_shards = shards[i];
993 // TODO(hthu): Shape might not be always divisible.
994 local_shape.emplace_back(global_shape[i] / dim_shards);
995 }
996 return local_shape;
997}
998
999PartialTensorShape Layout::LocalShapeFromGlobalShape(
1000 const PartialTensorShape& global_shape) const {
1001 if (IsFullyReplicated() || global_shape.dims() == -1) {
1002 return global_shape;
1003 }
1004 std::vector<int32> shards = num_shards();
1005 PartialTensorShape local_shape({});
1006 for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) {
1007 int64_t dim_size = global_shape.dim_size(spec_index);
1008 local_shape.AddDim(dim_size == -1 ? -1 : dim_size / shards[spec_index]);
1009 }
1010 return local_shape;
1011}
1012
1013StatusOr<Layout> Layout::FromProto(const LayoutProto& proto) {
1014 Layout layout;
1015 for (const auto& spec : proto.sharding_specs())
1016 layout.sharding_specs_.push_back(spec);
1017
1018 TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config()));
1019 layout.mesh_ = std::move(mesh);
1020
1021 return GetLayout(layout.sharding_specs_, layout.mesh_);
1022}
1023
1024Layout Layout::ReplicatedOnMesh(const Mesh& mesh, int rank) {
1025 std::vector<std::string> specs(rank, kUnshardedDim);
1026 return Layout::GetLayout(specs, mesh).value();
1027}
1028
1029Layout Layout::AnyOnMesh(const Mesh& mesh, int rank) {
1030 std::vector<std::string> specs(rank, kAny);
1031 return Layout::GetLayout(specs, mesh).value();
1032}
1033
1034StatusOr<Layout> Layout::Transposed2D(const Layout& layout) {
1035 if (layout.rank() < 2) {
1036 return errors::InvalidArgument("Transposed2D requires rank to be >= 2");
1037 }
1038 std::vector<std::string> transposed_specs = layout.sharding_spec_strs();
1039 std::iter_swap(transposed_specs.end() - 2, transposed_specs.end() - 1);
1040 return Layout::GetLayout(transposed_specs, layout.mesh()).value();
1041}
1042
1043// static
1044StatusOr<Layout> Layout::FromString(std::string layout_str) {
1045 if (layout_str == kEmptyLayoutString) return Layout::Empty();
1046
1047 // Print sharding specs.
1048 std::vector<absl::string_view> layout_parts = absl::StrSplit(layout_str, ' ');
1049 // Check formatting error.
1050 if (layout_parts.size() != 2) {
1051 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
1052 "Expected 2 items but found ", layout_parts.size(), layout_parts[0]));
1053 }
1054 // Substract prefixes.
1055 absl::string_view sharding_spec_str = layout_parts[0];
1056 absl::ConsumePrefix(&sharding_spec_str, "sharding_specs:");
1057
1058 absl::string_view mesh_str = layout_parts[1];
1059 absl::ConsumePrefix(&mesh_str, "mesh:");
1060
1061 // Add sharding specs.
1062 std::vector<std::string> sharding_spec_strs =
1063 absl::StrSplit(sharding_spec_str, ',');
1064 sharding_spec_strs.pop_back();
1065
1066 // Add mesh.
1067 TF_ASSIGN_OR_RETURN(Mesh mesh, Mesh::FromString(string(mesh_str)));
1068 // Try to create layout.
1069 TF_ASSIGN_OR_RETURN(Layout layout,
1070 Layout::GetLayout(sharding_spec_strs, mesh));
1071 return layout;
1072}
1073
1074std::vector<std::string> Layout::sharding_spec_strs() const {
1075 std::vector<std::string> sharding_spec_strs(sharding_specs().size());
1076 for (size_t i = 0; i < sharding_specs().size(); ++i)
1077 sharding_spec_strs[i] = sharding_spec(i);
1078 return sharding_spec_strs;
1079}
1080
1081std::string Layout::ToString() const {
1082 if (Layout::IsEmpty()) return kEmptyLayoutString;
1083
1084 std::string layout_str = "sharding_specs:";
1085 // Print sharding specs.
1086 for (const ShardingSpec& dim : sharding_specs_) {
1087 std::string dim_name = dim.sharding_spec();
1088 absl::StrAppend(&layout_str, dim_name + ",");
1089 }
1090 // Append mesh.
1091 absl::StrAppend(&layout_str, " mesh:", mesh_.ToString());
1092 return layout_str;
1093}
1094
1095Layout Layout::GetLayoutWithReducedDims(
1096 const absl::flat_hash_set<int>& reduced_dims, bool keep_dims) const {
1097 dtensor::LayoutProto output_layout;
1098 *output_layout.mutable_mesh_config() = mesh().ToProto();
1099
1100 for (int i = 0; i < rank(); ++i) {
1101 // reduced_dims may contain negative values.
1102 if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) {
1103 *output_layout.add_sharding_specs() = dim(i);
1104 } else if (keep_dims) {
1105 auto* replicated_dim = output_layout.add_sharding_specs();
1106 replicated_dim->set_sharding_spec(kUnshardedDim);
1107 }
1108 }
1109 return Layout::FromProto(output_layout).value();
1110}
1111
1112Layout Layout::Truncate(int64 split_point, bool end) const {
1113 if ((split_point == 0 && end) || (split_point == rank() && !end))
1114 return *this;
1115
1116 dtensor::LayoutProto output_layout;
1117 *output_layout.mutable_mesh_config() = mesh().ToProto();
1118
1119 if (end) {
1120 for (int i = split_point; i < rank(); ++i)
1121 *output_layout.add_sharding_specs() = dim(i);
1122 } else {
1123 for (int i = 0; i < split_point; ++i)
1124 *output_layout.add_sharding_specs() = dim(i);
1125 }
1126 return Layout::FromProto(output_layout).value();
1127}
1128
1129namespace {
1130// Adds unsharded sharding specs to layout.
1131Layout PadLayout(const int64 rank, const bool is_padding_before,
1132 const Layout& layout) {
1133 if (rank <= layout.rank()) return layout;
1134
1135 // Create list of padding sharding specs.
1136 const int n = rank - layout.rank();
1137 std::vector<ShardingSpec> new_specs(n);
1138 for (int i = 0; i < n; ++i)
1139 new_specs[i].set_sharding_spec(Layout::kUnshardedDim);
1140
1141 // Define concatenation point of layout specs.
1142 auto concat_point = is_padding_before ? new_specs.end() : new_specs.begin();
1143
1144 // Concatenate old layout specs and new unsharded specs.
1145 new_specs.insert(concat_point, layout.sharding_specs().begin(),
1146 layout.sharding_specs().end());
1147 return Layout::GetLayout(new_specs, layout.mesh()).value();
1148}
1149} // namespace
1150
1151Layout Layout::LeftPad(int64 rank) const {
1152 bool is_padding_before = true;
1153 return PadLayout(rank, is_padding_before, *this);
1154}
1155
1156StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a,
1157 const Layout& layout_b) {
1158 if (layout_a.mesh() != layout_b.mesh())
1159 return errors::InvalidArgument(
1160 "unable to concatenate layouts as they are on different meshes.");
1161
1162 absl::flat_hash_set<std::string> layout_a_mesh_dims;
1163 for (int i = 0; i < layout_a.rank(); ++i)
1164 if (layout_a.sharding_spec(i) != Layout::kUnshardedDim)
1165 layout_a_mesh_dims.emplace(layout_a.sharding_spec(i));
1166
1167 for (int i = 0; i < layout_b.rank(); ++i)
1168 if (layout_b.sharding_spec(i) != Layout::kUnshardedDim &&
1169 layout_a_mesh_dims.contains(layout_b.sharding_spec(i)))
1170 return errors::InvalidArgument(
1171 "unable to concatenate layouts as they use the same meshes "
1172 "dimension: ",
1173 layout_b.sharding_spec(i), " is used in both layouts.");
1174
1175 LayoutProto layout_proto_a = layout_a.ToProto();
1176 LayoutProto layout_proto_b = layout_b.ToProto();
1177 LayoutProto output_layout_proto;
1178
1179 *output_layout_proto.mutable_mesh_config() = layout_proto_a.mesh_config();
1180 for (int i = 0; i < layout_proto_a.sharding_specs_size(); ++i)
1181 *output_layout_proto.add_sharding_specs() =
1182 layout_proto_a.sharding_specs(i);
1183 for (int i = 0; i < layout_proto_b.sharding_specs_size(); ++i)
1184 *output_layout_proto.add_sharding_specs() =
1185 layout_proto_b.sharding_specs(i);
1186 return Layout::FromProto(output_layout_proto);
1187}
1188
1189} // namespace dtensor
1190} // namespace tensorflow
1191