1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/collectives_common.h"
17
18#include <string>
19
20namespace tensorflow {
21namespace dtensor {
22
23// A map from a unique set of kept mesh dimension values (a partition) to
24// IDs of devices in that partition.
25//
26// Users will typically ignore the key, but use the map values as the group
27// assignment for collective operations. This is intentionally a
28// std::map instead of absl::flat_hash_map to guarantee all hosts in
29// a multi-host cluster will generate the same grouping, and therefore the same
30// XLA program fingerprint, independently. std::map guarantees the same
31// iteration order.
32using AllReducePartitions = std::map<DeviceLocation, std::vector<int32>>;
33
34// Computes AllReduce partitions using reduced mesh dimension names.
35//
36// Reduction groups are formed across all _non_-reduced dimensions. For example,
37// in the following scenario:
38//
39// output_layout.dims() = [a, b]
40// output_layout.mesh() = [(x, 8), (y, 4)]
41// reduced_dims = `x`
42//
43// We first reduce over `a` locally on each device, producing 32 local
44// reductions. We then AllReduce within each of the 4 partitions. Each partition
45// corresponds to one unique value of `y` and has 8 devices. The end result is
46// sharded over the y mesh dimension and replicated 8 times.
47//
48// The returned map should have four entries with key values from [0] to [3]
49// (unique values of `y`). Each key maps to IDs of devices with that `y` value.
50StatusOr<AllReducePartitions> GetAllReducePartitionsFromReducedDims(
51 const dtensor::Layout& output_layout,
52 const absl::flat_hash_set<std::string>& reduced_dims) {
53 AllReducePartitions partitions;
54 for (int64 device = 0; device < output_layout.num_devices(); ++device) {
55 TF_ASSIGN_OR_RETURN(const DeviceLocation device_loc,
56 output_layout.device_location(device));
57 DeviceLocation kept_dims;
58 for (int64 dim_idx = 0; dim_idx < device_loc.size(); ++dim_idx) {
59 if (!reduced_dims.contains(output_layout.mesh().dim_name(dim_idx))) {
60 kept_dims.push_back(device_loc[dim_idx]);
61 }
62 }
63 partitions[kept_dims].push_back(device);
64 }
65 return partitions;
66}
67
68// Use the first device in the mesh to extract the device name. For example:
69//
70// device_path = "/job:localhost/replica:0/task:0/device:TPU:0"
71// device_type = "/job:localhost/replica:0/task:0/device:TPU"
72// device_id = 0
73//
74// The device ID can be obtained through DeviceId as a runtime input. We may
75// need it in the future to enable device ID-based branch divergence.
76StatusOr<std::string> DeviceTypeFromMesh(const Mesh& mesh) {
77 std::string device_path =
78 mesh.is_remote() ? mesh.global_devices()[0] : mesh.local_devices()[0];
79 size_t device_path_pos = device_path.find_last_of(':');
80 if (device_path_pos == std::string::npos) {
81 return errors::InvalidArgument("Unexpected device path: ", device_path);
82 }
83 return device_path.substr(0, device_path_pos);
84}
85
86} // namespace dtensor
87} // namespace tensorflow
88