1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
17#define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
18
19#include "tensorflow/core/util/tensor_slice_reader.h"
20#include "tensorflow/core/util/tensor_slice_writer.h"
21
22namespace tensorflow {
23
24class OpKernelContext;
25
26// Legacy / V1 checkpoint format.
27
28// Save input tensors in *context to a writer built from builder_func().
29// context must have the following inputs:
30// 0: a single element string tensor that contains the file name.
31// 1: names for the remaining tensors
32// If save_slices is true:
33// 2: shape and slice specifications.
34// rest: tensors to save
35void SaveTensors(
36 OpKernelContext* context,
37 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
38 bool save_slices);
39
40// Reads a single tensor from the reader built from open_func() and produces
41// it as context->output(restore_index). "preferred_shard" is the same the
42// TensorSliceReader preferred_shard parameter.
43//
44// context must have the following inputs:
45// 0: a single element string tensor that contains the file name.
46// 1: string tensor that names the outputs to be restored.
47// If restore_slice is true:
48// 2: shape and slice specification of the tensors to restore.
49//
50// restore_index indicates the variable name and slice to lookup
51// in context(1) and (2).
52void RestoreTensor(OpKernelContext* context,
53 checkpoint::TensorSliceReader::OpenTableFunction open_func,
54 int preferred_shard, bool restore_slice, int restore_index);
55
56// V2 checkpoint format.
57
58// Invokes the V2 checkpoint read path to read tensors.
59//
60// "context" is only used for allocating outputs. In particular, the inputs are
61// explicitly provided and not accessed via the "input(i)" methods.
62// REQUIRES:
63// * "prefix" has 1 element, DT_STRING.
64// * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING.
65// * "dtypes" has N elements, the datatypes of the to-restore tensors.
66Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
67 const Tensor& tensor_names,
68 const Tensor& shape_and_slices,
69 gtl::ArraySlice<DataType> dtypes);
70
71} // namespace tensorflow
72
73#endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
74