1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ |
17 | #define TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ |
18 | |
19 | #include <algorithm> |
20 | #include <cstdint> |
21 | #include <iostream> |
22 | #include <string> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "absl/container/flat_hash_map.h" |
27 | #include "absl/container/flat_hash_set.h" |
28 | #include "absl/container/inlined_vector.h" |
29 | #include "absl/strings/string_view.h" |
30 | #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" |
31 | #include "tensorflow/core/common_runtime/device_mgr.h" |
32 | #include "tensorflow/core/framework/tensor_shape.h" |
33 | #include "tensorflow/core/platform/statusor.h" |
34 | #include "tensorflow/dtensor/cc/dstatus.h" |
35 | #include "tensorflow/dtensor/proto/layout.pb.h" |
36 | |
37 | // Definitions for DTensor mesh & layout. |
38 | // |
39 | // A mesh describes how a set of devices is partitioned. |
40 | // A layout describes how a distributed tensor is partitioned across a mesh (and |
41 | // thus across devices). Defining tensor layouts in terms of mesh dimensions |
42 | // allows us to efficiently determine the communication required when computing |
43 | // an operation with tensors of different layouts. |
44 | namespace tensorflow { |
45 | namespace dtensor { |
46 | |
47 | // The location of a device in a mesh. |
48 | // |
49 | // Each device has a unique location in the mesh, which is indicated by the |
50 | // offset in each mesh dimension. e.g. a mesh: |
51 | // |
52 | // [x:4, y:3, z:2] |
53 | // |
54 | // Must consist of 24 devices placed densely into the corresponding 3D space. |
55 | using DeviceLocation = absl::InlinedVector<int64, 4>; |
56 | |
57 | // A shard refers to a partition of a tensor. Shards are arranged in |
58 | // ShardVectors that contains a list of Shards and a list of integers |
59 | // representing the number of shards in each dimension. |
60 | // |
61 | // Example: layout = sharding_specs:x,y, mesh:|x=2,y=2|. This can be represented |
62 | // with a ShardVector: |
63 | // - shards = (1,1), (1,2), (2,1), (2,2) |
64 | // - num_shards_per_dim = (2,2). |
65 | // |
66 | // The number of elements in each shard matches the tensor rank. |
67 | using Shard = std::vector<int>; |
68 | |
69 | struct ShardVector { |
70 | bool operator==(const ShardVector& other) const; |
71 | bool operator!=(const ShardVector& other) const { return !(*this == other); } |
72 | std::string ToString() const; |
73 | |
74 | bool ContainsShard(const Shard& shard) const; |
75 | |
76 | std::vector<Shard> shards; |
77 | std::vector<int> num_shards_per_dim; |
78 | }; |
79 | |
80 | struct MeshDimension { |
81 | MeshDimension(const std::string& name, int64 size) |
82 | : name(std::move(name)), size(size) {} |
83 | MeshDimension() = default; |
84 | |
85 | std::string name; |
86 | int64 size; |
87 | }; |
88 | |
89 | class Mesh { |
90 | public: |
91 | // Failed serialized strings are represented with en empty string, therefore |
92 | // we use this string representation of an empty mesh instead to avoid |
93 | // confusion. |
94 | static constexpr const char* kEmptyMeshString = "empty_mesh" ; |
95 | static Mesh Empty(); |
96 | bool IsEmpty() const; |
97 | Mesh() = default; |
98 | |
99 | // Parses from MeshProto. |
100 | static StatusOr<Mesh> ParseFromProto(const MeshProto& proto); |
101 | // Parses from a human readable string version of the mesh, currently used |
102 | // to represent meshes in MLIR: |
103 | // mesh = <name|List[MeshDim]|List[GlobalId]|List[LocalId]|List[Devices]> |
104 | // |
105 | // Example: |
106 | // mesh = |
107 | // <name|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3> |
108 | static StatusOr<Mesh> FromString(const std::string& str); |
109 | std::string ToString() const; |
110 | MeshProto ToProto() const; |
111 | |
112 | // Creates mesh without specific devices associated to it (aka abstract mesh). |
113 | // This is an experimental API. Use only if strictly needed. |
114 | static StatusOr<Mesh> GetAbstractMesh( |
115 | const std::string& name, const std::vector<MeshDimension>& mesh_dims); |
116 | // Creates fully defined mesh. |
117 | static StatusOr<Mesh> GetMesh( |
118 | const std::string& name, const std::vector<MeshDimension>& mesh_dims, |
119 | const std::vector<std::int64_t>& global_device_ids, |
120 | const std::vector<std::int64_t>& local_device_ids, |
121 | const std::vector<std::string>& local_devices, |
122 | const std::vector<std::string>& global_devices); |
123 | |
124 | bool is_cpu_mesh() const { return device_type() == "CPU" ; } |
125 | bool is_epu_mesh() const { return device_type() == "EPU" ; } |
126 | bool is_tpu_mesh() const { return device_type() == "TPU" ; } |
127 | // Returns whether the mesh is a remote mesh. |
128 | bool is_remote() const { |
129 | return local_device_ids_.empty() && !global_device_ids_.empty(); |
130 | } |
131 | |
132 | // Device information methods. |
133 | std::string device_type() const; |
134 | // Takes an index in the flattened list of devices and returns a location |
135 | // in the mesh. |
136 | StatusOr<const DeviceLocation> device_location(int offset) const; |
137 | int64 num_devices() const; |
138 | absl::Span<const std::string> local_devices() const { return local_devices_; } |
139 | absl::Span<const int64_t> local_device_ids() const { |
140 | return local_device_ids_; |
141 | } |
142 | // Parses names of local_devices according to TF's Device Name Utils. |
143 | StatusOr<const std::vector<DeviceNameUtils::ParsedName>> ParsedDevices() |
144 | const; |
145 | // Convert to given device type. |
146 | StatusOr<Mesh> ToDeviceType(const std::string& device_type) const; |
147 | std::vector<std::string> hosts() const; |
148 | |
149 | // Consumes a location in the mesh and returns its corresponding index in |
150 | // the flattened list of devices. |
151 | int64 GetFlattenedCoordinate(const DeviceLocation& loc) const; |
152 | |
153 | const MeshDimension& dim(int64 index) const { return mesh_dims_[index]; } |
154 | std::vector<MeshDimension> dims() const { return mesh_dims_; } |
155 | // Returns size of mesh dimension. |
156 | StatusOr<int64> dim_size(absl::string_view name) const; |
157 | // Returns list of mesh dimension sizes. |
158 | std::vector<int64> dim_sizes() const; |
159 | const std::string& dim_name(int64 index) const { |
160 | return mesh_dims_[index].name; |
161 | } |
162 | int64_t min_global_device_id() const { |
163 | DCHECK(!global_device_ids_.empty()); |
164 | return *std::min_element(global_device_ids_.begin(), |
165 | global_device_ids_.end()); |
166 | } |
167 | |
168 | absl::Span<const int64_t> global_device_ids() const { |
169 | return global_device_ids_; |
170 | } |
171 | |
172 | const std::vector<std::string>& global_devices() const { |
173 | return global_devices_; |
174 | } |
175 | // Returns index of given dim_name in the mesh. |
176 | StatusOr<int32> idx_for_dim(absl::string_view dim_name) const; |
177 | |
178 | // Returns the index of MeshDimension in mesh where the mesh dimension name is |
179 | // `mesh_name`. |
180 | int GetMeshDimIndexWithName(const std::string& mesh_name) const; |
181 | bool IsMeshDim(const std::string& dim_name) const; |
182 | |
183 | int64 rank() const; |
184 | int64 size() const; |
185 | const std::string& name() const { return name_; } |
186 | |
187 | // Global unique fingerprint. Same on different workers. |
188 | uint64 GlobalFingerprint() const; |
189 | |
190 | bool operator==(const Mesh& b) const; |
191 | bool operator!=(const Mesh& b) const { return !((*this) == b); } |
192 | bool operator<(const Mesh& b) const { |
193 | return this->ToString() < b.ToString(); |
194 | } |
195 | |
196 | template <typename H> |
197 | friend H AbslHashValue(H h, const Mesh& m) { |
198 | return H::combine(std::move(h), m.ToString()); |
199 | } |
200 | |
201 | // A map from mesh names to their corresponding core ID mappings. The core ID |
202 | // mapping is stored as a vector. The i-th element in the vector is the ID of |
203 | // the core represented by global device ID of i in this mesh. |
204 | // |
205 | // The entry stored under the empty name key (the so-called "default mapping" |
206 | // in some comments) is special. It is always set at the end of TPU |
207 | // initialization. It represents the mapping for any mesh whose global device |
208 | // IDs follow TF task-device ordinals. Legacy and test meshes created without |
209 | // using the `create_tpu_mesh` helper follow that rule and can use this entry. |
210 | static std::map<std::string, std::vector<int>>& tpu_core_ids(); |
211 | |
212 | // The host mesh associated with any user-defined TPU mesh. |
213 | static std::string& tpu_host_mesh(); |
214 | |
215 | private: |
216 | std::string name_; |
217 | std::vector<MeshDimension> mesh_dims_; |
218 | std::vector<std::string> local_devices_; |
219 | std::vector<int64_t> local_device_ids_; |
220 | std::vector<int64_t> global_device_ids_; |
221 | std::vector<std::string> global_devices_; |
222 | }; |
223 | |
224 | class Layout { |
225 | public: |
226 | static constexpr const char* kUnshardedDim = "unsharded" ; |
227 | // This spec should only be used to express no preferred sharding in the |
228 | // Layout propagation algorithm. |
229 | static constexpr const char* kAny = "any" ; |
230 | // Failed serialized strings are represented with en empty string, therefore |
231 | // we use this string representation of an empty layout instead to avoid |
232 | // confusion. |
233 | static constexpr const char* kEmptyLayoutString = "empty_layout" ; |
234 | // Used for the relayout operation, to allow relayout act as an identity on |
235 | // the layout for the given dimension. |
236 | static constexpr const char* kMatch = "match" ; |
237 | |
238 | // Returns empty layout. |
239 | static Layout Empty(); |
240 | |
241 | // Parses from LayoutProto. |
242 | static StatusOr<Layout> FromProto(const LayoutProto& proto); |
243 | // Parses from a human readable string version of the layout, currently used |
244 | // to represent layouts in MLIR: |
245 | // layout = <sharding_specs:List[specs] mesh:name|List[MeshDim]| |
246 | // List[GlobalId]|List[LocalId]|List[Devices]> |
247 | // |
248 | // Example: |
249 | // layout = <sharding_specs:x,not_sharded mesh:name|x=2,y=2|0,1,2,3|0,1,2,3| |
250 | // /job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1, |
251 | // /job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3> |
252 | static StatusOr<Layout> FromString(std::string layout_str); |
253 | // Creates human readable string version of a layout. |
254 | std::string ToString() const; |
255 | LayoutProto ToProto() const; |
256 | |
257 | const Mesh& mesh() const { return mesh_; } |
258 | static Layout ReplicatedOnMesh(const Mesh& mesh, int rank); |
259 | static Layout AnyOnMesh(const Mesh& mesh, int rank); |
260 | // Creates a mesh of unique shards. |
261 | Mesh ReducedMesh() const; |
262 | void set_mesh(Mesh mesh) { mesh_ = mesh; } |
263 | |
264 | // Returns a layout for the transposed matrix for given layout. This assumes |
265 | // that only the last two dimensions are used for matrix computation and all |
266 | // dimensions before are batch dimensions. |
267 | static StatusOr<Layout> Transposed2D(const Layout& layout); |
268 | static bool IsUnshardedDimension(const absl::string_view name) { |
269 | return name == kUnshardedDim; |
270 | } |
271 | static bool IsShardedDimension(const absl::string_view name) { |
272 | return !IsUnshardedDimension(name); |
273 | } |
274 | static bool IsUnshardedSpec(const ShardingSpec& spec) { |
275 | return IsUnshardedDimension(spec.sharding_spec()); |
276 | } |
277 | static bool IsShardedSpec(const ShardingSpec& spec) { |
278 | return !IsUnshardedDimension(spec.sharding_spec()); |
279 | } |
280 | static StatusOr<Layout> GetLayout( |
281 | const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh); |
282 | static StatusOr<Layout> GetLayout( |
283 | const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh); |
284 | |
285 | // Makes a new layout from this one dropping the given dimensions. |
286 | // If keep_dims is true, the dimensions are replicated rather than |
287 | // deleted. |
288 | Layout GetLayoutWithReducedDims(const absl::flat_hash_set<int>& reduced_dims, |
289 | bool keep_dims) const; |
290 | |
291 | // Truncates a layout at the front or back, depending on the value of end. |
292 | // end = false returns the layout upto the split point, |
293 | // end = true returns the layout from the split point. |
294 | Layout Truncate(int64 split_point, bool end = false) const; |
295 | |
296 | // Left or right pad the layout to a max rank. |
297 | Layout LeftPad(int64 rank) const; |
298 | |
299 | bool IsFullyReplicated() const; |
300 | bool IsLastDimReplicated() const; |
301 | // Checks that the last N-1 dimensions are replicated |
302 | bool IsBatchParallel() const; |
303 | // Checks that the dimensions from [-non_batch_rank, end) are replicaed |
304 | bool IsBatchParallel(int non_batch_rank) const; |
305 | bool IsEmpty() const; |
306 | |
307 | // Compute global shape using the layout and provided local_shape. |
308 | std::vector<int64_t> GlobalShapeFromLocalShape( |
309 | const std::vector<int64_t>& local_shape) const; |
310 | |
311 | std::vector<int64_t> LocalShapeFromGlobalShape( |
312 | absl::Span<const int64_t> global_shape) const; |
313 | PartialTensorShape LocalShapeFromGlobalShape( |
314 | const PartialTensorShape& global_shape) const; |
315 | |
316 | int64 rank() const { return sharding_specs_.size(); } |
317 | size_t num_shards_for_dim(const ShardingSpec& dim) const; |
318 | std::vector<int32> num_shards() const; |
319 | |
320 | const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; } |
321 | absl::Span<const ShardingSpec> sharding_specs() const { |
322 | return sharding_specs_; |
323 | } |
324 | |
325 | // Computes the corresponding shard vector to this layout. |
326 | ShardVector GetShardVector() const; |
327 | |
328 | // Returns sharding specs in string form. |
329 | std::vector<std::string> sharding_spec_strs() const; |
330 | |
331 | int64 num_devices() const { return mesh_.num_devices(); } |
332 | StatusOr<const DeviceLocation> device_location(int64 device_id) const { |
333 | return mesh_.device_location(device_id); |
334 | } |
335 | // Map hosts to shards. |
336 | std::map<std::string, ShardVector> HostShardMap() const; |
337 | |
338 | const std::string& sharding_spec(int idx) const; |
339 | |
340 | // Two layouts are equivalent if they would result in the same sharding for |
341 | // the tensor. E.g. if on is unsharded and the other is sharded on a mesh |
342 | // dimension of size 1. |
343 | bool IsEquivalent(const Layout& b) const; |
344 | bool operator==(const Layout& b) const; |
345 | bool operator!=(const Layout& b) const { return !((*this) == b); } |
346 | bool operator<(const Layout& b) const { |
347 | return this->ToString() < b.ToString(); |
348 | } |
349 | |
350 | private: |
351 | std::vector<ShardingSpec> sharding_specs_; |
352 | Mesh mesh_; |
353 | }; |
354 | |
355 | // Takes two layouts and concatenates their TensorDimensions. If the meshes for |
356 | // the two layouts are different or both layouts are using the same mesh |
357 | // dimension returns an error rather than a layout. |
358 | StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a, |
359 | const Layout& layout_b); |
360 | |
361 | } // namespace dtensor |
362 | } // namespace tensorflow |
363 | |
364 | #endif // TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ |
365 | |