1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
17#define TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
18
19#include <ostream>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_map.h"
25#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26#include "mlir/IR/MLIRContext.h" // from @llvm-project
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/status.h"
30#include "tensorflow/core/platform/types.h"
31#include "tensorflow/dtensor/cc/dstatus.h"
32
33namespace tensorflow {
34namespace dtensor {
35
36// Arranges all replica IDs in a DTensor mesh in groups, used as an attribute
37// on collective operations.
38//
39// A group assignment has two views:
40//
41// - The global mesh view contains replica IDs from all participant TPU slices.
42// These replica IDs are identical to global device IDs in a DTensor mesh.
43// - The local slice view contains per-slice device IDs understood and used by
44// the TPU runtime on each slice. These device IDs are used to set replica
45// IDs on each slice.
46//
47// Some notable common cases:
48//
49// - In a single-slice case, `slice_size` is set to the actual slice size
50// (e.g., 32 for 4x4 DF). The global and local views are identical.
51// - In a special topology case, `slice_size` is set to 8.
52// - In a multi-topology case, `slice_size` is set to the size of a single
53// topology.
54// All topologies must have the same size.
55class GroupAssignment {
56 public:
57 using ReplicaId = int;
58
59 struct DeviceId {
60 public:
61 int slice_id;
62 int core_id; // within `slice_id`
63 };
64
65 // Maps global replica IDs to local device IDs consisting of a slice ID and a
66 // core-on-slice ID.
67 class ReplicaToDeviceMap {
68 public:
69 // Creates a default map that orders devices according to TF task IDs
70 // followed by device ordinals.
71 static ReplicaToDeviceMap DefaultReplicaToDeviceMap(int num_slices,
72 int slice_size);
73
74 // Constructs a map directly, checking it's valid.
75 explicit ReplicaToDeviceMap(absl::flat_hash_map<ReplicaId, DeviceId> map);
76
77 int num_slices() { return num_slices_; }
78 int num_cores() { return map_.size(); }
79 DeviceId device_id(ReplicaId replica_id) { return map_[replica_id]; }
80
81 private:
82 absl::flat_hash_map<ReplicaId, DeviceId> map_;
83 int num_slices_;
84 };
85
86 // Creates a group assignment by converting from an MLIR attribute.
87 static StatusOr<GroupAssignment> FromMLIR(
88 const mlir::DenseIntElementsAttr& group_assignment_attr,
89 ReplicaToDeviceMap replica_to_device_map);
90
91 // Creates an MLIR attribute using the global view.
92 mlir::DenseIntElementsAttr GlobalToMLIR(mlir::MLIRContext& context) const {
93 return global_.ToMLIR(context);
94 }
95
96 // Creates an MLIR attribute for a particular slice.
97 // Callers should make sure `slice_id` is >= 0 and < num_slices().
98 StatusOr<mlir::DenseIntElementsAttr> SliceToMLIR(mlir::MLIRContext& context,
99 int slice_id) const {
100 if (slice_id < 0 || slice_id >= num_slices())
101 return errors::InvalidArgument("slide_id was not within bounds.");
102 return slices_[slice_id].ToMLIR(context);
103 }
104
105 // Returns a string representation for debugging.
106 std::string ToString() const;
107
108 // Returns true if every group in the global view only has replica IDs from
109 // the same slice.
110 bool IsWithinSlices() const;
111
112 // Returns the number of slices in the local view.
113 int num_slices() const { return slices_.size(); }
114
115 // These methods return attributes of the global view.
116 int num_groups() const { return global_.num_groups(); }
117 int group_size() const { return global_.group_size(); }
118 int num_replica_ids() const { return global_.num_replica_ids(); }
119 const std::vector<std::vector<int>>& replica_ids() const {
120 return global_.replica_ids();
121 }
122
123 // These methods return attributes of a particular slice.
124 // Callers should make sure `slice_id` is >= 0 and < num_slices().
125 StatusOr<int> num_groups(int slice_id) const {
126 if (slice_id < 0 || slice_id >= num_slices())
127 return errors::InvalidArgument("slide_id was not within bounds.");
128 return slices_[slice_id].num_groups();
129 }
130 StatusOr<int> group_size(int slice_id) const {
131 if (slice_id < 0 || slice_id >= num_slices())
132 return errors::InvalidArgument("slide_id was not within bounds.");
133 return slices_[slice_id].group_size();
134 }
135 const std::vector<std::vector<int>>& replica_ids(int slice_id) const {
136 return slices_[slice_id].replica_ids();
137 }
138
139 // Returns the replica groups for collectives running on a particular host.
140 // Callers should make sure `slice_id` is >= 0 and < num_slices().
141 const std::vector<std::vector<int>>& host_replica_ids(int slice_id) const {
142 return hosts_[slice_id].replica_ids();
143 }
144
145 private:
146 // Groups of consecutive replica IDs starting at 0.
147 class ReplicaGroups {
148 public:
149 // Creates an object, enforcing the requirements on `replica_ids_`.
150 explicit ReplicaGroups(std::vector<std::vector<int>> replica_ids);
151
152 mlir::DenseIntElementsAttr ToMLIR(mlir::MLIRContext& context) const;
153
154 std::string ToString() const;
155
156 int num_groups() const { return replica_ids_.size(); }
157 int group_size() const { return replica_ids_.front().size(); }
158 int num_replica_ids() const { return num_groups() * group_size(); }
159 const std::vector<std::vector<int>>& replica_ids() const {
160 return replica_ids_;
161 }
162
163 private:
164 // N groups of replica IDs, N > 0. All groups have the same size G, G > 0.
165 // All replica IDs are distinct values >= 0;
166 std::vector<std::vector<int>> replica_ids_; // replica ID order matters
167 };
168
169 // Creates an object but leaves `slices_` empty. `GlobalToSlices` should be
170 // called next to fill in `slices_`.
171 explicit GroupAssignment(ReplicaGroups global,
172 ReplicaToDeviceMap replica_to_device_map)
173 : global_(std::move(global)),
174 replica_to_device_map_(std::move(replica_to_device_map)) {}
175
176 // Divides the global view along slice boundaries and fill in the slice view.
177 Status GlobalToSlices();
178
179 ReplicaGroups global_;
180 std::vector<ReplicaGroups> hosts_; // sorted by increasing slice ID
181 std::vector<ReplicaGroups> slices_; // sorted by increasing slice ID
182 ReplicaToDeviceMap replica_to_device_map_;
183};
184
185} // namespace dtensor
186} // namespace tensorflow
187
188#endif // TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
189