1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/group_assignment.h"
17
18#include <cstdint>
19#include <set>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_map.h"
25#include "llvm/ADT/STLExtras.h"
26#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
27#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28#include "mlir/IR/MLIRContext.h" // from @llvm-project
29#include "tensorflow/core/platform/errors.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/status.h"
32#include "tensorflow/core/platform/str_util.h"
33#include "tensorflow/core/platform/strcat.h"
34#include "tensorflow/core/platform/types.h"
35#include "tensorflow/dtensor/cc/dstatus.h"
36
37namespace tensorflow {
38namespace dtensor {
39
40GroupAssignment::ReplicaToDeviceMap
41GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(int num_slices,
42 int slice_size) {
43 absl::flat_hash_map<ReplicaId, DeviceId> map;
44 for (int i = 0; i < num_slices; ++i) {
45 for (int j = 0; j < slice_size; ++j) {
46 map[ReplicaId{i * slice_size + j}] = DeviceId{i, j};
47 }
48 }
49 return ReplicaToDeviceMap(std::move(map));
50}
51
52GroupAssignment::ReplicaToDeviceMap::ReplicaToDeviceMap(
53 absl::flat_hash_map<ReplicaId, DeviceId> map)
54 : map_(std::move(map)) {
55 std::set<int> slice_ids;
56 for (const auto& entry : map_) {
57 slice_ids.insert(entry.second.slice_id);
58 }
59 CHECK_GT(slice_ids.size(), 0); // Crash OK
60 CHECK_EQ(map_.size() % slice_ids.size(), 0); // Crash OK
61 num_slices_ = slice_ids.size();
62}
63
64GroupAssignment::ReplicaGroups::ReplicaGroups(
65 std::vector<std::vector<int>> replica_ids)
66 : replica_ids_(std::move(replica_ids)) {
67 int n = replica_ids_.size();
68 CHECK_GT(n, 0); // Crash OK
69 int g = replica_ids_.front().size();
70 CHECK_GT(g, 0); // Crash OK
71 std::set<int> seen_replica_ids;
72 for (std::vector<int>& group : replica_ids_) {
73 CHECK_EQ(group.size(), g); // Crash OK
74 for (int replica_id : group) {
75 CHECK_GE(replica_id, 0); // Crash OK
76 bool inserted = seen_replica_ids.insert(replica_id).second;
77 CHECK(inserted); // Crash OK
78 }
79 }
80}
81
82mlir::DenseIntElementsAttr GroupAssignment::ReplicaGroups::ToMLIR(
83 mlir::MLIRContext& context) const {
84 auto shaped_type = mlir::RankedTensorType::get(
85 {num_groups(), group_size()}, mlir::IntegerType::get(&context, 32));
86
87 llvm::SmallVector<int32, 4> flat_replica_ids;
88 flat_replica_ids.reserve(num_replica_ids());
89 for (const std::vector<int>& group : replica_ids()) {
90 flat_replica_ids.insert(flat_replica_ids.end(), group.begin(), group.end());
91 }
92
93 return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids);
94}
95
96std::string GroupAssignment::ReplicaGroups::ToString() const {
97 return strings::StrCat(
98 "[",
99 str_util::Join(replica_ids(), ", ",
100 [](std::string* str, const std::vector<int>& group) {
101 strings::StrAppend(str, "[", str_util::Join(group, ", "),
102 "]");
103 }),
104 "]");
105}
106
107StatusOr<GroupAssignment> GroupAssignment::FromMLIR(
108 const mlir::DenseIntElementsAttr& group_assignment_attr,
109 ReplicaToDeviceMap replica_to_device_map) {
110 mlir::ShapedType shaped_type = group_assignment_attr.getType();
111 if (!shaped_type.hasRank()) {
112 return errors::InvalidArgument("group_assignment_attr must have a rank");
113 }
114 if (shaped_type.getRank() != 2) {
115 return errors::InvalidArgument(
116 "group_assignment_attr must have a rank of 2, got ",
117 shaped_type.getRank());
118 }
119 llvm::ArrayRef<int64_t> shape = shaped_type.getShape();
120 int num_groups = shape[0];
121 if (num_groups <= 0) {
122 return errors::InvalidArgument(
123 "group_assignment_attr must have at least 1 group, got ", num_groups);
124 }
125 int group_size = shape[1];
126 if (group_size <= 0) {
127 return errors::InvalidArgument(
128 "group_assignment_attr must have non-empty groups, got ", group_size,
129 " replica IDs per group");
130 }
131 int num_replica_ids = num_groups * group_size;
132 if (num_replica_ids != replica_to_device_map.num_cores()) {
133 return errors::InvalidArgument("group_assignment_attr must have ",
134 replica_to_device_map.num_cores(),
135 " replica IDs, got ", num_replica_ids);
136 }
137
138 // Translate the flat group assignment to a 2D array.
139 std::vector<std::vector<int>> replica_ids;
140 replica_ids.resize(num_groups, std::vector<int>(group_size));
141 std::set<int> seen_replica_ids;
142 if (group_assignment_attr.getNumElements() != num_replica_ids) {
143 return errors::InvalidArgument(
144 "group_assignments_attr num elements was not equal to the number of "
145 "replica ids.");
146 }
147 for (const auto& it :
148 llvm::enumerate(group_assignment_attr.getValues<llvm::APInt>())) {
149 int index = it.index();
150 int replica_id = it.value().getSExtValue();
151
152 // If all replica IDs are within this range and distinct, they must be a
153 // permutation of [0, ..., num_replica_ids).
154 if (replica_id < 0 || replica_id >= num_replica_ids) {
155 return errors::InvalidArgument("Out of range replica ID: ", replica_id);
156 }
157 if (!seen_replica_ids.insert(replica_id).second) {
158 return errors::InvalidArgument(
159 "All replica IDs in group_assigment must be distinct, seeing ",
160 replica_id, " more than once");
161 }
162
163 replica_ids[index / group_size][index % group_size] = replica_id;
164 }
165
166 GroupAssignment group_assignment(
167 /*global=*/ReplicaGroups(std::move(replica_ids)),
168 std::move(replica_to_device_map));
169 TF_RETURN_IF_ERROR(group_assignment.GlobalToSlices());
170 return group_assignment;
171}
172
173std::string GroupAssignment::ToString() const {
174 return strings::StrCat(
175 "GroupAssignment global: ", global_.ToString(), "; hosts: ",
176 hosts_.empty()
177 ? "<none>"
178 : str_util::Join(hosts_, ", ",
179 [](std::string* str, const ReplicaGroups& groups) {
180 strings::StrAppend(str, groups.ToString());
181 }),
182 "; slices: ",
183 slices_.empty()
184 ? "<none>"
185 : str_util::Join(slices_, ", ",
186 [](std::string* str, const ReplicaGroups& groups) {
187 strings::StrAppend(str, groups.ToString());
188 }));
189}
190
191bool GroupAssignment::IsWithinSlices() const {
192 // This function returns true iff no group in the global view gets split in
193 // `GlobalToSlices`, i.e., the total group count remains the same.
194 int total_num_groups = 0;
195 for (int i = 0; i < num_slices(); i++) {
196 total_num_groups += num_groups(i).value();
197 }
198 if (total_num_groups != num_groups()) return false;
199 return total_num_groups == num_groups();
200}
201
202Status GroupAssignment::GlobalToSlices() {
203 VLOG(2) << "Original group assignment: " << ToString();
204
205 int num_slices = replica_to_device_map_.num_slices();
206 if (num_slices == 0) {
207 return errors::InvalidArgument("Unexpectedly empty replica_to_device_map.");
208 }
209
210 // For each replica group in global replica groups, divide its replicas based
211 // on which slices they come from. Then, for each slice, collect subgroups
212 // from every such division and form a new ReplicaGroup for that slice.
213 std::vector<std::vector<std::vector<int>>> replica_groups_per_host;
214 std::vector<std::vector<std::vector<int>>> replica_groups_per_slice;
215 replica_groups_per_host.resize(num_slices, {});
216 replica_groups_per_slice.resize(num_slices, {});
217
218 for (const std::vector<int>& replica_group : replica_ids()) {
219 std::vector<std::vector<int>> replica_group_divided_by_host;
220 replica_group_divided_by_host.resize(num_slices, {});
221 std::vector<std::vector<int>> replica_group_divided_by_slice;
222 replica_group_divided_by_slice.resize(num_slices, {});
223
224 for (int replica_id : replica_group) {
225 // TODO(b/183426911): Use DeviceId::core_id in ReplicaGroup directly for
226 // now. Integrate with device assignment with proper typing.
227 DeviceId device_id = replica_to_device_map_.device_id(replica_id);
228 replica_group_divided_by_host[device_id.slice_id].push_back(replica_id);
229 replica_group_divided_by_slice[device_id.slice_id].push_back(
230 device_id.core_id);
231 }
232
233 for (int i = 0; i < num_slices; ++i) {
234 if (!replica_group_divided_by_host[i].empty()) {
235 // Host meshes have the same global device and replica IDs as TPU
236 // meshes. Let the first replica in every group do a host collective.
237 replica_groups_per_host[i].push_back(
238 std::vector<int>(1, replica_group_divided_by_host[i].front()));
239 }
240 if (!replica_group_divided_by_slice[i].empty()) {
241 replica_groups_per_slice[i].push_back(
242 std::move(replica_group_divided_by_slice[i]));
243 }
244 }
245 }
246
247 hosts_.reserve(num_slices);
248 slices_.reserve(num_slices);
249 for (int i = 0; i < num_slices; ++i) {
250 hosts_.push_back(ReplicaGroups(std::move(replica_groups_per_host[i])));
251 slices_.push_back(ReplicaGroups(std::move(replica_groups_per_slice[i])));
252 }
253
254 VLOG(2) << "Divided group assignment: " << ToString();
255 return OkStatus();
256}
257
258} // namespace dtensor
259} // namespace tensorflow
260