1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_
17#define TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_
18
19#include <string>
20
21#include "tensorflow/core/framework/dataset.h"
22#include "tensorflow/core/lib/core/status.h"
23
24namespace tensorflow {
25namespace data {
26
27// Reads dataset elements from the checkpoint reader using the given key prefix.
28Status ReadElementsFromCheckpoint(IteratorContext* ctx,
29 IteratorStateReader* reader,
30 StringPiece key_prefix,
31 std::vector<std::vector<Tensor>>* elements);
32
33// Writes dataset elements to the checkpoint writer using the given key prefix.
34// The elements can be read back by passing the same key prefix to
35// ReadElementsFromCheckpoint. Only one list of elements can be written under
36// the same key_prefix.
37Status WriteElementsToCheckpoint(
38 IteratorStateWriter* writer, StringPiece key_prefix,
39 const std::vector<std::vector<Tensor>>& elements);
40
41// Helper class for reading data from a vector of VariantTensorData objects.
42class VariantTensorDataReader : public IteratorStateReader {
43 public:
44 explicit VariantTensorDataReader(
45 const std::vector<const VariantTensorData*>& data);
46
47 bool Contains(StringPiece key) const override;
48 bool Contains(StringPiece name, StringPiece key) const override;
49
50 Status ReadScalar(StringPiece key, int64_t* val) const override;
51 Status ReadScalar(StringPiece name, StringPiece key,
52 int64_t* val) const override;
53 Status ReadScalar(StringPiece key, tstring* val) const override;
54 Status ReadScalar(StringPiece name, StringPiece key,
55 tstring* val) const override;
56 Status ReadTensor(StringPiece key, Tensor* val) const override;
57 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
58 Tensor* val) const override;
59 Status ReadTensor(StringPiece name, StringPiece key,
60 Tensor* val) const override;
61 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
62 StringPiece key, Tensor* val) const override;
63
64 private:
65 template <typename T>
66 Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const;
67 Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name,
68 StringPiece key, Tensor* val) const;
69 Status ReadDatasetInternal(FunctionLibraryRuntime* flr, StringPiece name,
70 StringPiece key, Tensor* val) const;
71
72 std::map<string, std::map<string, size_t>> map_;
73 std::map<string, const VariantTensorData*> data_; // Not owned.
74};
75
76// Helper class used to build a list of VariantTensorData objects, one for each
77// iterator which is determined from the key supplied from the Write* calls.
78// Sample usage:
79// VariantTensorDataWriter writer;
80// writer.WriteScalar(full_name("buffer_size"), buffer_.size());
81// writer.WriteScalar(full_name("num_threads"), threadpool_.size());
82// ....
83// std::vector<std::unique_ptr<VariantTensorData>> variants;
84// writer.ReleaseData(&variants);
85// Now the VariantTensorData objects can be used to serialize.
86class VariantTensorDataWriter : public IteratorStateWriter {
87 public:
88 Status WriteScalar(StringPiece key, const int64_t val) override;
89 Status WriteScalar(StringPiece name, StringPiece key,
90 const int64_t val) override;
91
92 Status WriteScalar(StringPiece key, const tstring& val) override;
93 Status WriteScalar(StringPiece name, StringPiece key,
94 const tstring& val) override;
95
96 Status WriteTensor(StringPiece key, const Tensor& val) override;
97 Status WriteTensor(StringPiece name, StringPiece key,
98 const Tensor& val) override;
99
100 // Releases the built VariantTensorData's to `variants`. Clears out all
101 // class state.
102 void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants);
103
104 // Obtains a read-only version of the VariantTensorData's built.
105 void GetData(std::vector<const VariantTensorData*>* variants);
106
107 private:
108 void MaybeFlush();
109 void Reset();
110
111 template <typename T>
112 Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val);
113 Status WriteTensorInternal(StringPiece name, StringPiece key,
114 const Tensor& val);
115 Status WriteDatasetInternal(StringPiece name, StringPiece key,
116 const DatasetBase* dataset);
117
118 bool is_flushed_ = false;
119 std::map<string, std::unique_ptr<VariantTensorData>> data_;
120 std::map<string, std::vector<string>> keys_;
121};
122
123// Returns a GraphDef representation of the given dataset.
124Status AsGraphDef(const DatasetBase* dataset,
125 SerializationContext&& serialization_ctx,
126 GraphDef* graph_def);
127
128// Returns a GraphDef representation of the given dataset suitable for
129// optimization rewrites. It sets serialization parameters to export a minimum
130// graph with additional information for optimization (i.e. ignoring external
131// state, not serializing data tensors, not failing if there are datasets which
132// do not have AsGraphDef implemented). Sets the `dataset_node` parameter to the
133// dataset's node name in the resulting GraphDef.
134Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input,
135 std::vector<std::pair<string, Tensor>>* input_list,
136 GraphDef* result, string* dataset_node);
137
138} // namespace data
139} // namespace tensorflow
140
141#endif // TENSORFLOW_CORE_KERNELS_SERIALIZATION_UTILS_H_
142