1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/checkpoint_reader.h" |
17 | |
18 | #include <unordered_set> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/platform/env.h" |
22 | #include "tensorflow/core/platform/status.h" |
23 | #include "tensorflow/core/platform/stringpiece.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace checkpoint { |
29 | |
30 | class TensorSliceReader; |
31 | |
32 | CheckpointReader::CheckpointReader(const string& filename, TF_Status* status) |
33 | : reader_(nullptr), |
34 | v2_reader_(nullptr), |
35 | var_to_shape_map_(nullptr), |
36 | var_to_data_type_map_(nullptr) { |
37 | // Depending on whether this is a V2 ckpt, initializes "reader_" or |
38 | // "v2_reader_". |
39 | std::vector<string> v2_path; |
40 | if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && |
41 | !v2_path.empty()) { |
42 | v2_reader_.reset( |
43 | new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); |
44 | if (!v2_reader_->status().ok()) { |
45 | Set_TF_Status_from_Status(status, v2_reader_->status()); |
46 | return; |
47 | } |
48 | auto result = BuildV2VarMaps(); |
49 | var_to_shape_map_.swap(result.first); |
50 | var_to_data_type_map_.swap(result.second); |
51 | } else { |
52 | reader_.reset(new TensorSliceReader(filename)); |
53 | if (!reader_->status().ok()) { |
54 | Set_TF_Status_from_Status(status, reader_->status()); |
55 | return; |
56 | } |
57 | var_to_shape_map_.reset( |
58 | new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); |
59 | var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( |
60 | reader_->GetVariableToDataTypeMap())); |
61 | } |
62 | } |
63 | |
64 | bool CheckpointReader::HasTensor(const string& name) const { |
65 | if (reader_ != nullptr) { |
66 | return reader_->HasTensor(name, nullptr, nullptr); |
67 | } |
68 | return v2_reader_->Contains(name); |
69 | } |
70 | |
71 | const TensorSliceReader::VarToShapeMap& |
72 | CheckpointReader::GetVariableToShapeMap() const { |
73 | CHECK(var_to_shape_map_); |
74 | return *var_to_shape_map_; |
75 | } |
76 | |
77 | const TensorSliceReader::VarToDataTypeMap& |
78 | CheckpointReader::GetVariableToDataTypeMap() const { |
79 | CHECK(var_to_data_type_map_); |
80 | return *var_to_data_type_map_; |
81 | } |
82 | |
83 | const string CheckpointReader::DebugString() const { |
84 | if (reader_ != nullptr) return reader_->DebugString(); |
85 | return v2_reader_->DebugString(); |
86 | } |
87 | |
88 | void CheckpointReader::GetTensor( |
89 | const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor, |
90 | TF_Status* out_status) const { |
91 | Status status; |
92 | if (reader_ != nullptr) { |
93 | status = reader_->GetTensor(name, out_tensor); |
94 | } else { |
95 | tensorflow::DataType dtype; |
96 | tensorflow::TensorShape shape; |
97 | status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape); |
98 | if (status.ok()) { |
99 | out_tensor->reset(new Tensor(dtype, shape)); |
100 | status = v2_reader_->Lookup(name, out_tensor->get()); |
101 | if (!status.ok()) out_tensor->reset(); |
102 | } |
103 | } |
104 | if (!status.ok()) { |
105 | Set_TF_Status_from_Status(out_status, status); |
106 | } |
107 | } |
108 | |
109 | std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>, |
110 | std::unique_ptr<TensorSliceReader::VarToDataTypeMap>> |
111 | CheckpointReader::BuildV2VarMaps() { |
112 | CHECK(v2_reader_ != nullptr); |
113 | CHECK(v2_reader_->status().ok()); |
114 | |
115 | // First pass: filters out the entries of the slices. |
116 | std::unordered_set<string> filtered_keys; |
117 | BundleEntryProto entry; |
118 | v2_reader_->Seek(kHeaderEntryKey); |
119 | for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { |
120 | CHECK(entry.ParseFromArray(v2_reader_->value().data(), |
121 | v2_reader_->value().size())) |
122 | << entry.InitializationErrorString(); |
123 | for (int i = 0; i < entry.slices_size(); ++i) { |
124 | const auto& slice_proto = entry.slices(i); |
125 | CHECK(filtered_keys |
126 | .insert(EncodeTensorNameSlice( |
127 | string(v2_reader_->key()) /* full var's name */, |
128 | TensorSlice(slice_proto))) |
129 | .second); |
130 | } |
131 | } |
132 | |
133 | // Second pass: adds the entries, ignoring the filtered keys. |
134 | std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map( |
135 | new TensorSliceReader::VarToShapeMap); |
136 | std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map( |
137 | new TensorSliceReader::VarToDataTypeMap); |
138 | v2_reader_->Seek(kHeaderEntryKey); |
139 | for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { |
140 | if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; |
141 | CHECK(entry.ParseFromArray(v2_reader_->value().data(), |
142 | v2_reader_->value().size())) |
143 | << entry.InitializationErrorString(); |
144 | string key(v2_reader_->key()); |
145 | (*var_to_shape_map)[key] = TensorShape(entry.shape()); |
146 | (*var_to_data_type_map)[key] = DataType(entry.dtype()); |
147 | } |
148 | // The returned pointers are owned by the caller. |
149 | return std::make_pair(std::move(var_to_shape_map), |
150 | std::move(var_to_data_type_map)); |
151 | } |
152 | |
153 | } // namespace checkpoint |
154 | } // namespace tensorflow |
155 | |