1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/io_ops.cc.
17
18#include <string>
19#include <vector>
20
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/resource_mgr.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/framework/types.pb.h"
27#include "tensorflow/core/kernels/checkpoint_callback_manager.h"
28#include "tensorflow/core/kernels/save_restore_tensor.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/io/path.h"
31#include "tensorflow/core/platform/env.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/types.h"
34#include "tensorflow/core/util/saved_tensor_slice_util.h"
35#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
36#include "tensorflow/core/util/tensor_slice_reader.h"
37
38namespace tensorflow {
39
40namespace {
41
42// Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
43void ValidateInputs(bool is_save_op, OpKernelContext* context,
44 const Tensor& prefix, const Tensor& tensor_names,
45 const Tensor& shape_and_slices) {
46 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
47 const int num_tensors = static_cast<int>(tensor_names.NumElements());
48 OP_REQUIRES(
49 context, prefix.NumElements() == 1,
50 errors::InvalidArgument("Input prefix should have a single element, got ",
51 prefix.NumElements(), " instead."));
52 OP_REQUIRES(context,
53 TensorShapeUtils::IsVector(tensor_names.shape()) &&
54 TensorShapeUtils::IsVector(shape_and_slices.shape()),
55 errors::InvalidArgument(
56 "Input tensor_names and shape_and_slices "
57 "should be an 1-D tensors, got ",
58 tensor_names.shape().DebugString(), " and ",
59 shape_and_slices.shape().DebugString(), " instead."));
60 OP_REQUIRES(context,
61 tensor_names.NumElements() == shape_and_slices.NumElements(),
62 errors::InvalidArgument("tensor_names and shape_and_slices "
63 "have different number of elements: ",
64 tensor_names.NumElements(), " vs. ",
65 shape_and_slices.NumElements()));
66 OP_REQUIRES(context,
67 FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
68 std::numeric_limits<int>::max()),
69 errors::InvalidArgument("Too many inputs to the op"));
70 OP_REQUIRES(
71 context, shape_and_slices.NumElements() == num_tensors,
72 errors::InvalidArgument("Expected ", num_tensors,
73 " elements in shapes_and_slices, but got ",
74 context->input(2).NumElements()));
75 if (is_save_op) {
76 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
77 errors::InvalidArgument(
78 "Got ", num_tensors, " tensor names but ",
79 context->num_inputs() - kFixedInputs, " tensors."));
80 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
81 errors::InvalidArgument(
82 "Expected a total of ", num_tensors + kFixedInputs,
83 " inputs as input #1 (which is a string "
84 "tensor of saved names) contains ",
85 num_tensors, " names, but received ", context->num_inputs(),
86 " inputs"));
87 }
88}
89
90} // namespace
91
92// Saves a list of named tensors using the tensor bundle library.
93class SaveV2 : public OpKernel {
94 public:
95 explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
96
97 void Compute(OpKernelContext* context) override {
98 const Tensor& prefix = context->input(0);
99 const Tensor& tensor_names = context->input(1);
100 const Tensor& shape_and_slices = context->input(2);
101 ValidateInputs(true /* is save op */, context, prefix, tensor_names,
102 shape_and_slices);
103 if (!context->status().ok()) return;
104
105 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
106 const int num_tensors = static_cast<int>(tensor_names.NumElements());
107 const string& prefix_string = prefix.scalar<tstring>()();
108 const auto& tensor_names_flat = tensor_names.flat<tstring>();
109 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
110
111 BundleWriter writer(Env::Default(), prefix_string);
112 OP_REQUIRES_OK(context, writer.status());
113 VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
114
115 for (int i = 0; i < num_tensors; ++i) {
116 const string& tensor_name = tensor_names_flat(i);
117 const Tensor& tensor = context->input(i + kFixedInputs);
118 VLOG(2) << "Starting save of " << tensor_name;
119
120 if (!shape_and_slices_flat(i).empty()) {
121 const string& shape_spec = shape_and_slices_flat(i);
122 TensorShape shape;
123 TensorSlice slice(tensor.dims());
124 TensorShape slice_shape;
125
126 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
127 shape_spec, &shape, &slice, &slice_shape));
128 OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
129 errors::InvalidArgument("Slice in shape_and_slice "
130 "specification does not match the "
131 "shape of the tensor to save: ",
132 shape_spec, ", tensor: ",
133 tensor.shape().DebugString()));
134
135 OP_REQUIRES_OK(context,
136 writer.AddSlice(tensor_name, shape, slice, tensor));
137 } else {
138 OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
139 }
140
141 if (VLOG_IS_ON(5)) {
142 if (tensor.dtype() == DT_FLOAT) {
143 const float* t_data = tensor.flat<float>().data();
144 float min = std::numeric_limits<float>::infinity();
145 float max = -std::numeric_limits<float>::infinity();
146 double avg = 0.0;
147 for (int i = 0; i < tensor.NumElements(); ++i) {
148 if (t_data[i] < min) min = t_data[i];
149 if (t_data[i] > max) max = t_data[i];
150 avg += t_data[i];
151 }
152 VLOG(5) << " min " << min << " max " << max << " avg "
153 << avg / tensor.NumElements() << " total elts "
154 << tensor.NumElements();
155 }
156 }
157
158 VLOG(2) << "Done save of " << tensor_name;
159 }
160 OP_REQUIRES_OK(context, writer.Finish());
161 VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
162
163 ResourceMgr* resource_manager = context->resource_manager();
164 if (resource_manager != nullptr) {
165 checkpoint::CheckpointCallbackManager* checkpoint_callback_manager;
166 OP_REQUIRES_OK(
167 context,
168 resource_manager
169 ->LookupOrCreate<checkpoint::CheckpointCallbackManager>(
170 resource_manager->default_container(),
171 std::string(
172 checkpoint::kCheckpointCallbackManagerResourceName),
173 &checkpoint_callback_manager,
174 [](checkpoint::CheckpointCallbackManager** out) {
175 *out = new checkpoint::CheckpointCallbackManager();
176 return OkStatus();
177 }));
178 checkpoint_callback_manager->Save(prefix_string);
179 checkpoint_callback_manager->Unref();
180 }
181 }
182};
183REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
184
185// Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
186class RestoreV2 : public OpKernel {
187 public:
188 explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
189 OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
190 }
191
192 void Compute(OpKernelContext* context) override {
193 const Tensor& prefix = context->input(0);
194 const Tensor& tensor_names = context->input(1);
195 const Tensor& shape_and_slices = context->input(2);
196 OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
197 errors::InvalidArgument("Got ", tensor_names.NumElements(),
198 " tensor names, but ", dtypes_.size(),
199 " expected dtypes."));
200 ValidateInputs(false /* not save op */, context, prefix, tensor_names,
201 shape_and_slices);
202 if (!context->status().ok()) return;
203
204 const string& prefix_string = prefix.scalar<tstring>()();
205
206 // Intention: we plan to use the RestoreV2 op as a backward-compatible
207 // reader as we upgrade to the V2 format. This allows transparent upgrade.
208 // We here attempt to read a V1 checkpoint, if "prefix_string" does not
209 // refer to a V2 checkpoint.
210 Env* env = Env::Default();
211 std::vector<string> paths;
212 if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
213 paths.empty()) {
214 // Cannot find V2's metadata file, so "prefix_string" does not point to a
215 // V2 checkpoint. Invokes the V1 read path instead.
216 for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
217 RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
218 /* preferred_shard */ -1, /* restore_slice */ true,
219 /* restore_index */ i);
220 if (!context->status().ok()) {
221 return;
222 }
223 }
224 return;
225 }
226 // If found, invokes the V2 reader.
227 OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
228 shape_and_slices, dtypes_));
229
230 ResourceMgr* resource_manager = context->resource_manager();
231 if (resource_manager != nullptr) {
232 checkpoint::CheckpointCallbackManager* checkpoint_callback_manager;
233 OP_REQUIRES_OK(
234 context,
235 resource_manager
236 ->LookupOrCreate<checkpoint::CheckpointCallbackManager>(
237 resource_manager->default_container(),
238 std::string(
239 checkpoint::kCheckpointCallbackManagerResourceName),
240 &checkpoint_callback_manager,
241 [](checkpoint::CheckpointCallbackManager** out) {
242 *out = new checkpoint::CheckpointCallbackManager();
243 return OkStatus();
244 }));
245 checkpoint_callback_manager->Restore(prefix_string);
246 checkpoint_callback_manager->Unref();
247 }
248 }
249
250 private:
251 // Expected dtypes of the to-restore tensors.
252 std::vector<DataType> dtypes_;
253};
254REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
255
256// The final step in saving sharded V2 checkpoints: merges metadata files.
257class MergeV2Checkpoints : public OpKernel {
258 public:
259 explicit MergeV2Checkpoints(OpKernelConstruction* context)
260 : OpKernel(context) {
261 OP_REQUIRES_OK(context,
262 context->GetAttr("delete_old_dirs", &delete_old_dirs_));
263 OP_REQUIRES_OK(context, context->GetAttr("allow_missing_files",
264 &allow_missing_files_));
265 }
266
267 void Compute(OpKernelContext* context) override {
268 const Tensor& checkpoint_prefixes = context->input(0);
269 const Tensor& destination_prefix = context->input(1);
270 OP_REQUIRES(context,
271 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
272 errors::InvalidArgument(
273 "Input checkpoint_prefixes should be an 1-D tensor, got ",
274 checkpoint_prefixes.shape().DebugString(), " instead."));
275 OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
276 errors::InvalidArgument(
277 "Input destination_prefix should be a scalar tensor, got ",
278 destination_prefix.shape().DebugString(), " instead."));
279
280 const gtl::ArraySlice<tstring> input_prefixes =
281 gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
282 Env* env = Env::Default();
283 const string& merged_prefix = destination_prefix.scalar<tstring>()();
284 OP_REQUIRES_OK(context,
285 tensorflow::MergeBundles(env, input_prefixes, merged_prefix,
286 allow_missing_files_));
287
288 if (delete_old_dirs_) {
289 const string merged_dir(io::Dirname(merged_prefix));
290 for (const string& input_prefix : input_prefixes) {
291 const string dirname(io::Dirname(input_prefix));
292 if (dirname == merged_dir) continue;
293 Status status = env->DeleteDir(dirname);
294 // For sharded save, only the first delete will go through and all
295 // others will hit NotFound. Use vlog to be less verbose.
296 if (!status.ok()) VLOG(1) << status;
297 }
298 }
299 }
300
301 private:
302 // On merge, whether or not to delete the input (temporary) directories.
303 bool delete_old_dirs_;
304
305 // On merge, whether or not to relax condition that all input prefix filenames
306 // to exist.
307 bool allow_missing_files_;
308};
309REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
310 MergeV2Checkpoints);
311
312} // namespace tensorflow
313