1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ |
17 | #define TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ |
18 | |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "absl/container/flat_hash_map.h" |
23 | #include "mlir/IR/Builders.h" // from @llvm-project |
24 | #include "mlir/IR/Value.h" // from @llvm-project |
25 | #include "tensorflow/dtensor/cc/dstatus.h" |
26 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace dtensor { |
30 | |
31 | // Defines an Metadata entry when saving a Tensor. |
32 | struct SavingTensorMetadata { |
33 | // Tracks index from the original save op. |
34 | int64_t tensor_index; |
35 | // The global shape of the saving tensor. |
36 | std::vector<int64_t> shape; |
37 | // The layout of the saving tensor. |
38 | Layout layout; |
39 | |
40 | SavingTensorMetadata(int64_t index, std::vector<int64_t> global_shape, |
41 | Layout tensor_layout) |
42 | : tensor_index(index), |
43 | shape(std::move(global_shape)), |
44 | layout(std::move(tensor_layout)) {} |
45 | }; |
46 | |
47 | // Tracks a complete specification for a particular save op. |
48 | // The users would build out multiple save ops using the following manner for |
49 | // the given fields: |
50 | // |
51 | // save_op[i] = tf.SaveV2( |
52 | // prefix = new_prefixes[i], |
53 | // tensor_indices = tensor_indies[i], |
54 | // shape_and_slices = shape_and_slice_spec[i]) |
55 | struct SaveOpSpecs { |
56 | std::vector<mlir::Value> new_prefixes; |
57 | std::vector<std::vector<int>> tensor_indices; |
58 | std::vector<std::vector<std::string>> shape_and_slice_spec; |
59 | |
60 | SaveOpSpecs(std::vector<mlir::Value> prefixes, |
61 | std::vector<std::vector<int>> indices, |
62 | std::vector<std::vector<std::string>> specs) |
63 | : new_prefixes(std::move(prefixes)), |
64 | tensor_indices(std::move(indices)), |
65 | shape_and_slice_spec(std::move(specs)) {} |
66 | }; |
67 | |
68 | // Returns a device suffix with printf formatting. |
69 | std::string DeviceSuffix(int device_id, int total_devices); |
70 | |
71 | // Builds a complete saving specification for each device on the mesh. |
72 | // |
73 | // The returned map contains a map of <device_id, SavingSpec>. |
74 | // Device_id is where the saving should happen, and SavingSpec is a |
75 | // mapping of <tensor_index -> shape_and_slices>. e.g., |
76 | // |
77 | // A map of {device_id : 0 -> { |
78 | // 0 : "2 0,1", |
79 | // 1 : "" |
80 | // } |
81 | // } |
82 | // |
83 | // Means that device_0 is responsible for saving tensor 0 and 1 from the passed |
84 | // in tensors list. For tensor[0], it saves the only the first element in that |
85 | // 1d vector with 2 elements. For tensor[1], it saves all elements. |
86 | // |
87 | // We accept another map as input, that records the mapping of |
88 | // <tensor_index -> (tensor_global_shape, tensor_layout)>. |
89 | // |
90 | // (tensor_global_shape, tensor_layout & tensor_layout.mesh) defines which |
91 | // device saves what slices of the Tensor. |
92 | // |
93 | // For a complete definition of shape_and_slices field, please see: |
94 | // third_party/tensorflow/core/framework/tensor_slice.h |
95 | StatusOr<absl::flat_hash_map< |
96 | int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>> |
97 | BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas); |
98 | |
99 | // For a given per device saving spec, find out the counts of SaveV2 ops |
100 | // needed and their corresponding inputs. |
101 | // |
102 | // Current SaveV2 op requires tensor_names to be unique in the list, which is a |
103 | // contract that distributed saving would break. For example, if the saving spec |
104 | // decides that device 0 is responsible for saving two slices of tensor[a], then |
105 | // a single SaveV2 op can't fufill. The setup is very likely to happen when |
106 | // saving on TPU - where 8 cores maps to 1 host. In that case, the CPU host will |
107 | // be responsible for saving slices on the same tensor across 8 TPU cores. |
108 | // TODO(b/179126981): Investigate whether we can make TF core API run with |
109 | // different slice spec on a same tensor key. |
110 | // |
111 | // That said, building one SaveV2 op for each save is wasteful, when a single |
112 | // SaveV2 op is capable of saving different tensors. Instead, we simply need to |
113 | // break the SaveV2 op to be able to track the longest saving specs for a single |
114 | // tensor happening on the device, e.g., |
115 | // |
116 | // For given saving specs: |
117 | // |
118 | // { 'tensor_name_a' : <"spec_a", "spec_a_2"> } |
119 | // { 'tensor_name_b' : <"spec_b"> } |
120 | // |
121 | // would result into two save ops, where: |
122 | // |
123 | // SaveOp1 (tensor_names = <"tensor_name_a", tensor_name_b">, |
124 | // slice_spec = <"spec_a", "spec_b">) |
125 | // |
126 | // SaveOp2 (tensor_names = "<tensor_name_a>", slice_spec = <"spec_a_2">. |
127 | // |
128 | // The output vectors tracks the new SaveV2 op parameters and they must agree on |
129 | // size and indexing for saving tensors. |
130 | // |
131 | // tensor_indices trackes a list of indices of tensors that are being saved for |
132 | // each Save op, e.g., |
133 | // |
134 | // tensor_indices[0] is a list of tensors (in index form) that needs to be saved |
135 | // on the first SaveV2 op. |
136 | // |
137 | // shape_and_slice_specs tracks a list of shape_and_slice_specs being saved for |
138 | // each Save op, e.g., |
139 | // |
140 | // shape_and_slice_spec[0] is a list of shape_and_slices parameters for SaveV2 |
141 | // op. |
142 | SaveOpSpecs BuildPerDeviceSave( |
143 | mlir::OpBuilder& builder, |
144 | const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec, |
145 | int device_id, mlir::Value prefix, int total_devices); |
146 | |
147 | // Figures out the tensor slice_spec for a given layout and mesh device |
148 | // location. |
149 | StatusOr<std::vector<std::string>> SliceSpecOnDevice( |
150 | const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords, |
151 | absl::Span<const int64_t> global_shape); |
152 | } // namespace dtensor |
153 | |
154 | } // namespace tensorflow |
155 | |
156 | #endif // TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ |
157 | |