1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/save_restore_tensor.h"
17
18#include <memory>
19#include <numeric>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/core/framework/bounds_check.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/framework/types.pb.h"
29#include "tensorflow/core/lib/core/threadpool.h"
30#include "tensorflow/core/lib/gtl/array_slice.h"
31#include "tensorflow/core/lib/strings/str_util.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33#include "tensorflow/core/lib/strings/stringprintf.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/types.h"
36#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
37#include "tensorflow/core/util/tensor_slice_reader.h"
38#include "tensorflow/core/util/tensor_slice_reader_cache.h"
39#include "tensorflow/core/util/tensor_slice_writer.h"
40
41namespace tensorflow {
42
43void SaveTensors(
44 OpKernelContext* context,
45 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
46 bool save_slices) {
47 const Tensor& filename_t = context->input(0);
48 {
49 const int64_t size = filename_t.NumElements();
50 OP_REQUIRES(
51 context, size == 1,
52 errors::InvalidArgument(
53 "Input 0 (filename) must be a string scalar; got a tensor of ",
54 size, "elements"));
55 }
56
57 // Path, names, and slices if save_slices is true.
58 const int kFixedInputs = save_slices ? 3 : 2;
59 const Tensor& tensor_names_t = context->input(1);
60 OP_REQUIRES(context,
61 FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
62 std::numeric_limits<int>::max()),
63 errors::InvalidArgument("Too many inputs to SaveTensors"));
64 const int N = static_cast<int>(tensor_names_t.NumElements());
65 const tstring* tensor_shapes_and_slices_ptr = nullptr;
66 if (save_slices) {
67 const Tensor& tensor_shapes_and_slices_t = context->input(2);
68 OP_REQUIRES(
69 context,
70 tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N),
71 errors::InvalidArgument("Expected ", N,
72 " elements for the tensor "
73 "shapes and slices but got ",
74 tensor_shapes_and_slices_t.NumElements()));
75 tensor_shapes_and_slices_ptr =
76 tensor_shapes_and_slices_t.flat<tstring>().data();
77 }
78 OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
79 errors::InvalidArgument("Expected totally ", N + kFixedInputs,
80 " inputs as input #1 (which is a string "
81 "tensor of saved names) contains ",
82 N, " names, but received ",
83 context->num_inputs(), " inputs"));
84
85 VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
86 << "...";
87 checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
88 std::move(builder_func));
89
90 Status s;
91 auto tensor_names_flat = tensor_names_t.flat<tstring>();
92
93 // Process tensors in sorted name order. This allows us to avoid seeking
94 // during restoration in the common case where we are restoring a full
95 // checkpoint.
96 // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't
97 // strictly necessary anymore. However, restores with TF version <= 2.7 will
98 // still benefit.
99 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
100 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
101 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
102 [&tensor_names_flat](size_t a, size_t b) {
103 return tensor_names_flat(a) < tensor_names_flat(b);
104 });
105
106 for (const size_t i : sorted_name_idx) {
107 const string& name = tensor_names_flat(i);
108 const Tensor& input = context->input(i + kFixedInputs);
109 TensorShape shape(input.shape());
110 TensorSlice slice(input.dims());
111 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
112 const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
113 TensorShape slice_shape;
114 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
115 shape_spec, &shape, &slice, &slice_shape));
116 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
117 errors::InvalidArgument(
118 "Slice in shape_and_slice "
119 "specification does not match the "
120 "shape of the tensor to save: ",
121 shape_spec, ", tensor: ", input.shape().DebugString()));
122 }
123
124#define WRITER_ADD(T) \
125 case DataTypeToEnum<T>::value: \
126 s = writer.Add(name, shape, slice, input.flat<T>().data()); \
127 break;
128
129 switch (input.dtype()) {
130 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
131 default:
132 context->SetStatus(errors::Unimplemented("Saving data type ",
133 DataTypeString(input.dtype()),
134 " not yet supported"));
135 return;
136 }
137#undef WRITER_ADD
138 if (!s.ok()) {
139 context->SetStatus(s);
140 return;
141 }
142 }
143
144 s = writer.Finish();
145 if (!s.ok()) {
146 context->SetStatus(s);
147 }
148}
149
150void RestoreTensor(OpKernelContext* context,
151 checkpoint::TensorSliceReader::OpenTableFunction open_func,
152 int preferred_shard, bool restore_slice, int restore_index) {
153 const Tensor& file_pattern_t = context->input(0);
154 {
155 const int64_t size = file_pattern_t.NumElements();
156 OP_REQUIRES(
157 context, size == 1,
158 errors::InvalidArgument(
159 "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
160 size, " elements"));
161 }
162 const string& file_pattern = file_pattern_t.flat<tstring>()(0);
163
164 const Tensor& tensor_name_t = context->input(1);
165 {
166 const int64_t size = tensor_name_t.NumElements();
167 OP_REQUIRES(context, size > restore_index,
168 errors::InvalidArgument(
169 "Input 1 (file_pattern) must be a have at least ",
170 restore_index + 1, " elements"));
171 }
172 const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
173
174 // If we cannot find a cached reader we will allocate our own.
175 std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
176
177 const checkpoint::TensorSliceReader* reader = nullptr;
178
179 if (context->slice_reader_cache()) {
180 reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
181 preferred_shard);
182 }
183 if (!reader) {
184 allocated_reader.reset(new checkpoint::TensorSliceReader(
185 file_pattern, open_func, preferred_shard));
186 reader = allocated_reader.get();
187 }
188 OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
189
190 // Get the shape and type from the save file.
191 DataType type;
192 TensorShape saved_shape;
193 OP_REQUIRES(
194 context, reader->HasTensor(tensor_name, &saved_shape, &type),
195 errors::NotFound("Tensor name \"", tensor_name,
196 "\" not found in checkpoint files ", file_pattern));
197 OP_REQUIRES(
198 context, type == context->expected_output_dtype(restore_index),
199 errors::InvalidArgument("Expected to restore a tensor of type ",
200 DataTypeString(context->expected_output_dtype(0)),
201 ", got a tensor of type ", DataTypeString(type),
202 " instead: tensor_name = ", tensor_name));
203
204 // Shape of the output and slice to load.
205 TensorShape output_shape(saved_shape);
206 TensorSlice slice_to_load(saved_shape.dims());
207 if (restore_slice) {
208 const tstring& shape_spec =
209 context->input(2).flat<tstring>()(restore_index);
210 if (!shape_spec.empty()) {
211 TensorShape parsed_shape;
212 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
213 shape_spec, &parsed_shape, &slice_to_load,
214 &output_shape));
215 OP_REQUIRES(
216 context, parsed_shape.IsSameSize(saved_shape),
217 errors::InvalidArgument(
218 "Shape in shape_and_slice spec does not match the shape in the "
219 "save file: ",
220 parsed_shape.DebugString(),
221 ", save file shape: ", saved_shape.DebugString()));
222 }
223 }
224
225 Tensor* t = nullptr;
226 OP_REQUIRES_OK(context,
227 context->allocate_output(restore_index, output_shape, &t));
228
229 if (output_shape.num_elements() == 0) return;
230
231#define READER_COPY(T) \
232 case DataTypeToEnum<T>::value: \
233 OP_REQUIRES(context, \
234 reader->CopySliceData(tensor_name, slice_to_load, \
235 t->flat<T>().data()), \
236 errors::InvalidArgument("Error copying slice data")); \
237 break;
238
239 switch (type) {
240 TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
241 default:
242 context->SetStatus(errors::Unimplemented(
243 "Restoring data type ", DataTypeString(type), " not yet supported"));
244 }
245#undef READER_COPY
246}
247
248namespace {
249
250// Tensors larger than this threshold will be restored from a thread-pool.
251const int64_t kLargeShapeThreshold = 16 << 20; // 16M
252
253// A restore operation for a single tensor. Small tensors may be restored
254// directly from the op thread to improve read locality. Large tensors can be
255// restored from a thread pool: this requires creating a separate BundleReader
256// for each restore.
257struct RestoreOp {
258 RestoreOp(OpKernelContext* context, int idx, const string& tensor_name,
259 const string& shape_and_slice, const string& reader_prefix,
260 DataType dtype)
261 : context(context),
262 idx(idx),
263 tensor_name(tensor_name),
264 shape_and_slice(shape_and_slice),
265 reader_prefix(reader_prefix),
266 dtype(dtype) {}
267
268 // Move-only. It does not make sense to "run()" a copied RestoreOp.
269 RestoreOp(const RestoreOp&) = delete;
270 RestoreOp& operator=(const RestoreOp&) = delete;
271 RestoreOp(RestoreOp&&) = default;
272 RestoreOp& operator=(RestoreOp&&) = default;
273
274 bool should_run_in_pool(BundleReader* reader) const {
275 TensorShape restored_full_shape;
276
277 // Ignore status here; we'll catch the error later.
278 if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
279 return false;
280 }
281
282 return restored_full_shape.num_elements() > kLargeShapeThreshold;
283 }
284
285 // Run this restore operation using a new BundleReader.
286 void run_with_new_reader() {
287 BundleReader reader(Env::Default(), reader_prefix);
288 if (!reader.status().ok()) {
289 status = reader.status();
290 return;
291 }
292
293 status = run(&reader);
294 }
295
296 Status run(BundleReader* reader) {
297 TensorShape restored_full_shape;
298 TF_RETURN_IF_ERROR(
299 reader->LookupTensorShape(tensor_name, &restored_full_shape));
300
301 VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
302 << restored_full_shape.num_elements();
303 Tensor* restored_tensor;
304 if (shape_and_slice.empty()) {
305 // Lookup the full tensor.
306 TF_RETURN_IF_ERROR(
307 context->allocate_output(idx, restored_full_shape, &restored_tensor));
308 TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
309 } else {
310 // Lookup the slice.
311 TensorShape parsed_full_shape;
312 TensorSlice parsed_slice;
313 TensorShape parsed_slice_shape;
314
315 TF_RETURN_IF_ERROR(
316 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
317 &parsed_slice, &parsed_slice_shape));
318
319 if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
320 return errors::InvalidArgument(
321 "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
322 parsed_full_shape.DebugString(),
323 " does not match the shape stored in checkpoint: ",
324 restored_full_shape.DebugString());
325 }
326 TF_RETURN_IF_ERROR(
327 context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
328 TF_RETURN_IF_ERROR(
329 reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
330 }
331 if (VLOG_IS_ON(5)) {
332 if (restored_tensor->dtype() == DT_FLOAT) {
333 const float* t_data = restored_tensor->flat<float>().data();
334 float min = std::numeric_limits<float>::infinity();
335 float max = -std::numeric_limits<float>::infinity();
336 double avg = 0.0;
337 for (int i = 0; i < restored_tensor->NumElements(); ++i) {
338 if (t_data[i] < min) min = t_data[i];
339 if (t_data[i] > max) max = t_data[i];
340 avg += t_data[i];
341 }
342 VLOG(5) << " min " << min << " max " << max << " avg "
343 << avg / restored_tensor->NumElements() << " total elts "
344 << restored_tensor->NumElements();
345 }
346 }
347 VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
348 << restored_full_shape.num_elements();
349 return OkStatus();
350 }
351
352 OpKernelContext* context;
353 int idx;
354 string tensor_name;
355 string shape_and_slice;
356 string reader_prefix;
357 DataType dtype;
358
359 ::tensorflow::Status status;
360};
361
362} // namespace
363
364Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
365 const Tensor& tensor_names,
366 const Tensor& shape_and_slices,
367 gtl::ArraySlice<DataType> dtypes) {
368 const string& prefix_string = prefix.scalar<tstring>()();
369
370 const auto& tensor_names_flat = tensor_names.flat<tstring>();
371 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
372
373 std::vector<RestoreOp> restore_ops;
374 restore_ops.reserve(tensor_names_flat.size());
375 for (int i = 0; i < tensor_names_flat.size(); ++i) {
376 restore_ops.push_back({context, i, tensor_names_flat(i),
377 shape_and_slices_flat(i), prefix_string, dtypes[i]});
378 }
379
380 BundleReader default_reader(Env::Default(), prefix_string);
381 TF_RETURN_IF_ERROR(default_reader.status());
382
383 TF_RETURN_IF_ERROR(default_reader.SortForSequentialAccess<RestoreOp>(
384 restore_ops, [](const RestoreOp& op) { return op.tensor_name; }));
385
386 std::vector<string> mismatched_errors;
387 for (const RestoreOp& restore_op : restore_ops) {
388 TensorShape restored_full_shape;
389 DataType original_dtype;
390 TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
391 restore_op.tensor_name, &original_dtype, &restored_full_shape));
392 if (restore_op.dtype != original_dtype) {
393 string error_msg = strings::StrCat(
394 "tensor_name = ", restore_op.tensor_name, "; expected dtype ",
395 DataTypeString(restore_op.dtype), " does not equal original dtype ",
396 DataTypeString(original_dtype));
397 mismatched_errors.emplace_back(error_msg);
398 }
399 }
400 if (!mismatched_errors.empty()) {
401 const string error_msg = absl::StrJoin(mismatched_errors, "\n");
402 return errors::InvalidArgument(error_msg);
403 }
404
405 std::vector<RestoreOp*> pool_restore_ops;
406 std::vector<RestoreOp*> direct_restore_ops;
407 for (RestoreOp& restore_op : restore_ops) {
408 if (restore_op.should_run_in_pool(&default_reader)) {
409 pool_restore_ops.push_back(&restore_op);
410 } else {
411 direct_restore_ops.push_back(&restore_op);
412 }
413 }
414
415 {
416 // Schedule any threaded operations first, skipping thread pool creation if
417 // we don't have any expensive operations.
418 std::unique_ptr<thread::ThreadPool> reader_pool;
419 if (!pool_restore_ops.empty()) {
420 reader_pool.reset(
421 new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
422 for (auto* op : pool_restore_ops) {
423 reader_pool->Schedule([op]() { op->run_with_new_reader(); });
424 }
425 }
426
427 // Read small tensors from the op thread
428 for (auto* op : direct_restore_ops) {
429 TF_RETURN_IF_ERROR(op->run(&default_reader));
430 }
431 }
432
433 // Check status of pool ops; this must come after the pool shuts down.
434 for (auto* op : pool_restore_ops) {
435 TF_RETURN_IF_ERROR(op->status);
436 }
437
438 for (const RestoreOp& restore_op : restore_ops) {
439 if (restore_op.dtype != context->mutable_output(restore_op.idx)->dtype()) {
440 return errors::InvalidArgument(
441 "tensor_name = ", restore_op.tensor_name, "; expected dtype ",
442 DataTypeString(restore_op.dtype), " does not equal restored dtype ",
443 DataTypeString(context->mutable_output(restore_op.idx)->dtype()));
444 }
445 }
446
447 return OkStatus();
448}
449
450} // namespace tensorflow
451