1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/group_assignment.h" |
17 | |
18 | #include <cstdint> |
19 | #include <set> |
20 | #include <string> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "absl/container/flat_hash_map.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
27 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
28 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
29 | #include "tensorflow/core/platform/errors.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/status.h" |
32 | #include "tensorflow/core/platform/str_util.h" |
33 | #include "tensorflow/core/platform/strcat.h" |
34 | #include "tensorflow/core/platform/types.h" |
35 | #include "tensorflow/dtensor/cc/dstatus.h" |
36 | |
37 | namespace tensorflow { |
38 | namespace dtensor { |
39 | |
40 | GroupAssignment::ReplicaToDeviceMap |
41 | GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(int num_slices, |
42 | int slice_size) { |
43 | absl::flat_hash_map<ReplicaId, DeviceId> map; |
44 | for (int i = 0; i < num_slices; ++i) { |
45 | for (int j = 0; j < slice_size; ++j) { |
46 | map[ReplicaId{i * slice_size + j}] = DeviceId{i, j}; |
47 | } |
48 | } |
49 | return ReplicaToDeviceMap(std::move(map)); |
50 | } |
51 | |
52 | GroupAssignment::ReplicaToDeviceMap::ReplicaToDeviceMap( |
53 | absl::flat_hash_map<ReplicaId, DeviceId> map) |
54 | : map_(std::move(map)) { |
55 | std::set<int> slice_ids; |
56 | for (const auto& entry : map_) { |
57 | slice_ids.insert(entry.second.slice_id); |
58 | } |
59 | CHECK_GT(slice_ids.size(), 0); // Crash OK |
60 | CHECK_EQ(map_.size() % slice_ids.size(), 0); // Crash OK |
61 | num_slices_ = slice_ids.size(); |
62 | } |
63 | |
64 | GroupAssignment::ReplicaGroups::ReplicaGroups( |
65 | std::vector<std::vector<int>> replica_ids) |
66 | : replica_ids_(std::move(replica_ids)) { |
67 | int n = replica_ids_.size(); |
68 | CHECK_GT(n, 0); // Crash OK |
69 | int g = replica_ids_.front().size(); |
70 | CHECK_GT(g, 0); // Crash OK |
71 | std::set<int> seen_replica_ids; |
72 | for (std::vector<int>& group : replica_ids_) { |
73 | CHECK_EQ(group.size(), g); // Crash OK |
74 | for (int replica_id : group) { |
75 | CHECK_GE(replica_id, 0); // Crash OK |
76 | bool inserted = seen_replica_ids.insert(replica_id).second; |
77 | CHECK(inserted); // Crash OK |
78 | } |
79 | } |
80 | } |
81 | |
82 | mlir::DenseIntElementsAttr GroupAssignment::ReplicaGroups::ToMLIR( |
83 | mlir::MLIRContext& context) const { |
84 | auto shaped_type = mlir::RankedTensorType::get( |
85 | {num_groups(), group_size()}, mlir::IntegerType::get(&context, 32)); |
86 | |
87 | llvm::SmallVector<int32, 4> flat_replica_ids; |
88 | flat_replica_ids.reserve(num_replica_ids()); |
89 | for (const std::vector<int>& group : replica_ids()) { |
90 | flat_replica_ids.insert(flat_replica_ids.end(), group.begin(), group.end()); |
91 | } |
92 | |
93 | return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids); |
94 | } |
95 | |
96 | std::string GroupAssignment::ReplicaGroups::ToString() const { |
97 | return strings::StrCat( |
98 | "[" , |
99 | str_util::Join(replica_ids(), ", " , |
100 | [](std::string* str, const std::vector<int>& group) { |
101 | strings::StrAppend(str, "[" , str_util::Join(group, ", " ), |
102 | "]" ); |
103 | }), |
104 | "]" ); |
105 | } |
106 | |
107 | StatusOr<GroupAssignment> GroupAssignment::FromMLIR( |
108 | const mlir::DenseIntElementsAttr& group_assignment_attr, |
109 | ReplicaToDeviceMap replica_to_device_map) { |
110 | mlir::ShapedType shaped_type = group_assignment_attr.getType(); |
111 | if (!shaped_type.hasRank()) { |
112 | return errors::InvalidArgument("group_assignment_attr must have a rank" ); |
113 | } |
114 | if (shaped_type.getRank() != 2) { |
115 | return errors::InvalidArgument( |
116 | "group_assignment_attr must have a rank of 2, got " , |
117 | shaped_type.getRank()); |
118 | } |
119 | llvm::ArrayRef<int64_t> shape = shaped_type.getShape(); |
120 | int num_groups = shape[0]; |
121 | if (num_groups <= 0) { |
122 | return errors::InvalidArgument( |
123 | "group_assignment_attr must have at least 1 group, got " , num_groups); |
124 | } |
125 | int group_size = shape[1]; |
126 | if (group_size <= 0) { |
127 | return errors::InvalidArgument( |
128 | "group_assignment_attr must have non-empty groups, got " , group_size, |
129 | " replica IDs per group" ); |
130 | } |
131 | int num_replica_ids = num_groups * group_size; |
132 | if (num_replica_ids != replica_to_device_map.num_cores()) { |
133 | return errors::InvalidArgument("group_assignment_attr must have " , |
134 | replica_to_device_map.num_cores(), |
135 | " replica IDs, got " , num_replica_ids); |
136 | } |
137 | |
138 | // Translate the flat group assignment to a 2D array. |
139 | std::vector<std::vector<int>> replica_ids; |
140 | replica_ids.resize(num_groups, std::vector<int>(group_size)); |
141 | std::set<int> seen_replica_ids; |
142 | if (group_assignment_attr.getNumElements() != num_replica_ids) { |
143 | return errors::InvalidArgument( |
144 | "group_assignments_attr num elements was not equal to the number of " |
145 | "replica ids." ); |
146 | } |
147 | for (const auto& it : |
148 | llvm::enumerate(group_assignment_attr.getValues<llvm::APInt>())) { |
149 | int index = it.index(); |
150 | int replica_id = it.value().getSExtValue(); |
151 | |
152 | // If all replica IDs are within this range and distinct, they must be a |
153 | // permutation of [0, ..., num_replica_ids). |
154 | if (replica_id < 0 || replica_id >= num_replica_ids) { |
155 | return errors::InvalidArgument("Out of range replica ID: " , replica_id); |
156 | } |
157 | if (!seen_replica_ids.insert(replica_id).second) { |
158 | return errors::InvalidArgument( |
159 | "All replica IDs in group_assigment must be distinct, seeing " , |
160 | replica_id, " more than once" ); |
161 | } |
162 | |
163 | replica_ids[index / group_size][index % group_size] = replica_id; |
164 | } |
165 | |
166 | GroupAssignment group_assignment( |
167 | /*global=*/ReplicaGroups(std::move(replica_ids)), |
168 | std::move(replica_to_device_map)); |
169 | TF_RETURN_IF_ERROR(group_assignment.GlobalToSlices()); |
170 | return group_assignment; |
171 | } |
172 | |
173 | std::string GroupAssignment::ToString() const { |
174 | return strings::StrCat( |
175 | "GroupAssignment global: " , global_.ToString(), "; hosts: " , |
176 | hosts_.empty() |
177 | ? "<none>" |
178 | : str_util::Join(hosts_, ", " , |
179 | [](std::string* str, const ReplicaGroups& groups) { |
180 | strings::StrAppend(str, groups.ToString()); |
181 | }), |
182 | "; slices: " , |
183 | slices_.empty() |
184 | ? "<none>" |
185 | : str_util::Join(slices_, ", " , |
186 | [](std::string* str, const ReplicaGroups& groups) { |
187 | strings::StrAppend(str, groups.ToString()); |
188 | })); |
189 | } |
190 | |
191 | bool GroupAssignment::IsWithinSlices() const { |
192 | // This function returns true iff no group in the global view gets split in |
193 | // `GlobalToSlices`, i.e., the total group count remains the same. |
194 | int total_num_groups = 0; |
195 | for (int i = 0; i < num_slices(); i++) { |
196 | total_num_groups += num_groups(i).value(); |
197 | } |
198 | if (total_num_groups != num_groups()) return false; |
199 | return total_num_groups == num_groups(); |
200 | } |
201 | |
202 | Status GroupAssignment::GlobalToSlices() { |
203 | VLOG(2) << "Original group assignment: " << ToString(); |
204 | |
205 | int num_slices = replica_to_device_map_.num_slices(); |
206 | if (num_slices == 0) { |
207 | return errors::InvalidArgument("Unexpectedly empty replica_to_device_map." ); |
208 | } |
209 | |
210 | // For each replica group in global replica groups, divide its replicas based |
211 | // on which slices they come from. Then, for each slice, collect subgroups |
212 | // from every such division and form a new ReplicaGroup for that slice. |
213 | std::vector<std::vector<std::vector<int>>> replica_groups_per_host; |
214 | std::vector<std::vector<std::vector<int>>> replica_groups_per_slice; |
215 | replica_groups_per_host.resize(num_slices, {}); |
216 | replica_groups_per_slice.resize(num_slices, {}); |
217 | |
218 | for (const std::vector<int>& replica_group : replica_ids()) { |
219 | std::vector<std::vector<int>> replica_group_divided_by_host; |
220 | replica_group_divided_by_host.resize(num_slices, {}); |
221 | std::vector<std::vector<int>> replica_group_divided_by_slice; |
222 | replica_group_divided_by_slice.resize(num_slices, {}); |
223 | |
224 | for (int replica_id : replica_group) { |
225 | // TODO(b/183426911): Use DeviceId::core_id in ReplicaGroup directly for |
226 | // now. Integrate with device assignment with proper typing. |
227 | DeviceId device_id = replica_to_device_map_.device_id(replica_id); |
228 | replica_group_divided_by_host[device_id.slice_id].push_back(replica_id); |
229 | replica_group_divided_by_slice[device_id.slice_id].push_back( |
230 | device_id.core_id); |
231 | } |
232 | |
233 | for (int i = 0; i < num_slices; ++i) { |
234 | if (!replica_group_divided_by_host[i].empty()) { |
235 | // Host meshes have the same global device and replica IDs as TPU |
236 | // meshes. Let the first replica in every group do a host collective. |
237 | replica_groups_per_host[i].push_back( |
238 | std::vector<int>(1, replica_group_divided_by_host[i].front())); |
239 | } |
240 | if (!replica_group_divided_by_slice[i].empty()) { |
241 | replica_groups_per_slice[i].push_back( |
242 | std::move(replica_group_divided_by_slice[i])); |
243 | } |
244 | } |
245 | } |
246 | |
247 | hosts_.reserve(num_slices); |
248 | slices_.reserve(num_slices); |
249 | for (int i = 0; i < num_slices; ++i) { |
250 | hosts_.push_back(ReplicaGroups(std::move(replica_groups_per_host[i]))); |
251 | slices_.push_back(ReplicaGroups(std::move(replica_groups_per_slice[i]))); |
252 | } |
253 | |
254 | VLOG(2) << "Divided group assignment: " << ToString(); |
255 | return OkStatus(); |
256 | } |
257 | |
258 | } // namespace dtensor |
259 | } // namespace tensorflow |
260 | |