1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ |
18 | |
19 | #include <ostream> |
20 | #include <string> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "absl/container/flat_hash_map.h" |
25 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
26 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/logging.h" |
29 | #include "tensorflow/core/platform/status.h" |
30 | #include "tensorflow/core/platform/types.h" |
31 | #include "tensorflow/dtensor/cc/dstatus.h" |
32 | |
33 | namespace tensorflow { |
34 | namespace dtensor { |
35 | |
36 | // Arranges all replica IDs in a DTensor mesh in groups, used as an attribute |
37 | // on collective operations. |
38 | // |
39 | // A group assignment has two views: |
40 | // |
41 | // - The global mesh view contains replica IDs from all participant TPU slices. |
42 | // These replica IDs are identical to global device IDs in a DTensor mesh. |
43 | // - The local slice view contains per-slice device IDs understood and used by |
44 | // the TPU runtime on each slice. These device IDs are used to set replica |
45 | // IDs on each slice. |
46 | // |
47 | // Some notable common cases: |
48 | // |
49 | // - In a single-slice case, `slice_size` is set to the actual slice size |
50 | // (e.g., 32 for 4x4 DF). The global and local views are identical. |
51 | // - In a special topology case, `slice_size` is set to 8. |
52 | // - In a multi-topology case, `slice_size` is set to the size of a single |
53 | // topology. |
54 | // All topologies must have the same size. |
55 | class GroupAssignment { |
56 | public: |
57 | using ReplicaId = int; |
58 | |
59 | struct DeviceId { |
60 | public: |
61 | int slice_id; |
62 | int core_id; // within `slice_id` |
63 | }; |
64 | |
65 | // Maps global replica IDs to local device IDs consisting of a slice ID and a |
66 | // core-on-slice ID. |
67 | class ReplicaToDeviceMap { |
68 | public: |
69 | // Creates a default map that orders devices according to TF task IDs |
70 | // followed by device ordinals. |
71 | static ReplicaToDeviceMap DefaultReplicaToDeviceMap(int num_slices, |
72 | int slice_size); |
73 | |
74 | // Constructs a map directly, checking it's valid. |
75 | explicit ReplicaToDeviceMap(absl::flat_hash_map<ReplicaId, DeviceId> map); |
76 | |
77 | int num_slices() { return num_slices_; } |
78 | int num_cores() { return map_.size(); } |
79 | DeviceId device_id(ReplicaId replica_id) { return map_[replica_id]; } |
80 | |
81 | private: |
82 | absl::flat_hash_map<ReplicaId, DeviceId> map_; |
83 | int num_slices_; |
84 | }; |
85 | |
86 | // Creates a group assignment by converting from an MLIR attribute. |
87 | static StatusOr<GroupAssignment> FromMLIR( |
88 | const mlir::DenseIntElementsAttr& group_assignment_attr, |
89 | ReplicaToDeviceMap replica_to_device_map); |
90 | |
91 | // Creates an MLIR attribute using the global view. |
92 | mlir::DenseIntElementsAttr GlobalToMLIR(mlir::MLIRContext& context) const { |
93 | return global_.ToMLIR(context); |
94 | } |
95 | |
96 | // Creates an MLIR attribute for a particular slice. |
97 | // Callers should make sure `slice_id` is >= 0 and < num_slices(). |
98 | StatusOr<mlir::DenseIntElementsAttr> SliceToMLIR(mlir::MLIRContext& context, |
99 | int slice_id) const { |
100 | if (slice_id < 0 || slice_id >= num_slices()) |
101 | return errors::InvalidArgument("slide_id was not within bounds." ); |
102 | return slices_[slice_id].ToMLIR(context); |
103 | } |
104 | |
105 | // Returns a string representation for debugging. |
106 | std::string ToString() const; |
107 | |
108 | // Returns true if every group in the global view only has replica IDs from |
109 | // the same slice. |
110 | bool IsWithinSlices() const; |
111 | |
112 | // Returns the number of slices in the local view. |
113 | int num_slices() const { return slices_.size(); } |
114 | |
115 | // These methods return attributes of the global view. |
116 | int num_groups() const { return global_.num_groups(); } |
117 | int group_size() const { return global_.group_size(); } |
118 | int num_replica_ids() const { return global_.num_replica_ids(); } |
119 | const std::vector<std::vector<int>>& replica_ids() const { |
120 | return global_.replica_ids(); |
121 | } |
122 | |
123 | // These methods return attributes of a particular slice. |
124 | // Callers should make sure `slice_id` is >= 0 and < num_slices(). |
125 | StatusOr<int> num_groups(int slice_id) const { |
126 | if (slice_id < 0 || slice_id >= num_slices()) |
127 | return errors::InvalidArgument("slide_id was not within bounds." ); |
128 | return slices_[slice_id].num_groups(); |
129 | } |
130 | StatusOr<int> group_size(int slice_id) const { |
131 | if (slice_id < 0 || slice_id >= num_slices()) |
132 | return errors::InvalidArgument("slide_id was not within bounds." ); |
133 | return slices_[slice_id].group_size(); |
134 | } |
135 | const std::vector<std::vector<int>>& replica_ids(int slice_id) const { |
136 | return slices_[slice_id].replica_ids(); |
137 | } |
138 | |
139 | // Returns the replica groups for collectives running on a particular host. |
140 | // Callers should make sure `slice_id` is >= 0 and < num_slices(). |
141 | const std::vector<std::vector<int>>& host_replica_ids(int slice_id) const { |
142 | return hosts_[slice_id].replica_ids(); |
143 | } |
144 | |
145 | private: |
146 | // Groups of consecutive replica IDs starting at 0. |
147 | class ReplicaGroups { |
148 | public: |
149 | // Creates an object, enforcing the requirements on `replica_ids_`. |
150 | explicit ReplicaGroups(std::vector<std::vector<int>> replica_ids); |
151 | |
152 | mlir::DenseIntElementsAttr ToMLIR(mlir::MLIRContext& context) const; |
153 | |
154 | std::string ToString() const; |
155 | |
156 | int num_groups() const { return replica_ids_.size(); } |
157 | int group_size() const { return replica_ids_.front().size(); } |
158 | int num_replica_ids() const { return num_groups() * group_size(); } |
159 | const std::vector<std::vector<int>>& replica_ids() const { |
160 | return replica_ids_; |
161 | } |
162 | |
163 | private: |
164 | // N groups of replica IDs, N > 0. All groups have the same size G, G > 0. |
165 | // All replica IDs are distinct values >= 0; |
166 | std::vector<std::vector<int>> replica_ids_; // replica ID order matters |
167 | }; |
168 | |
169 | // Creates an object but leaves `slices_` empty. `GlobalToSlices` should be |
170 | // called next to fill in `slices_`. |
171 | explicit GroupAssignment(ReplicaGroups global, |
172 | ReplicaToDeviceMap replica_to_device_map) |
173 | : global_(std::move(global)), |
174 | replica_to_device_map_(std::move(replica_to_device_map)) {} |
175 | |
176 | // Divides the global view along slice boundaries and fill in the slice view. |
177 | Status GlobalToSlices(); |
178 | |
179 | ReplicaGroups global_; |
180 | std::vector<ReplicaGroups> hosts_; // sorted by increasing slice ID |
181 | std::vector<ReplicaGroups> slices_; // sorted by increasing slice ID |
182 | ReplicaToDeviceMap replica_to_device_map_; |
183 | }; |
184 | |
185 | } // namespace dtensor |
186 | } // namespace tensorflow |
187 | |
188 | #endif // TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ |
189 | |