1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
17#define TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
18
19#include <string>
20#include <utility>
21
22#include "absl/container/flat_hash_map.h"
23#include "mlir/IR/Builders.h" // from @llvm-project
24#include "mlir/IR/Value.h" // from @llvm-project
25#include "tensorflow/dtensor/cc/dstatus.h"
26#include "tensorflow/dtensor/cc/tensor_layout.h"
27
28namespace tensorflow {
29namespace dtensor {
30
31// Defines an Metadata entry when saving a Tensor.
32struct SavingTensorMetadata {
33 // Tracks index from the original save op.
34 int64_t tensor_index;
35 // The global shape of the saving tensor.
36 std::vector<int64_t> shape;
37 // The layout of the saving tensor.
38 Layout layout;
39
40 SavingTensorMetadata(int64_t index, std::vector<int64_t> global_shape,
41 Layout tensor_layout)
42 : tensor_index(index),
43 shape(std::move(global_shape)),
44 layout(std::move(tensor_layout)) {}
45};
46
47// Tracks a complete specification for a particular save op.
48// The users would build out multiple save ops using the following manner for
49// the given fields:
50//
51// save_op[i] = tf.SaveV2(
52// prefix = new_prefixes[i],
53// tensor_indices = tensor_indies[i],
54// shape_and_slices = shape_and_slice_spec[i])
55struct SaveOpSpecs {
56 std::vector<mlir::Value> new_prefixes;
57 std::vector<std::vector<int>> tensor_indices;
58 std::vector<std::vector<std::string>> shape_and_slice_spec;
59
60 SaveOpSpecs(std::vector<mlir::Value> prefixes,
61 std::vector<std::vector<int>> indices,
62 std::vector<std::vector<std::string>> specs)
63 : new_prefixes(std::move(prefixes)),
64 tensor_indices(std::move(indices)),
65 shape_and_slice_spec(std::move(specs)) {}
66};
67
68// Returns a device suffix with printf formatting.
69std::string DeviceSuffix(int device_id, int total_devices);
70
71// Builds a complete saving specification for each device on the mesh.
72//
73// The returned map contains a map of <device_id, SavingSpec>.
74// Device_id is where the saving should happen, and SavingSpec is a
75// mapping of <tensor_index -> shape_and_slices>. e.g.,
76//
77// A map of {device_id : 0 -> {
78// 0 : "2 0,1",
79// 1 : ""
80// }
81// }
82//
83// Means that device_0 is responsible for saving tensor 0 and 1 from the passed
84// in tensors list. For tensor[0], it saves the only the first element in that
85// 1d vector with 2 elements. For tensor[1], it saves all elements.
86//
87// We accept another map as input, that records the mapping of
88// <tensor_index -> (tensor_global_shape, tensor_layout)>.
89//
90// (tensor_global_shape, tensor_layout & tensor_layout.mesh) defines which
91// device saves what slices of the Tensor.
92//
93// For a complete definition of shape_and_slices field, please see:
94// third_party/tensorflow/core/framework/tensor_slice.h
95StatusOr<absl::flat_hash_map<
96 int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>>
97BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas);
98
99// For a given per device saving spec, find out the counts of SaveV2 ops
100// needed and their corresponding inputs.
101//
102// Current SaveV2 op requires tensor_names to be unique in the list, which is a
103// contract that distributed saving would break. For example, if the saving spec
104// decides that device 0 is responsible for saving two slices of tensor[a], then
105// a single SaveV2 op can't fufill. The setup is very likely to happen when
106// saving on TPU - where 8 cores maps to 1 host. In that case, the CPU host will
107// be responsible for saving slices on the same tensor across 8 TPU cores.
108// TODO(b/179126981): Investigate whether we can make TF core API run with
109// different slice spec on a same tensor key.
110//
111// That said, building one SaveV2 op for each save is wasteful, when a single
112// SaveV2 op is capable of saving different tensors. Instead, we simply need to
113// break the SaveV2 op to be able to track the longest saving specs for a single
114// tensor happening on the device, e.g.,
115//
116// For given saving specs:
117//
118// { 'tensor_name_a' : <"spec_a", "spec_a_2"> }
119// { 'tensor_name_b' : <"spec_b"> }
120//
121// would result into two save ops, where:
122//
123// SaveOp1 (tensor_names = <"tensor_name_a", tensor_name_b">,
124// slice_spec = <"spec_a", "spec_b">)
125//
126// SaveOp2 (tensor_names = "<tensor_name_a>", slice_spec = <"spec_a_2">.
127//
128// The output vectors tracks the new SaveV2 op parameters and they must agree on
129// size and indexing for saving tensors.
130//
131// tensor_indices trackes a list of indices of tensors that are being saved for
132// each Save op, e.g.,
133//
134// tensor_indices[0] is a list of tensors (in index form) that needs to be saved
135// on the first SaveV2 op.
136//
137// shape_and_slice_specs tracks a list of shape_and_slice_specs being saved for
138// each Save op, e.g.,
139//
140// shape_and_slice_spec[0] is a list of shape_and_slices parameters for SaveV2
141// op.
142SaveOpSpecs BuildPerDeviceSave(
143 mlir::OpBuilder& builder,
144 const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec,
145 int device_id, mlir::Value prefix, int total_devices);
146
147// Figures out the tensor slice_spec for a given layout and mesh device
148// location.
149StatusOr<std::vector<std::string>> SliceSpecOnDevice(
150 const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords,
151 absl::Span<const int64_t> global_shape);
152} // namespace dtensor
153
154} // namespace tensorflow
155
156#endif // TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
157