1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/tensor_slice_writer.h"
17
18#include <utility>
19
20#include "tensorflow/core/framework/versions.pb.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/io/table_builder.h"
23#include "tensorflow/core/lib/random/random.h"
24#include "tensorflow/core/lib/strings/strcat.h"
25#include "tensorflow/core/platform/env.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/public/version.h"
28#include "tensorflow/core/util/saved_tensor_slice_util.h"
29
30namespace tensorflow {
31
32namespace checkpoint {
33
34namespace {
35
36class TableBuilder : public TensorSliceWriter::Builder {
37 public:
38 TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) {
39 table::Options option;
40 option.compression = table::kNoCompression;
41 builder_.reset(new table::TableBuilder(option, f));
42 }
43 void Add(StringPiece key, StringPiece val) override {
44 builder_->Add(key, val);
45 }
46 Status Finish(int64_t* file_size) override {
47 *file_size = -1;
48 Status s = builder_->Finish();
49 if (s.ok()) {
50 s = file_->Close();
51 if (s.ok()) {
52 *file_size = builder_->FileSize();
53 }
54 }
55 if (!s.ok()) {
56 s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
57 s.error_message());
58 }
59 builder_.reset();
60 file_.reset();
61 return s;
62 }
63
64 private:
65 string name_;
66 std::unique_ptr<WritableFile> file_;
67 std::unique_ptr<table::TableBuilder> builder_;
68};
69} // anonymous namespace
70
71Status CreateTableTensorSliceBuilder(const string& name,
72 TensorSliceWriter::Builder** builder) {
73 *builder = nullptr;
74 std::unique_ptr<WritableFile> f;
75 Status s = Env::Default()->NewWritableFile(name, &f);
76 if (s.ok()) {
77 *builder = new TableBuilder(name, f.release());
78 return OkStatus();
79 } else {
80 return s;
81 }
82}
83
84TensorSliceWriter::TensorSliceWriter(const string& filename,
85 CreateBuilderFunction create_builder)
86 : filename_(filename),
87 create_builder_(std::move(create_builder)),
88 tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
89 slices_(0) {
90 VersionDef* versions = sts_.mutable_meta()->mutable_versions();
91 versions->set_producer(TF_CHECKPOINT_VERSION);
92 versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
93}
94
95Status TensorSliceWriter::Finish() {
96 Builder* b;
97 Status s = create_builder_(tmpname_, &b);
98 if (!s.ok()) {
99 delete b;
100 return s;
101 }
102 std::unique_ptr<Builder> builder(b);
103
104 // We save the saved tensor slice metadata as the first element.
105 string meta;
106 sts_.AppendToString(&meta);
107 builder->Add(kSavedTensorSlicesKey, meta);
108
109 // Go through all the data and add them
110 for (const auto& x : data_) {
111 builder->Add(x.first, x.second);
112 }
113
114 int64_t file_size;
115 s = builder->Finish(&file_size);
116 // We need to rename the file to the proper name
117 if (s.ok()) {
118 s = Env::Default()->RenameFile(tmpname_, filename_);
119 if (s.ok()) {
120 VLOG(1) << "Written " << slices_ << " slices for "
121 << sts_.meta().tensor_size() << " tensors (" << file_size
122 << " bytes) to " << filename_;
123 } else {
124 LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
125 }
126 } else {
127 Env::Default()->DeleteFile(tmpname_).IgnoreError();
128 }
129 return s;
130}
131
132/* static */
133size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
134 size_t max_bytes_per_element =
135 TensorSliceWriter::MaxBytesPerElementOrZero(dt);
136 if (max_bytes_per_element == 0) {
137 LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
138 }
139 return max_bytes_per_element;
140}
141
142/* static */
143size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) {
144 switch (dt) {
145 case DT_FLOAT:
146 return 4;
147 case DT_DOUBLE:
148 return 8;
149 case DT_INT32:
150 return 10;
151 case DT_UINT8:
152 return 2;
153 case DT_INT16:
154 return 10;
155 case DT_INT8:
156 return 10;
157 case DT_COMPLEX64:
158 return 8;
159 case DT_INT64:
160 return 10;
161 case DT_BOOL:
162 return 1;
163 case DT_QINT8:
164 return 10;
165 case DT_QUINT8:
166 return 2;
167 case DT_QINT32:
168 return 10;
169 case DT_QINT16:
170 return 10;
171 case DT_QUINT16:
172 return 3;
173 case DT_UINT16:
174 return 3;
175 case DT_COMPLEX128:
176 return 16;
177 case DT_HALF:
178 return 3;
179 case DT_INVALID:
180 case DT_STRING:
181 case DT_BFLOAT16:
182 default:
183 return 0;
184 }
185}
186
187template <>
188Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements,
189 SavedSlice* ss) {
190 size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
191 (num_elements * MaxBytesPerElement(DT_INT32));
192 for (int64_t i = 0; i < num_elements; ++i) {
193 size_bound += data[i].size();
194 }
195 if (size_bound > kMaxMessageBytes) {
196 return errors::InvalidArgument(
197 "Tensor slice is too large to serialize (conservative estimate: ",
198 size_bound, " bytes)");
199 }
200 Fill(data, num_elements, ss->mutable_data());
201 DCHECK_GE(ss->ByteSize(), 0);
202 DCHECK_LE(ss->ByteSize(), size_bound);
203 return OkStatus();
204}
205
206} // namespace checkpoint
207
208} // namespace tensorflow
209