1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/tensor_slice_writer.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/framework/versions.pb.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/io/table_builder.h" |
23 | #include "tensorflow/core/lib/random/random.h" |
24 | #include "tensorflow/core/lib/strings/strcat.h" |
25 | #include "tensorflow/core/platform/env.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/public/version.h" |
28 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | namespace checkpoint { |
33 | |
34 | namespace { |
35 | |
36 | class TableBuilder : public TensorSliceWriter::Builder { |
37 | public: |
38 | TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) { |
39 | table::Options option; |
40 | option.compression = table::kNoCompression; |
41 | builder_.reset(new table::TableBuilder(option, f)); |
42 | } |
43 | void Add(StringPiece key, StringPiece val) override { |
44 | builder_->Add(key, val); |
45 | } |
46 | Status Finish(int64_t* file_size) override { |
47 | *file_size = -1; |
48 | Status s = builder_->Finish(); |
49 | if (s.ok()) { |
50 | s = file_->Close(); |
51 | if (s.ok()) { |
52 | *file_size = builder_->FileSize(); |
53 | } |
54 | } |
55 | if (!s.ok()) { |
56 | s = errors::Internal("Error writing (tmp) checkpoint file: " , name_, ": " , |
57 | s.error_message()); |
58 | } |
59 | builder_.reset(); |
60 | file_.reset(); |
61 | return s; |
62 | } |
63 | |
64 | private: |
65 | string name_; |
66 | std::unique_ptr<WritableFile> file_; |
67 | std::unique_ptr<table::TableBuilder> builder_; |
68 | }; |
69 | } // anonymous namespace |
70 | |
71 | Status CreateTableTensorSliceBuilder(const string& name, |
72 | TensorSliceWriter::Builder** builder) { |
73 | *builder = nullptr; |
74 | std::unique_ptr<WritableFile> f; |
75 | Status s = Env::Default()->NewWritableFile(name, &f); |
76 | if (s.ok()) { |
77 | *builder = new TableBuilder(name, f.release()); |
78 | return OkStatus(); |
79 | } else { |
80 | return s; |
81 | } |
82 | } |
83 | |
84 | TensorSliceWriter::TensorSliceWriter(const string& filename, |
85 | CreateBuilderFunction create_builder) |
86 | : filename_(filename), |
87 | create_builder_(std::move(create_builder)), |
88 | tmpname_(strings::StrCat(filename, ".tempstate" , random::New64())), |
89 | slices_(0) { |
90 | VersionDef* versions = sts_.mutable_meta()->mutable_versions(); |
91 | versions->set_producer(TF_CHECKPOINT_VERSION); |
92 | versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER); |
93 | } |
94 | |
95 | Status TensorSliceWriter::Finish() { |
96 | Builder* b; |
97 | Status s = create_builder_(tmpname_, &b); |
98 | if (!s.ok()) { |
99 | delete b; |
100 | return s; |
101 | } |
102 | std::unique_ptr<Builder> builder(b); |
103 | |
104 | // We save the saved tensor slice metadata as the first element. |
105 | string meta; |
106 | sts_.AppendToString(&meta); |
107 | builder->Add(kSavedTensorSlicesKey, meta); |
108 | |
109 | // Go through all the data and add them |
110 | for (const auto& x : data_) { |
111 | builder->Add(x.first, x.second); |
112 | } |
113 | |
114 | int64_t file_size; |
115 | s = builder->Finish(&file_size); |
116 | // We need to rename the file to the proper name |
117 | if (s.ok()) { |
118 | s = Env::Default()->RenameFile(tmpname_, filename_); |
119 | if (s.ok()) { |
120 | VLOG(1) << "Written " << slices_ << " slices for " |
121 | << sts_.meta().tensor_size() << " tensors (" << file_size |
122 | << " bytes) to " << filename_; |
123 | } else { |
124 | LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_; |
125 | } |
126 | } else { |
127 | Env::Default()->DeleteFile(tmpname_).IgnoreError(); |
128 | } |
129 | return s; |
130 | } |
131 | |
132 | /* static */ |
133 | size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) { |
134 | size_t max_bytes_per_element = |
135 | TensorSliceWriter::MaxBytesPerElementOrZero(dt); |
136 | if (max_bytes_per_element == 0) { |
137 | LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt; |
138 | } |
139 | return max_bytes_per_element; |
140 | } |
141 | |
142 | /* static */ |
143 | size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) { |
144 | switch (dt) { |
145 | case DT_FLOAT: |
146 | return 4; |
147 | case DT_DOUBLE: |
148 | return 8; |
149 | case DT_INT32: |
150 | return 10; |
151 | case DT_UINT8: |
152 | return 2; |
153 | case DT_INT16: |
154 | return 10; |
155 | case DT_INT8: |
156 | return 10; |
157 | case DT_COMPLEX64: |
158 | return 8; |
159 | case DT_INT64: |
160 | return 10; |
161 | case DT_BOOL: |
162 | return 1; |
163 | case DT_QINT8: |
164 | return 10; |
165 | case DT_QUINT8: |
166 | return 2; |
167 | case DT_QINT32: |
168 | return 10; |
169 | case DT_QINT16: |
170 | return 10; |
171 | case DT_QUINT16: |
172 | return 3; |
173 | case DT_UINT16: |
174 | return 3; |
175 | case DT_COMPLEX128: |
176 | return 16; |
177 | case DT_HALF: |
178 | return 3; |
179 | case DT_INVALID: |
180 | case DT_STRING: |
181 | case DT_BFLOAT16: |
182 | default: |
183 | return 0; |
184 | } |
185 | } |
186 | |
187 | template <> |
188 | Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements, |
189 | SavedSlice* ss) { |
190 | size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + |
191 | (num_elements * MaxBytesPerElement(DT_INT32)); |
192 | for (int64_t i = 0; i < num_elements; ++i) { |
193 | size_bound += data[i].size(); |
194 | } |
195 | if (size_bound > kMaxMessageBytes) { |
196 | return errors::InvalidArgument( |
197 | "Tensor slice is too large to serialize (conservative estimate: " , |
198 | size_bound, " bytes)" ); |
199 | } |
200 | Fill(data, num_elements, ss->mutable_data()); |
201 | DCHECK_GE(ss->ByteSize(), 0); |
202 | DCHECK_LE(ss->ByteSize(), size_bound); |
203 | return OkStatus(); |
204 | } |
205 | |
206 | } // namespace checkpoint |
207 | |
208 | } // namespace tensorflow |
209 | |