1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ |
18 | |
19 | #include "tensorflow/core/util/tensor_slice_reader.h" |
20 | #include "tensorflow/core/util/tensor_slice_writer.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | class OpKernelContext; |
25 | |
26 | // Legacy / V1 checkpoint format. |
27 | |
28 | // Save input tensors in *context to a writer built from builder_func(). |
29 | // context must have the following inputs: |
30 | // 0: a single element string tensor that contains the file name. |
31 | // 1: names for the remaining tensors |
32 | // If save_slices is true: |
33 | // 2: shape and slice specifications. |
34 | // rest: tensors to save |
35 | void SaveTensors( |
36 | OpKernelContext* context, |
37 | checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, |
38 | bool save_slices); |
39 | |
40 | // Reads a single tensor from the reader built from open_func() and produces |
41 | // it as context->output(restore_index). "preferred_shard" is the same the |
42 | // TensorSliceReader preferred_shard parameter. |
43 | // |
44 | // context must have the following inputs: |
45 | // 0: a single element string tensor that contains the file name. |
46 | // 1: string tensor that names the outputs to be restored. |
47 | // If restore_slice is true: |
48 | // 2: shape and slice specification of the tensors to restore. |
49 | // |
50 | // restore_index indicates the variable name and slice to lookup |
51 | // in context(1) and (2). |
52 | void RestoreTensor(OpKernelContext* context, |
53 | checkpoint::TensorSliceReader::OpenTableFunction open_func, |
54 | int preferred_shard, bool restore_slice, int restore_index); |
55 | |
56 | // V2 checkpoint format. |
57 | |
58 | // Invokes the V2 checkpoint read path to read tensors. |
59 | // |
60 | // "context" is only used for allocating outputs. In particular, the inputs are |
61 | // explicitly provided and not accessed via the "input(i)" methods. |
62 | // REQUIRES: |
63 | // * "prefix" has 1 element, DT_STRING. |
64 | // * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING. |
65 | // * "dtypes" has N elements, the datatypes of the to-restore tensors. |
66 | Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, |
67 | const Tensor& tensor_names, |
68 | const Tensor& shape_and_slices, |
69 | gtl::ArraySlice<DataType> dtypes); |
70 | |
71 | } // namespace tensorflow |
72 | |
73 | #endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ |
74 | |