1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/save_restore_util.h" |
17 | |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "mlir/IR/Builders.h" // from @llvm-project |
20 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
21 | #include "tensorflow/dtensor/mlir/value_utils.h" |
22 | namespace tensorflow { |
23 | namespace dtensor { |
24 | |
25 | namespace { |
26 | // A map that is keyed by the index of tensor_name. |
27 | // For example, {2 : <"spec_a", "spec_b"> } means that the |
28 | // save_v2.tensor_names[2] should have "spec_a" and "spec_b" saved. |
29 | using SliceSpecByName = absl::flat_hash_map<int64_t, std::vector<std::string>>; |
30 | |
31 | // Builds a map from tensor slice spec to saving device_id for the given Tensor |
32 | // and layout. The output would record the saving device and the slices it needs |
33 | // to save. |
34 | // |
35 | // For each sharded Tensor, each device would hold a slice of the Tensor - but |
36 | // it isn't necessary a unique copy. For a 2 way sharded Tensor in a (2,4) mesh |
37 | // on the first dimension, device [0-3] and device [4-7] will hold the same |
38 | // slice data. To avoid saving duplicated copies of the Tensor slice, the map |
39 | // would only contain the min(device_id) that occupies the slice and save from |
40 | // there. |
41 | // |
42 | // Furthermore, to save a Tensor that isn't on CPU mesh, send/recv is necessary |
43 | // from saving device to its corresponding host(CPU) devices. Since we don't |
44 | // have multi-mesh execution yet, this isn't implemented yet. |
45 | StatusOr<SliceSpecByName> BuildSliceSpecDeviceMap( |
46 | absl::Span<const int64_t> global_shape, Layout layout) { |
47 | if (!layout.mesh().is_cpu_mesh()) |
48 | return errors::Unimplemented( |
49 | "Saving tensors on non CPU mesh needs explicit send/receive and isn't " |
50 | "implemented yet" ); |
51 | |
52 | // Result map that records the minimum device_id that occupies the unique |
53 | // copy. |
54 | // Note that llvm::SmallDenseMap won't accept std::string as a key. |
55 | absl::flat_hash_map<std::string, int64_t> min_device_for_slice_spec; |
56 | // Records the map of device_ids and a list of slice_spec that it needs to |
57 | // save. |
58 | SliceSpecByName device_slices; |
59 | |
60 | const auto& mesh = layout.mesh(); |
61 | // Construct SliceSpec for each device in the mesh. |
62 | for (int device_id = 0; device_id < mesh.size(); ++device_id) { |
63 | TF_ASSIGN_OR_RETURN(const DeviceLocation& coords, |
64 | mesh.device_location(device_id)); |
65 | // Prefill with full spec on each dim. |
66 | TF_ASSIGN_OR_RETURN(std::vector<std::string> slice_specs, |
67 | SliceSpecOnDevice(layout, mesh, coords, global_shape)); |
68 | |
69 | // Build the real slice_spec from string pieces. |
70 | std::string slice_spec = absl::StrJoin(slice_specs, ":" ); |
71 | // Get local shape from the global shape. |
72 | std::string shape_spec = absl::StrJoin(global_shape, " " ); |
73 | // Concat shape spec and slice spec to form a complete shape_and_slice. |
74 | std::string shape_and_slice = absl::StrCat(shape_spec, " " , slice_spec); |
75 | |
76 | // Only record the min device_id for the unique slice_spec on a given |
77 | // Tensor. |
78 | if (min_device_for_slice_spec.find(shape_and_slice) == |
79 | min_device_for_slice_spec.end() || |
80 | device_id < min_device_for_slice_spec[shape_and_slice]) { |
81 | min_device_for_slice_spec[shape_and_slice] = device_id; |
82 | } |
83 | } |
84 | |
85 | // Constructs device_id keyed map for future save operation conditioned on |
86 | // device_ids. |
87 | for (const auto& spec_and_id : min_device_for_slice_spec) { |
88 | device_slices[spec_and_id.second].push_back(spec_and_id.first); |
89 | } |
90 | |
91 | return device_slices; |
92 | } |
93 | |
94 | } // namespace |
95 | |
96 | // Example is _dev-02-of-16. |
97 | std::string DeviceSuffix(int device_id, int total_devices) { |
98 | return absl::StrFormat("_dev-%0*d-of-%d" , absl::StrCat(total_devices).size(), |
99 | device_id, total_devices); |
100 | } |
101 | |
102 | StatusOr<absl::flat_hash_map< |
103 | int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>> |
104 | BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas) { |
105 | absl::flat_hash_map<int64_t, |
106 | absl::flat_hash_map<int64_t, std::vector<std::string>>> |
107 | saving_specs; |
108 | for (const SavingTensorMetadata& tensor_metadata : tensor_metadatas) { |
109 | // We use index to select the tensor names and shape_and_slices from the |
110 | // inputs. This is generic regardless whether the inputs are constants or |
111 | // just arguments. |
112 | int index = tensor_metadata.tensor_index; |
113 | const Layout& layout = tensor_metadata.layout; |
114 | absl::Span<const int64_t> tensor_shape = tensor_metadata.shape; |
115 | |
116 | if (layout.IsFullyReplicated()) { |
117 | // Push a fully replicated save on device 0, where slice_spec is simply |
118 | // empty string. |
119 | saving_specs[0][index].push_back("" ); |
120 | } else { |
121 | // Calculate shape_and_slices for sharded case here. |
122 | TF_ASSIGN_OR_RETURN(const auto& slice_specs, |
123 | BuildSliceSpecDeviceMap(tensor_shape, layout)); |
124 | // Push specs for each device into the global map. |
125 | for (const auto& slice_spec : slice_specs) { |
126 | int64_t saving_device_id = slice_spec.first; |
127 | for (const std::string& slice : slice_spec.second) { |
128 | saving_specs[saving_device_id][index].push_back(slice); |
129 | } |
130 | } |
131 | } |
132 | } |
133 | |
134 | return saving_specs; |
135 | } |
136 | |
137 | SaveOpSpecs BuildPerDeviceSave( |
138 | mlir::OpBuilder& builder, |
139 | const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec, |
140 | int device_id, mlir::Value prefix, int total_devices) { |
141 | std::vector<mlir::Value> new_prefixes; |
142 | std::vector<std::vector<int>> tensor_indices; |
143 | std::vector<std::vector<std::string>> shape_and_slice_specs; |
144 | for (const auto& tensor_name_index_and_slice_specs : saving_spec) { |
145 | int tensor_index = tensor_name_index_and_slice_specs.first; |
146 | const std::vector<std::string> specs = |
147 | tensor_name_index_and_slice_specs.second; |
148 | // For each tensor_name, we save its first slice_spec in the first |
149 | // save_op, second slice_spec in the second save op, etc. |
150 | // This allows us to group save ops together without running into |
151 | // duplicated tensor_names (which save_v2 op doesn't support). |
152 | for (int save_op_index = 0; save_op_index < specs.size(); ++save_op_index) { |
153 | if (save_op_index >= tensor_indices.size()) { |
154 | tensor_indices.push_back({}); |
155 | shape_and_slice_specs.push_back({}); |
156 | |
157 | mlir::Value new_prefix = |
158 | builder |
159 | .create<mlir::TF::AddOp>( |
160 | prefix.getLoc(), |
161 | prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix, |
162 | StringScalarConst(builder, prefix.getLoc(), |
163 | DeviceSuffix(device_id, total_devices))) |
164 | .z(); |
165 | // Generate new prefix based on device_id and save op index, only when |
166 | // we need a new save_op. |
167 | new_prefixes.push_back(new_prefix); |
168 | } |
169 | tensor_indices[save_op_index].push_back(tensor_index); |
170 | shape_and_slice_specs[save_op_index].push_back(specs[save_op_index]); |
171 | } |
172 | } |
173 | |
174 | return SaveOpSpecs(new_prefixes, tensor_indices, shape_and_slice_specs); |
175 | } |
176 | |
177 | StatusOr<std::vector<std::string>> SliceSpecOnDevice( |
178 | const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords, |
179 | absl::Span<const int64_t> global_shape) { |
180 | // Prefill the slice with replicated layouts. |
181 | std::vector<std::string> slice_specs(global_shape.size(), "-" ); |
182 | |
183 | const std::vector<std::string>& sharding_spec_strs = |
184 | layout.sharding_spec_strs(); |
185 | for (int tensor_dim_index = 0; tensor_dim_index < sharding_spec_strs.size(); |
186 | ++tensor_dim_index) { |
187 | const std::string& mesh_dim = sharding_spec_strs[tensor_dim_index]; |
188 | if (layout.IsShardedDimension(mesh_dim)) { |
189 | TF_ASSIGN_OR_RETURN(int mesh_dim_index, mesh.idx_for_dim(mesh_dim)); |
190 | TF_ASSIGN_OR_RETURN(int64_t dim_size, mesh.dim_size(mesh_dim)); |
191 | int64_t per_slice_size = global_shape[tensor_dim_index] / dim_size; |
192 | int start = device_coords[mesh_dim_index] * per_slice_size; |
193 | slice_specs[tensor_dim_index] = absl::StrCat(start, "," , per_slice_size); |
194 | } |
195 | } |
196 | return slice_specs; |
197 | } |
198 | |
199 | } // namespace dtensor |
200 | } // namespace tensorflow |
201 | |