1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/data/serialization_utils.h"
17
18#include <string>
19#include <utility>
20
21#include "tensorflow/core/common_runtime/graph_constructor.h"
22#include "tensorflow/core/common_runtime/graph_runner.h"
23#include "tensorflow/core/data/dataset_utils.h"
24#include "tensorflow/core/framework/dataset.h"
25#include "tensorflow/core/framework/function.h"
26#include "tensorflow/core/graph/graph_def_builder.h"
27
28namespace tensorflow {
29namespace data {
30namespace {
31
32constexpr char kDelimiter[] = "@@";
33constexpr char kComponent[] = "component";
34constexpr char kNumComponents[] = "num_components";
35constexpr char kNumElements[] = "num_elements";
36constexpr char kIsDataset[] = ".is_dataset";
37constexpr char kOutputNode[] = ".output_node";
38
39// We assume that all keys are of the form <iterator_prefix>:<name>. We extract
40// the iterator name by getting rid of everything post the final colon.
41Status GetIteratorName(StringPiece key, string* name) {
42 if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
43 return errors::InvalidArgument("Save key: ", key,
44 " not generated using full_name.");
45 }
46 std::vector<string> split_keys = str_util::Split(key, data::kPipe);
47 if (split_keys.size() != 2) {
48 return errors::InvalidArgument("Save key: ", key,
49 " not generated using full_name.");
50 }
51 string real_key = split_keys[1];
52 const int pos = real_key.rfind(kColon);
53 *name = real_key.substr(0, pos);
54 return OkStatus();
55}
56
57Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def,
58 const std::vector<std::pair<string, Tensor>>& input_list,
59 const string& output_node, Tensor* result) {
60 FunctionLibraryRuntime* cloned_flr = nullptr;
61 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
62 std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
63 TF_RETURN_IF_ERROR(flr->Clone(&lib_def, &pflr, &cloned_flr, true));
64 TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
65 Graph graph(OpRegistry::Global());
66 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
67 std::vector<Tensor> outputs;
68 GraphRunner graph_runner(cloned_flr->device());
69 TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list,
70 {output_node}, &outputs));
71 *result = outputs[0];
72 return OkStatus();
73}
74
75// FindStatefulOps searches `graph_def` for all of its stateful ops storing
76// their names in `stateful_op_names`.
77Status FindStatefulOps(const GraphDef& graph_def,
78 std::vector<string>* stateful_op_names) {
79 FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());
80
81 // Iterate over all nodes in the graph.
82 for (const auto& node : graph_def.node()) {
83 // Each Dataset graph has a _Retval op in the end which is marked stateful
84 if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
85 if (!IsNodeStateful(lib_def, node).ok()) {
86 stateful_op_names->push_back(node.op());
87 }
88 }
89
90 // Iterate over all functions.
91 for (const auto& fdef : graph_def.library().function()) {
92 if (!fdef.signature().is_stateful()) continue;
93 for (const auto& node : fdef.node_def()) {
94 if (!IsNodeStateful(lib_def, node).ok()) {
95 stateful_op_names->push_back(
96 absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
97 }
98 }
99 }
100 return OkStatus();
101}
102
103} // namespace
104
105Status ReadElementsFromCheckpoint(IteratorContext* ctx,
106 IteratorStateReader* reader,
107 StringPiece key_prefix,
108 std::vector<std::vector<Tensor>>* elements) {
109 int64_t num_elements;
110 TF_RETURN_IF_ERROR(
111 reader->ReadScalar(key_prefix, kNumElements, &num_elements));
112 DCHECK(elements->empty());
113 elements->reserve(num_elements);
114 for (int i = 0; i < num_elements; ++i) {
115 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
116 int64_t num_components;
117 TF_RETURN_IF_ERROR(
118 reader->ReadScalar(element_prefix, kNumComponents, &num_components));
119 elements->emplace_back();
120 std::vector<Tensor>& element = elements->at(i);
121 element.reserve(num_components);
122 for (int j = 0; j < num_components; ++j) {
123 element.emplace_back();
124 TF_RETURN_IF_ERROR(reader->ReadTensor(
125 ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
126 &element.back()));
127 }
128 }
129 return OkStatus();
130}
131
132Status WriteElementsToCheckpoint(
133 IteratorStateWriter* writer, StringPiece key_prefix,
134 const std::vector<std::vector<Tensor>>& elements) {
135 TF_RETURN_IF_ERROR(
136 writer->WriteScalar(key_prefix, kNumElements, elements.size()));
137 for (int i = 0; i < elements.size(); ++i) {
138 const std::vector<Tensor>& element = elements[i];
139 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
140 TF_RETURN_IF_ERROR(
141 writer->WriteScalar(element_prefix, kNumComponents, element.size()));
142 for (int j = 0; j < elements[i].size(); ++j) {
143 TF_RETURN_IF_ERROR(writer->WriteTensor(
144 element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
145 }
146 }
147 return OkStatus();
148}
149
150VariantTensorDataReader::VariantTensorDataReader(
151 const std::vector<const tensorflow::VariantTensorData*>& data) {
152 for (const auto& d : data) {
153 string metadata;
154 d->get_metadata(&metadata);
155 auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
156 const string name = keys[0];
157 data_[name] = d;
158 map_[name] = std::map<string, size_t>();
159 for (size_t i = 1; i < keys.size(); ++i) {
160 map_[name][keys[i]] = i - 1;
161 }
162 }
163}
164
165Status VariantTensorDataReader::ReadScalar(StringPiece key,
166 int64_t* val) const {
167 string name;
168 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
169 return ReadScalar(name, key, val);
170}
171
172Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
173 int64_t* val) const {
174 return ReadScalarInternal(name, key, val);
175}
176
177Status VariantTensorDataReader::ReadScalar(StringPiece key,
178 tstring* val) const {
179 string name;
180 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
181 return ReadScalar(name, key, val);
182}
183
184Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
185 tstring* val) const {
186 return ReadScalarInternal(name, key, val);
187}
188
189Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
190 string name;
191 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
192 return ReadTensor(name, key, val);
193}
194
195Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
196 StringPiece key, Tensor* val) const {
197 string name;
198 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
199 return ReadTensorInternal(flr, name, key, val);
200}
201
202Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
203 Tensor* val) const {
204 return ReadTensor(/*flr=*/nullptr, name, key, val);
205}
206
207Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
208 StringPiece name, StringPiece key,
209 Tensor* val) const {
210 return ReadTensorInternal(flr, name, key, val);
211}
212
213bool VariantTensorDataReader::Contains(StringPiece key) const {
214 string name;
215 if (!GetIteratorName(key, &name).ok()) {
216 return false;
217 }
218 return Contains(name, key);
219}
220
221bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
222 string name(n);
223 auto it = map_.find(name);
224 if (it == map_.end()) {
225 return false;
226 }
227 const auto& bucket = it->second;
228 return bucket.find(string(key)) != bucket.end();
229}
230
231template <typename T>
232Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
233 StringPiece key,
234 T* val) const {
235 string name(n);
236 auto it = map_.find(name);
237 if (it == map_.end()) {
238 return errors::NotFound(name);
239 }
240 const auto& bucket = it->second;
241 auto key_it = bucket.find(string(key));
242 if (key_it == bucket.end()) {
243 return errors::NotFound(key);
244 }
245 *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
246 return OkStatus();
247}
248
249Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr,
250 StringPiece n,
251 StringPiece key,
252 Tensor* val) const {
253 if (Contains(n, strings::StrCat(key, kIsDataset))) {
254 return ReadDatasetInternal(flr, n, key, val);
255 }
256 string name(n);
257 auto it = map_.find(name);
258 if (it == map_.end()) {
259 return errors::NotFound(name);
260 }
261 const auto& bucket = it->second;
262 auto key_it = bucket.find(string(key));
263 if (key_it == bucket.end()) {
264 return errors::NotFound(key);
265 }
266 *val = data_.at(name)->tensors(key_it->second);
267 return OkStatus();
268}
269
270Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr,
271 StringPiece n,
272 StringPiece key,
273 Tensor* val) const {
274 if (flr == nullptr) {
275 return errors::Internal(
276 "Function library runtime is needed to restore a dataset.");
277 }
278 tstring output_node, serialized_graph_def;
279 TF_RETURN_IF_ERROR(
280 ReadScalar(n, strings::StrCat(key, kOutputNode), &output_node));
281 TF_RETURN_IF_ERROR(
282 ReadScalar(n, strings::StrCat(key), &serialized_graph_def));
283 GraphDef graph_def;
284 graph_def.ParseFromString(serialized_graph_def);
285 TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val));
286 return OkStatus();
287}
288
289Status VariantTensorDataWriter::WriteScalar(StringPiece key,
290 const int64_t val) {
291 string name;
292 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
293 return WriteScalar(name, key, val);
294}
295
296Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
297 const int64_t val) {
298 return WriteScalarInternal(name, key, val);
299}
300
301Status VariantTensorDataWriter::WriteScalar(StringPiece key,
302 const tstring& val) {
303 string name;
304 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
305 return WriteScalar(name, key, val);
306}
307
308Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
309 const tstring& val) {
310 return WriteScalarInternal(name, key, val);
311}
312
313Status VariantTensorDataWriter::WriteTensor(StringPiece key,
314 const Tensor& val) {
315 string name;
316 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
317 return WriteTensor(name, key, val);
318}
319
320Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
321 const Tensor& val) {
322 return WriteTensorInternal(name, key, val);
323}
324
325void VariantTensorDataWriter::MaybeFlush() {
326 if (is_flushed_) return;
327 for (auto& keys : keys_) {
328 const string name = keys.first;
329 string metadata = name;
330 for (size_t i = 0; i < keys_[name].size(); ++i) {
331 strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
332 }
333 data_[name]->set_metadata(metadata);
334 }
335 is_flushed_ = true;
336}
337
338void VariantTensorDataWriter::Reset() {
339 is_flushed_ = false;
340 data_.clear();
341 keys_.clear();
342}
343
344void VariantTensorDataWriter::ReleaseData(
345 std::vector<std::unique_ptr<VariantTensorData>>* variants) {
346 MaybeFlush();
347 for (auto& it : data_) {
348 variants->push_back(std::move(it.second));
349 }
350 Reset();
351}
352
353void VariantTensorDataWriter::GetData(
354 std::vector<const VariantTensorData*>* variants) {
355 MaybeFlush();
356 for (auto& it : data_) {
357 variants->push_back(it.second.get());
358 }
359}
360
361template <typename T>
362Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
363 StringPiece key,
364 const T& val) {
365 if (is_flushed_) {
366 return errors::FailedPrecondition(
367 "Cannot call WriteScalar after GetData or ReleaseData is called");
368 }
369 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
370 val_t.scalar<T>()() = val;
371 return WriteTensorInternal(name, key, val_t);
372}
373
374Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
375 StringPiece key,
376 const Tensor& val) {
377 DatasetBase* dataset;
378 if (GetDatasetFromVariantTensor(val, &dataset).ok()) {
379 return WriteDatasetInternal(n, key, dataset);
380 }
381 if (is_flushed_) {
382 return errors::FailedPrecondition(
383 "Cannot call WriteTensor after GetData or ReleaseData is called");
384 }
385 DCHECK_EQ(key.find(kDelimiter), string::npos);
386 string name(n);
387 if (keys_.count(name) == 0) {
388 keys_[name] = std::vector<string>();
389 }
390 keys_[name].push_back(string(key));
391 if (data_.count(name) == 0) {
392 data_[name] = std::make_unique<VariantTensorData>();
393 data_[name]->set_type_name("tensorflow::Iterator");
394 }
395 *(data_[name]->add_tensors()) = val;
396 return OkStatus();
397}
398
399Status VariantTensorDataWriter::WriteDatasetInternal(
400 StringPiece n, StringPiece key, const DatasetBase* dataset) {
401 GraphDef graph_def;
402 SerializationContext ctx((SerializationContext::Params()));
403 TF_RETURN_IF_ERROR(AsGraphDef(dataset, std::move(ctx), &graph_def));
404 string output_node;
405 for (const auto& node : graph_def.node()) {
406 if (node.op() == "_Retval") {
407 output_node = node.input(0);
408 break;
409 }
410 }
411 string result;
412 graph_def.SerializeToString(&result);
413 TF_RETURN_IF_ERROR(WriteScalar(n, strings::StrCat(key, kIsDataset), ""));
414 TF_RETURN_IF_ERROR(
415 WriteScalar(n, strings::StrCat(key, kOutputNode), output_node));
416 TF_RETURN_IF_ERROR(WriteScalar(n, key, result));
417 return OkStatus();
418}
419
420Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input,
421 std::vector<std::pair<string, Tensor>>* input_list,
422 GraphDef* result, string* dataset_node) {
423 SerializationContext::Params params(ctx);
424 params.input_list = input_list;
425 params.external_state_policy =
426 SerializationContext::ExternalStatePolicy::kIgnore;
427 params.is_graph_rewrite = true;
428 SerializationContext serialization_ctx(params);
429 TF_RETURN_IF_ERROR(AsGraphDef(input, std::move(serialization_ctx), result));
430
431 // Symbolic `_Retval` node indicates which node corresponds to the dataset.
432 for (const auto& node : result->node()) {
433 if (node.op() == "_Retval") {
434 *dataset_node = node.input(0);
435 }
436 }
437 return OkStatus();
438}
439
440Status AsGraphDef(const DatasetBase* dataset,
441 SerializationContext&& serialization_ctx,
442 GraphDef* graph_def) {
443 if (serialization_ctx.external_state_policy() ==
444 SerializationContext::ExternalStatePolicy::kFail) {
445 TF_RETURN_IF_ERROR(dataset->CheckExternalState());
446 }
447 if (serialization_ctx.external_state_policy() ==
448 SerializationContext::ExternalStatePolicy::kWarn) {
449 std::vector<string> stateful_op_names;
450 TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names));
451 if (!stateful_op_names.empty()) {
452 LOG(WARNING) << "We found the following stateful ops in the dataset "
453 "construction graph whose state would not be "
454 "serialized and might "
455 "cause subtle bugs: "
456 << absl::StrJoin(stateful_op_names, ", ");
457 }
458 }
459 GraphDefBuilder b;
460 DatasetBase::DatasetGraphDefBuilder db(&b);
461 Node* output_node = nullptr;
462 TF_RETURN_IF_ERROR(
463 db.AddInputDataset(&serialization_ctx, dataset, &output_node));
464 // Insert a purely symbolic _Retval node to indicate to consumers which node
465 // represents `dataset`.
466 ops::UnaryOp("_Retval", output_node,
467 b.opts()
468 .WithName("dataset")
469 .WithAttr("T", DT_VARIANT)
470 .WithAttr("index", 0));
471 TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
472 return OkStatus();
473}
474
475} // namespace data
476} // namespace tensorflow
477