1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/save_restore_util.h"
17
18#include "llvm/ADT/SmallVector.h"
19#include "mlir/IR/Builders.h" // from @llvm-project
20#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21#include "tensorflow/dtensor/mlir/value_utils.h"
22namespace tensorflow {
23namespace dtensor {
24
25namespace {
26// A map that is keyed by the index of tensor_name.
27// For example, {2 : <"spec_a", "spec_b"> } means that the
28// save_v2.tensor_names[2] should have "spec_a" and "spec_b" saved.
29using SliceSpecByName = absl::flat_hash_map<int64_t, std::vector<std::string>>;
30
31// Builds a map from tensor slice spec to saving device_id for the given Tensor
32// and layout. The output would record the saving device and the slices it needs
33// to save.
34//
35// For each sharded Tensor, each device would hold a slice of the Tensor - but
36// it isn't necessary a unique copy. For a 2 way sharded Tensor in a (2,4) mesh
37// on the first dimension, device [0-3] and device [4-7] will hold the same
38// slice data. To avoid saving duplicated copies of the Tensor slice, the map
39// would only contain the min(device_id) that occupies the slice and save from
40// there.
41//
42// Furthermore, to save a Tensor that isn't on CPU mesh, send/recv is necessary
43// from saving device to its corresponding host(CPU) devices. Since we don't
44// have multi-mesh execution yet, this isn't implemented yet.
45StatusOr<SliceSpecByName> BuildSliceSpecDeviceMap(
46 absl::Span<const int64_t> global_shape, Layout layout) {
47 if (!layout.mesh().is_cpu_mesh())
48 return errors::Unimplemented(
49 "Saving tensors on non CPU mesh needs explicit send/receive and isn't "
50 "implemented yet");
51
52 // Result map that records the minimum device_id that occupies the unique
53 // copy.
54 // Note that llvm::SmallDenseMap won't accept std::string as a key.
55 absl::flat_hash_map<std::string, int64_t> min_device_for_slice_spec;
56 // Records the map of device_ids and a list of slice_spec that it needs to
57 // save.
58 SliceSpecByName device_slices;
59
60 const auto& mesh = layout.mesh();
61 // Construct SliceSpec for each device in the mesh.
62 for (int device_id = 0; device_id < mesh.size(); ++device_id) {
63 TF_ASSIGN_OR_RETURN(const DeviceLocation& coords,
64 mesh.device_location(device_id));
65 // Prefill with full spec on each dim.
66 TF_ASSIGN_OR_RETURN(std::vector<std::string> slice_specs,
67 SliceSpecOnDevice(layout, mesh, coords, global_shape));
68
69 // Build the real slice_spec from string pieces.
70 std::string slice_spec = absl::StrJoin(slice_specs, ":");
71 // Get local shape from the global shape.
72 std::string shape_spec = absl::StrJoin(global_shape, " ");
73 // Concat shape spec and slice spec to form a complete shape_and_slice.
74 std::string shape_and_slice = absl::StrCat(shape_spec, " ", slice_spec);
75
76 // Only record the min device_id for the unique slice_spec on a given
77 // Tensor.
78 if (min_device_for_slice_spec.find(shape_and_slice) ==
79 min_device_for_slice_spec.end() ||
80 device_id < min_device_for_slice_spec[shape_and_slice]) {
81 min_device_for_slice_spec[shape_and_slice] = device_id;
82 }
83 }
84
85 // Constructs device_id keyed map for future save operation conditioned on
86 // device_ids.
87 for (const auto& spec_and_id : min_device_for_slice_spec) {
88 device_slices[spec_and_id.second].push_back(spec_and_id.first);
89 }
90
91 return device_slices;
92}
93
94} // namespace
95
96// Example is _dev-02-of-16.
97std::string DeviceSuffix(int device_id, int total_devices) {
98 return absl::StrFormat("_dev-%0*d-of-%d", absl::StrCat(total_devices).size(),
99 device_id, total_devices);
100}
101
102StatusOr<absl::flat_hash_map<
103 int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>>
104BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas) {
105 absl::flat_hash_map<int64_t,
106 absl::flat_hash_map<int64_t, std::vector<std::string>>>
107 saving_specs;
108 for (const SavingTensorMetadata& tensor_metadata : tensor_metadatas) {
109 // We use index to select the tensor names and shape_and_slices from the
110 // inputs. This is generic regardless whether the inputs are constants or
111 // just arguments.
112 int index = tensor_metadata.tensor_index;
113 const Layout& layout = tensor_metadata.layout;
114 absl::Span<const int64_t> tensor_shape = tensor_metadata.shape;
115
116 if (layout.IsFullyReplicated()) {
117 // Push a fully replicated save on device 0, where slice_spec is simply
118 // empty string.
119 saving_specs[0][index].push_back("");
120 } else {
121 // Calculate shape_and_slices for sharded case here.
122 TF_ASSIGN_OR_RETURN(const auto& slice_specs,
123 BuildSliceSpecDeviceMap(tensor_shape, layout));
124 // Push specs for each device into the global map.
125 for (const auto& slice_spec : slice_specs) {
126 int64_t saving_device_id = slice_spec.first;
127 for (const std::string& slice : slice_spec.second) {
128 saving_specs[saving_device_id][index].push_back(slice);
129 }
130 }
131 }
132 }
133
134 return saving_specs;
135}
136
137SaveOpSpecs BuildPerDeviceSave(
138 mlir::OpBuilder& builder,
139 const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec,
140 int device_id, mlir::Value prefix, int total_devices) {
141 std::vector<mlir::Value> new_prefixes;
142 std::vector<std::vector<int>> tensor_indices;
143 std::vector<std::vector<std::string>> shape_and_slice_specs;
144 for (const auto& tensor_name_index_and_slice_specs : saving_spec) {
145 int tensor_index = tensor_name_index_and_slice_specs.first;
146 const std::vector<std::string> specs =
147 tensor_name_index_and_slice_specs.second;
148 // For each tensor_name, we save its first slice_spec in the first
149 // save_op, second slice_spec in the second save op, etc.
150 // This allows us to group save ops together without running into
151 // duplicated tensor_names (which save_v2 op doesn't support).
152 for (int save_op_index = 0; save_op_index < specs.size(); ++save_op_index) {
153 if (save_op_index >= tensor_indices.size()) {
154 tensor_indices.push_back({});
155 shape_and_slice_specs.push_back({});
156
157 mlir::Value new_prefix =
158 builder
159 .create<mlir::TF::AddOp>(
160 prefix.getLoc(),
161 prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix,
162 StringScalarConst(builder, prefix.getLoc(),
163 DeviceSuffix(device_id, total_devices)))
164 .z();
165 // Generate new prefix based on device_id and save op index, only when
166 // we need a new save_op.
167 new_prefixes.push_back(new_prefix);
168 }
169 tensor_indices[save_op_index].push_back(tensor_index);
170 shape_and_slice_specs[save_op_index].push_back(specs[save_op_index]);
171 }
172 }
173
174 return SaveOpSpecs(new_prefixes, tensor_indices, shape_and_slice_specs);
175}
176
177StatusOr<std::vector<std::string>> SliceSpecOnDevice(
178 const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords,
179 absl::Span<const int64_t> global_shape) {
180 // Prefill the slice with replicated layouts.
181 std::vector<std::string> slice_specs(global_shape.size(), "-");
182
183 const std::vector<std::string>& sharding_spec_strs =
184 layout.sharding_spec_strs();
185 for (int tensor_dim_index = 0; tensor_dim_index < sharding_spec_strs.size();
186 ++tensor_dim_index) {
187 const std::string& mesh_dim = sharding_spec_strs[tensor_dim_index];
188 if (layout.IsShardedDimension(mesh_dim)) {
189 TF_ASSIGN_OR_RETURN(int mesh_dim_index, mesh.idx_for_dim(mesh_dim));
190 TF_ASSIGN_OR_RETURN(int64_t dim_size, mesh.dim_size(mesh_dim));
191 int64_t per_slice_size = global_shape[tensor_dim_index] / dim_size;
192 int start = device_coords[mesh_dim_index] * per_slice_size;
193 slice_specs[tensor_dim_index] = absl::StrCat(start, ",", per_slice_size);
194 }
195 }
196 return slice_specs;
197}
198
199} // namespace dtensor
200} // namespace tensorflow
201