1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
17#define TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
18
19#include <algorithm>
20#include <cstdint>
21#include <iostream>
22#include <string>
23#include <utility>
24#include <vector>
25
26#include "absl/container/flat_hash_map.h"
27#include "absl/container/flat_hash_set.h"
28#include "absl/container/inlined_vector.h"
29#include "absl/strings/string_view.h"
30#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
31#include "tensorflow/core/common_runtime/device_mgr.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/platform/statusor.h"
34#include "tensorflow/dtensor/cc/dstatus.h"
35#include "tensorflow/dtensor/proto/layout.pb.h"
36
37// Definitions for DTensor mesh & layout.
38//
39// A mesh describes how a set of devices is partitioned.
40// A layout describes how a distributed tensor is partitioned across a mesh (and
41// thus across devices). Defining tensor layouts in terms of mesh dimensions
42// allows us to efficiently determine the communication required when computing
43// an operation with tensors of different layouts.
44namespace tensorflow {
45namespace dtensor {
46
47// The location of a device in a mesh.
48//
49// Each device has a unique location in the mesh, which is indicated by the
50// offset in each mesh dimension. e.g. a mesh:
51//
52// [x:4, y:3, z:2]
53//
54// Must consist of 24 devices placed densely into the corresponding 3D space.
55using DeviceLocation = absl::InlinedVector<int64, 4>;
56
57// A shard refers to a partition of a tensor. Shards are arranged in
58// ShardVectors that contains a list of Shards and a list of integers
59// representing the number of shards in each dimension.
60//
61// Example: layout = sharding_specs:x,y, mesh:|x=2,y=2|. This can be represented
62// with a ShardVector:
63// - shards = (1,1), (1,2), (2,1), (2,2)
64// - num_shards_per_dim = (2,2).
65//
66// The number of elements in each shard matches the tensor rank.
67using Shard = std::vector<int>;
68
69struct ShardVector {
70 bool operator==(const ShardVector& other) const;
71 bool operator!=(const ShardVector& other) const { return !(*this == other); }
72 std::string ToString() const;
73
74 bool ContainsShard(const Shard& shard) const;
75
76 std::vector<Shard> shards;
77 std::vector<int> num_shards_per_dim;
78};
79
80struct MeshDimension {
81 MeshDimension(const std::string& name, int64 size)
82 : name(std::move(name)), size(size) {}
83 MeshDimension() = default;
84
85 std::string name;
86 int64 size;
87};
88
89class Mesh {
90 public:
91 // Failed serialized strings are represented with en empty string, therefore
92 // we use this string representation of an empty mesh instead to avoid
93 // confusion.
94 static constexpr const char* kEmptyMeshString = "empty_mesh";
95 static Mesh Empty();
96 bool IsEmpty() const;
97 Mesh() = default;
98
99 // Parses from MeshProto.
100 static StatusOr<Mesh> ParseFromProto(const MeshProto& proto);
101 // Parses from a human readable string version of the mesh, currently used
102 // to represent meshes in MLIR:
103 // mesh = <name|List[MeshDim]|List[GlobalId]|List[LocalId]|List[Devices]>
104 //
105 // Example:
106 // mesh =
107 // <name|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
108 static StatusOr<Mesh> FromString(const std::string& str);
109 std::string ToString() const;
110 MeshProto ToProto() const;
111
112 // Creates mesh without specific devices associated to it (aka abstract mesh).
113 // This is an experimental API. Use only if strictly needed.
114 static StatusOr<Mesh> GetAbstractMesh(
115 const std::string& name, const std::vector<MeshDimension>& mesh_dims);
116 // Creates fully defined mesh.
117 static StatusOr<Mesh> GetMesh(
118 const std::string& name, const std::vector<MeshDimension>& mesh_dims,
119 const std::vector<std::int64_t>& global_device_ids,
120 const std::vector<std::int64_t>& local_device_ids,
121 const std::vector<std::string>& local_devices,
122 const std::vector<std::string>& global_devices);
123
124 bool is_cpu_mesh() const { return device_type() == "CPU"; }
125 bool is_epu_mesh() const { return device_type() == "EPU"; }
126 bool is_tpu_mesh() const { return device_type() == "TPU"; }
127 // Returns whether the mesh is a remote mesh.
128 bool is_remote() const {
129 return local_device_ids_.empty() && !global_device_ids_.empty();
130 }
131
132 // Device information methods.
133 std::string device_type() const;
134 // Takes an index in the flattened list of devices and returns a location
135 // in the mesh.
136 StatusOr<const DeviceLocation> device_location(int offset) const;
137 int64 num_devices() const;
138 absl::Span<const std::string> local_devices() const { return local_devices_; }
139 absl::Span<const int64_t> local_device_ids() const {
140 return local_device_ids_;
141 }
142 // Parses names of local_devices according to TF's Device Name Utils.
143 StatusOr<const std::vector<DeviceNameUtils::ParsedName>> ParsedDevices()
144 const;
145 // Convert to given device type.
146 StatusOr<Mesh> ToDeviceType(const std::string& device_type) const;
147 std::vector<std::string> hosts() const;
148
149 // Consumes a location in the mesh and returns its corresponding index in
150 // the flattened list of devices.
151 int64 GetFlattenedCoordinate(const DeviceLocation& loc) const;
152
153 const MeshDimension& dim(int64 index) const { return mesh_dims_[index]; }
154 std::vector<MeshDimension> dims() const { return mesh_dims_; }
155 // Returns size of mesh dimension.
156 StatusOr<int64> dim_size(absl::string_view name) const;
157 // Returns list of mesh dimension sizes.
158 std::vector<int64> dim_sizes() const;
159 const std::string& dim_name(int64 index) const {
160 return mesh_dims_[index].name;
161 }
162 int64_t min_global_device_id() const {
163 DCHECK(!global_device_ids_.empty());
164 return *std::min_element(global_device_ids_.begin(),
165 global_device_ids_.end());
166 }
167
168 absl::Span<const int64_t> global_device_ids() const {
169 return global_device_ids_;
170 }
171
172 const std::vector<std::string>& global_devices() const {
173 return global_devices_;
174 }
175 // Returns index of given dim_name in the mesh.
176 StatusOr<int32> idx_for_dim(absl::string_view dim_name) const;
177
178 // Returns the index of MeshDimension in mesh where the mesh dimension name is
179 // `mesh_name`.
180 int GetMeshDimIndexWithName(const std::string& mesh_name) const;
181 bool IsMeshDim(const std::string& dim_name) const;
182
183 int64 rank() const;
184 int64 size() const;
185 const std::string& name() const { return name_; }
186
187 // Global unique fingerprint. Same on different workers.
188 uint64 GlobalFingerprint() const;
189
190 bool operator==(const Mesh& b) const;
191 bool operator!=(const Mesh& b) const { return !((*this) == b); }
192 bool operator<(const Mesh& b) const {
193 return this->ToString() < b.ToString();
194 }
195
196 template <typename H>
197 friend H AbslHashValue(H h, const Mesh& m) {
198 return H::combine(std::move(h), m.ToString());
199 }
200
201 // A map from mesh names to their corresponding core ID mappings. The core ID
202 // mapping is stored as a vector. The i-th element in the vector is the ID of
203 // the core represented by global device ID of i in this mesh.
204 //
205 // The entry stored under the empty name key (the so-called "default mapping"
206 // in some comments) is special. It is always set at the end of TPU
207 // initialization. It represents the mapping for any mesh whose global device
208 // IDs follow TF task-device ordinals. Legacy and test meshes created without
209 // using the `create_tpu_mesh` helper follow that rule and can use this entry.
210 static std::map<std::string, std::vector<int>>& tpu_core_ids();
211
212 // The host mesh associated with any user-defined TPU mesh.
213 static std::string& tpu_host_mesh();
214
215 private:
216 std::string name_;
217 std::vector<MeshDimension> mesh_dims_;
218 std::vector<std::string> local_devices_;
219 std::vector<int64_t> local_device_ids_;
220 std::vector<int64_t> global_device_ids_;
221 std::vector<std::string> global_devices_;
222};
223
224class Layout {
225 public:
226 static constexpr const char* kUnshardedDim = "unsharded";
227 // This spec should only be used to express no preferred sharding in the
228 // Layout propagation algorithm.
229 static constexpr const char* kAny = "any";
230 // Failed serialized strings are represented with en empty string, therefore
231 // we use this string representation of an empty layout instead to avoid
232 // confusion.
233 static constexpr const char* kEmptyLayoutString = "empty_layout";
234 // Used for the relayout operation, to allow relayout act as an identity on
235 // the layout for the given dimension.
236 static constexpr const char* kMatch = "match";
237
238 // Returns empty layout.
239 static Layout Empty();
240
241 // Parses from LayoutProto.
242 static StatusOr<Layout> FromProto(const LayoutProto& proto);
243 // Parses from a human readable string version of the layout, currently used
244 // to represent layouts in MLIR:
245 // layout = <sharding_specs:List[specs] mesh:name|List[MeshDim]|
246 // List[GlobalId]|List[LocalId]|List[Devices]>
247 //
248 // Example:
249 // layout = <sharding_specs:x,not_sharded mesh:name|x=2,y=2|0,1,2,3|0,1,2,3|
250 // /job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,
251 // /job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
252 static StatusOr<Layout> FromString(std::string layout_str);
253 // Creates human readable string version of a layout.
254 std::string ToString() const;
255 LayoutProto ToProto() const;
256
257 const Mesh& mesh() const { return mesh_; }
258 static Layout ReplicatedOnMesh(const Mesh& mesh, int rank);
259 static Layout AnyOnMesh(const Mesh& mesh, int rank);
260 // Creates a mesh of unique shards.
261 Mesh ReducedMesh() const;
262 void set_mesh(Mesh mesh) { mesh_ = mesh; }
263
264 // Returns a layout for the transposed matrix for given layout. This assumes
265 // that only the last two dimensions are used for matrix computation and all
266 // dimensions before are batch dimensions.
267 static StatusOr<Layout> Transposed2D(const Layout& layout);
268 static bool IsUnshardedDimension(const absl::string_view name) {
269 return name == kUnshardedDim;
270 }
271 static bool IsShardedDimension(const absl::string_view name) {
272 return !IsUnshardedDimension(name);
273 }
274 static bool IsUnshardedSpec(const ShardingSpec& spec) {
275 return IsUnshardedDimension(spec.sharding_spec());
276 }
277 static bool IsShardedSpec(const ShardingSpec& spec) {
278 return !IsUnshardedDimension(spec.sharding_spec());
279 }
280 static StatusOr<Layout> GetLayout(
281 const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh);
282 static StatusOr<Layout> GetLayout(
283 const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh);
284
285 // Makes a new layout from this one dropping the given dimensions.
286 // If keep_dims is true, the dimensions are replicated rather than
287 // deleted.
288 Layout GetLayoutWithReducedDims(const absl::flat_hash_set<int>& reduced_dims,
289 bool keep_dims) const;
290
291 // Truncates a layout at the front or back, depending on the value of end.
292 // end = false returns the layout upto the split point,
293 // end = true returns the layout from the split point.
294 Layout Truncate(int64 split_point, bool end = false) const;
295
296 // Left or right pad the layout to a max rank.
297 Layout LeftPad(int64 rank) const;
298
299 bool IsFullyReplicated() const;
300 bool IsLastDimReplicated() const;
301 // Checks that the last N-1 dimensions are replicated
302 bool IsBatchParallel() const;
303 // Checks that the dimensions from [-non_batch_rank, end) are replicaed
304 bool IsBatchParallel(int non_batch_rank) const;
305 bool IsEmpty() const;
306
307 // Compute global shape using the layout and provided local_shape.
308 std::vector<int64_t> GlobalShapeFromLocalShape(
309 const std::vector<int64_t>& local_shape) const;
310
311 std::vector<int64_t> LocalShapeFromGlobalShape(
312 absl::Span<const int64_t> global_shape) const;
313 PartialTensorShape LocalShapeFromGlobalShape(
314 const PartialTensorShape& global_shape) const;
315
316 int64 rank() const { return sharding_specs_.size(); }
317 size_t num_shards_for_dim(const ShardingSpec& dim) const;
318 std::vector<int32> num_shards() const;
319
320 const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; }
321 absl::Span<const ShardingSpec> sharding_specs() const {
322 return sharding_specs_;
323 }
324
325 // Computes the corresponding shard vector to this layout.
326 ShardVector GetShardVector() const;
327
328 // Returns sharding specs in string form.
329 std::vector<std::string> sharding_spec_strs() const;
330
331 int64 num_devices() const { return mesh_.num_devices(); }
332 StatusOr<const DeviceLocation> device_location(int64 device_id) const {
333 return mesh_.device_location(device_id);
334 }
335 // Map hosts to shards.
336 std::map<std::string, ShardVector> HostShardMap() const;
337
338 const std::string& sharding_spec(int idx) const;
339
340 // Two layouts are equivalent if they would result in the same sharding for
341 // the tensor. E.g. if on is unsharded and the other is sharded on a mesh
342 // dimension of size 1.
343 bool IsEquivalent(const Layout& b) const;
344 bool operator==(const Layout& b) const;
345 bool operator!=(const Layout& b) const { return !((*this) == b); }
346 bool operator<(const Layout& b) const {
347 return this->ToString() < b.ToString();
348 }
349
350 private:
351 std::vector<ShardingSpec> sharding_specs_;
352 Mesh mesh_;
353};
354
355// Takes two layouts and concatenates their TensorDimensions. If the meshes for
356// the two layouts are different or both layouts are using the same mesh
357// dimension returns an error rather than a layout.
358StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a,
359 const Layout& layout_b);
360
361} // namespace dtensor
362} // namespace tensorflow
363
364#endif // TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
365