1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/checkpoint_reader.h"
17
18#include <unordered_set>
19#include <utility>
20
21#include "tensorflow/core/platform/env.h"
22#include "tensorflow/core/platform/status.h"
23#include "tensorflow/core/platform/stringpiece.h"
24#include "tensorflow/core/platform/types.h"
25#include "tensorflow/core/util/saved_tensor_slice_util.h"
26
27namespace tensorflow {
28namespace checkpoint {
29
30class TensorSliceReader;
31
32CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
33 : reader_(nullptr),
34 v2_reader_(nullptr),
35 var_to_shape_map_(nullptr),
36 var_to_data_type_map_(nullptr) {
37 // Depending on whether this is a V2 ckpt, initializes "reader_" or
38 // "v2_reader_".
39 std::vector<string> v2_path;
40 if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
41 !v2_path.empty()) {
42 v2_reader_.reset(
43 new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
44 if (!v2_reader_->status().ok()) {
45 Set_TF_Status_from_Status(status, v2_reader_->status());
46 return;
47 }
48 auto result = BuildV2VarMaps();
49 var_to_shape_map_.swap(result.first);
50 var_to_data_type_map_.swap(result.second);
51 } else {
52 reader_.reset(new TensorSliceReader(filename));
53 if (!reader_->status().ok()) {
54 Set_TF_Status_from_Status(status, reader_->status());
55 return;
56 }
57 var_to_shape_map_.reset(
58 new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
59 var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
60 reader_->GetVariableToDataTypeMap()));
61 }
62}
63
64bool CheckpointReader::HasTensor(const string& name) const {
65 if (reader_ != nullptr) {
66 return reader_->HasTensor(name, nullptr, nullptr);
67 }
68 return v2_reader_->Contains(name);
69}
70
71const TensorSliceReader::VarToShapeMap&
72CheckpointReader::GetVariableToShapeMap() const {
73 CHECK(var_to_shape_map_);
74 return *var_to_shape_map_;
75}
76
77const TensorSliceReader::VarToDataTypeMap&
78CheckpointReader::GetVariableToDataTypeMap() const {
79 CHECK(var_to_data_type_map_);
80 return *var_to_data_type_map_;
81}
82
83const string CheckpointReader::DebugString() const {
84 if (reader_ != nullptr) return reader_->DebugString();
85 return v2_reader_->DebugString();
86}
87
88void CheckpointReader::GetTensor(
89 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
90 TF_Status* out_status) const {
91 Status status;
92 if (reader_ != nullptr) {
93 status = reader_->GetTensor(name, out_tensor);
94 } else {
95 tensorflow::DataType dtype;
96 tensorflow::TensorShape shape;
97 status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
98 if (status.ok()) {
99 out_tensor->reset(new Tensor(dtype, shape));
100 status = v2_reader_->Lookup(name, out_tensor->get());
101 if (!status.ok()) out_tensor->reset();
102 }
103 }
104 if (!status.ok()) {
105 Set_TF_Status_from_Status(out_status, status);
106 }
107}
108
109std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
110 std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
111CheckpointReader::BuildV2VarMaps() {
112 CHECK(v2_reader_ != nullptr);
113 CHECK(v2_reader_->status().ok());
114
115 // First pass: filters out the entries of the slices.
116 std::unordered_set<string> filtered_keys;
117 BundleEntryProto entry;
118 v2_reader_->Seek(kHeaderEntryKey);
119 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
120 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
121 v2_reader_->value().size()))
122 << entry.InitializationErrorString();
123 for (int i = 0; i < entry.slices_size(); ++i) {
124 const auto& slice_proto = entry.slices(i);
125 CHECK(filtered_keys
126 .insert(EncodeTensorNameSlice(
127 string(v2_reader_->key()) /* full var's name */,
128 TensorSlice(slice_proto)))
129 .second);
130 }
131 }
132
133 // Second pass: adds the entries, ignoring the filtered keys.
134 std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
135 new TensorSliceReader::VarToShapeMap);
136 std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
137 new TensorSliceReader::VarToDataTypeMap);
138 v2_reader_->Seek(kHeaderEntryKey);
139 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
140 if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
141 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
142 v2_reader_->value().size()))
143 << entry.InitializationErrorString();
144 string key(v2_reader_->key());
145 (*var_to_shape_map)[key] = TensorShape(entry.shape());
146 (*var_to_data_type_map)[key] = DataType(entry.dtype());
147 }
148 // The returned pointers are owned by the caller.
149 return std::make_pair(std::move(var_to_shape_map),
150 std::move(var_to_data_type_map));
151}
152
153} // namespace checkpoint
154} // namespace tensorflow
155