1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/data/serialization_utils.h" |
17 | |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
22 | #include "tensorflow/core/common_runtime/graph_runner.h" |
23 | #include "tensorflow/core/data/dataset_utils.h" |
24 | #include "tensorflow/core/framework/dataset.h" |
25 | #include "tensorflow/core/framework/function.h" |
26 | #include "tensorflow/core/graph/graph_def_builder.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace data { |
30 | namespace { |
31 | |
32 | constexpr char kDelimiter[] = "@@" ; |
33 | constexpr char kComponent[] = "component" ; |
34 | constexpr char kNumComponents[] = "num_components" ; |
35 | constexpr char kNumElements[] = "num_elements" ; |
36 | constexpr char kIsDataset[] = ".is_dataset" ; |
37 | constexpr char kOutputNode[] = ".output_node" ; |
38 | |
39 | // We assume that all keys are of the form <iterator_prefix>:<name>. We extract |
40 | // the iterator name by getting rid of everything post the final colon. |
41 | Status GetIteratorName(StringPiece key, string* name) { |
42 | if (!str_util::StartsWith(key, data::kFullNameRandomHex)) { |
43 | return errors::InvalidArgument("Save key: " , key, |
44 | " not generated using full_name." ); |
45 | } |
46 | std::vector<string> split_keys = str_util::Split(key, data::kPipe); |
47 | if (split_keys.size() != 2) { |
48 | return errors::InvalidArgument("Save key: " , key, |
49 | " not generated using full_name." ); |
50 | } |
51 | string real_key = split_keys[1]; |
52 | const int pos = real_key.rfind(kColon); |
53 | *name = real_key.substr(0, pos); |
54 | return OkStatus(); |
55 | } |
56 | |
57 | Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def, |
58 | const std::vector<std::pair<string, Tensor>>& input_list, |
59 | const string& output_node, Tensor* result) { |
60 | FunctionLibraryRuntime* cloned_flr = nullptr; |
61 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr; |
62 | std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr; |
63 | TF_RETURN_IF_ERROR(flr->Clone(&lib_def, &pflr, &cloned_flr, true)); |
64 | TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library())); |
65 | Graph graph(OpRegistry::Global()); |
66 | TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); |
67 | std::vector<Tensor> outputs; |
68 | GraphRunner graph_runner(cloned_flr->device()); |
69 | TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list, |
70 | {output_node}, &outputs)); |
71 | *result = outputs[0]; |
72 | return OkStatus(); |
73 | } |
74 | |
75 | // FindStatefulOps searches `graph_def` for all of its stateful ops storing |
76 | // their names in `stateful_op_names`. |
77 | Status FindStatefulOps(const GraphDef& graph_def, |
78 | std::vector<string>* stateful_op_names) { |
79 | FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library()); |
80 | |
81 | // Iterate over all nodes in the graph. |
82 | for (const auto& node : graph_def.node()) { |
83 | // Each Dataset graph has a _Retval op in the end which is marked stateful |
84 | if (node.op() == FunctionLibraryDefinition::kRetOp) continue; |
85 | if (!IsNodeStateful(lib_def, node).ok()) { |
86 | stateful_op_names->push_back(node.op()); |
87 | } |
88 | } |
89 | |
90 | // Iterate over all functions. |
91 | for (const auto& fdef : graph_def.library().function()) { |
92 | if (!fdef.signature().is_stateful()) continue; |
93 | for (const auto& node : fdef.node_def()) { |
94 | if (!IsNodeStateful(lib_def, node).ok()) { |
95 | stateful_op_names->push_back( |
96 | absl::StrCat(node.op(), " in function: " , fdef.signature().name())); |
97 | } |
98 | } |
99 | } |
100 | return OkStatus(); |
101 | } |
102 | |
103 | } // namespace |
104 | |
105 | Status ReadElementsFromCheckpoint(IteratorContext* ctx, |
106 | IteratorStateReader* reader, |
107 | StringPiece key_prefix, |
108 | std::vector<std::vector<Tensor>>* elements) { |
109 | int64_t num_elements; |
110 | TF_RETURN_IF_ERROR( |
111 | reader->ReadScalar(key_prefix, kNumElements, &num_elements)); |
112 | DCHECK(elements->empty()); |
113 | elements->reserve(num_elements); |
114 | for (int i = 0; i < num_elements; ++i) { |
115 | std::string element_prefix = absl::StrCat(key_prefix, "::" , i); |
116 | int64_t num_components; |
117 | TF_RETURN_IF_ERROR( |
118 | reader->ReadScalar(element_prefix, kNumComponents, &num_components)); |
119 | elements->emplace_back(); |
120 | std::vector<Tensor>& element = elements->at(i); |
121 | element.reserve(num_components); |
122 | for (int j = 0; j < num_components; ++j) { |
123 | element.emplace_back(); |
124 | TF_RETURN_IF_ERROR(reader->ReadTensor( |
125 | ctx->flr(), element_prefix, absl::StrCat(kComponent, "[" , j, "]" ), |
126 | &element.back())); |
127 | } |
128 | } |
129 | return OkStatus(); |
130 | } |
131 | |
132 | Status WriteElementsToCheckpoint( |
133 | IteratorStateWriter* writer, StringPiece key_prefix, |
134 | const std::vector<std::vector<Tensor>>& elements) { |
135 | TF_RETURN_IF_ERROR( |
136 | writer->WriteScalar(key_prefix, kNumElements, elements.size())); |
137 | for (int i = 0; i < elements.size(); ++i) { |
138 | const std::vector<Tensor>& element = elements[i]; |
139 | std::string element_prefix = absl::StrCat(key_prefix, "::" , i); |
140 | TF_RETURN_IF_ERROR( |
141 | writer->WriteScalar(element_prefix, kNumComponents, element.size())); |
142 | for (int j = 0; j < elements[i].size(); ++j) { |
143 | TF_RETURN_IF_ERROR(writer->WriteTensor( |
144 | element_prefix, absl::StrCat(kComponent, "[" , j, "]" ), element[j])); |
145 | } |
146 | } |
147 | return OkStatus(); |
148 | } |
149 | |
150 | VariantTensorDataReader::VariantTensorDataReader( |
151 | const std::vector<const tensorflow::VariantTensorData*>& data) { |
152 | for (const auto& d : data) { |
153 | string metadata; |
154 | d->get_metadata(&metadata); |
155 | auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty()); |
156 | const string name = keys[0]; |
157 | data_[name] = d; |
158 | map_[name] = std::map<string, size_t>(); |
159 | for (size_t i = 1; i < keys.size(); ++i) { |
160 | map_[name][keys[i]] = i - 1; |
161 | } |
162 | } |
163 | } |
164 | |
165 | Status VariantTensorDataReader::ReadScalar(StringPiece key, |
166 | int64_t* val) const { |
167 | string name; |
168 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
169 | return ReadScalar(name, key, val); |
170 | } |
171 | |
172 | Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key, |
173 | int64_t* val) const { |
174 | return ReadScalarInternal(name, key, val); |
175 | } |
176 | |
177 | Status VariantTensorDataReader::ReadScalar(StringPiece key, |
178 | tstring* val) const { |
179 | string name; |
180 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
181 | return ReadScalar(name, key, val); |
182 | } |
183 | |
184 | Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key, |
185 | tstring* val) const { |
186 | return ReadScalarInternal(name, key, val); |
187 | } |
188 | |
189 | Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const { |
190 | string name; |
191 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
192 | return ReadTensor(name, key, val); |
193 | } |
194 | |
195 | Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, |
196 | StringPiece key, Tensor* val) const { |
197 | string name; |
198 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
199 | return ReadTensorInternal(flr, name, key, val); |
200 | } |
201 | |
202 | Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key, |
203 | Tensor* val) const { |
204 | return ReadTensor(/*flr=*/nullptr, name, key, val); |
205 | } |
206 | |
207 | Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, |
208 | StringPiece name, StringPiece key, |
209 | Tensor* val) const { |
210 | return ReadTensorInternal(flr, name, key, val); |
211 | } |
212 | |
213 | bool VariantTensorDataReader::Contains(StringPiece key) const { |
214 | string name; |
215 | if (!GetIteratorName(key, &name).ok()) { |
216 | return false; |
217 | } |
218 | return Contains(name, key); |
219 | } |
220 | |
221 | bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const { |
222 | string name(n); |
223 | auto it = map_.find(name); |
224 | if (it == map_.end()) { |
225 | return false; |
226 | } |
227 | const auto& bucket = it->second; |
228 | return bucket.find(string(key)) != bucket.end(); |
229 | } |
230 | |
231 | template <typename T> |
232 | Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, |
233 | StringPiece key, |
234 | T* val) const { |
235 | string name(n); |
236 | auto it = map_.find(name); |
237 | if (it == map_.end()) { |
238 | return errors::NotFound(name); |
239 | } |
240 | const auto& bucket = it->second; |
241 | auto key_it = bucket.find(string(key)); |
242 | if (key_it == bucket.end()) { |
243 | return errors::NotFound(key); |
244 | } |
245 | *val = data_.at(name)->tensors(key_it->second).scalar<T>()(); |
246 | return OkStatus(); |
247 | } |
248 | |
249 | Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr, |
250 | StringPiece n, |
251 | StringPiece key, |
252 | Tensor* val) const { |
253 | if (Contains(n, strings::StrCat(key, kIsDataset))) { |
254 | return ReadDatasetInternal(flr, n, key, val); |
255 | } |
256 | string name(n); |
257 | auto it = map_.find(name); |
258 | if (it == map_.end()) { |
259 | return errors::NotFound(name); |
260 | } |
261 | const auto& bucket = it->second; |
262 | auto key_it = bucket.find(string(key)); |
263 | if (key_it == bucket.end()) { |
264 | return errors::NotFound(key); |
265 | } |
266 | *val = data_.at(name)->tensors(key_it->second); |
267 | return OkStatus(); |
268 | } |
269 | |
270 | Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr, |
271 | StringPiece n, |
272 | StringPiece key, |
273 | Tensor* val) const { |
274 | if (flr == nullptr) { |
275 | return errors::Internal( |
276 | "Function library runtime is needed to restore a dataset." ); |
277 | } |
278 | tstring output_node, serialized_graph_def; |
279 | TF_RETURN_IF_ERROR( |
280 | ReadScalar(n, strings::StrCat(key, kOutputNode), &output_node)); |
281 | TF_RETURN_IF_ERROR( |
282 | ReadScalar(n, strings::StrCat(key), &serialized_graph_def)); |
283 | GraphDef graph_def; |
284 | graph_def.ParseFromString(serialized_graph_def); |
285 | TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val)); |
286 | return OkStatus(); |
287 | } |
288 | |
289 | Status VariantTensorDataWriter::WriteScalar(StringPiece key, |
290 | const int64_t val) { |
291 | string name; |
292 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
293 | return WriteScalar(name, key, val); |
294 | } |
295 | |
296 | Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key, |
297 | const int64_t val) { |
298 | return WriteScalarInternal(name, key, val); |
299 | } |
300 | |
301 | Status VariantTensorDataWriter::WriteScalar(StringPiece key, |
302 | const tstring& val) { |
303 | string name; |
304 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
305 | return WriteScalar(name, key, val); |
306 | } |
307 | |
308 | Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key, |
309 | const tstring& val) { |
310 | return WriteScalarInternal(name, key, val); |
311 | } |
312 | |
313 | Status VariantTensorDataWriter::WriteTensor(StringPiece key, |
314 | const Tensor& val) { |
315 | string name; |
316 | TF_RETURN_IF_ERROR(GetIteratorName(key, &name)); |
317 | return WriteTensor(name, key, val); |
318 | } |
319 | |
320 | Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key, |
321 | const Tensor& val) { |
322 | return WriteTensorInternal(name, key, val); |
323 | } |
324 | |
325 | void VariantTensorDataWriter::MaybeFlush() { |
326 | if (is_flushed_) return; |
327 | for (auto& keys : keys_) { |
328 | const string name = keys.first; |
329 | string metadata = name; |
330 | for (size_t i = 0; i < keys_[name].size(); ++i) { |
331 | strings::StrAppend(&metadata, kDelimiter, keys_[name][i]); |
332 | } |
333 | data_[name]->set_metadata(metadata); |
334 | } |
335 | is_flushed_ = true; |
336 | } |
337 | |
338 | void VariantTensorDataWriter::Reset() { |
339 | is_flushed_ = false; |
340 | data_.clear(); |
341 | keys_.clear(); |
342 | } |
343 | |
344 | void VariantTensorDataWriter::ReleaseData( |
345 | std::vector<std::unique_ptr<VariantTensorData>>* variants) { |
346 | MaybeFlush(); |
347 | for (auto& it : data_) { |
348 | variants->push_back(std::move(it.second)); |
349 | } |
350 | Reset(); |
351 | } |
352 | |
353 | void VariantTensorDataWriter::GetData( |
354 | std::vector<const VariantTensorData*>* variants) { |
355 | MaybeFlush(); |
356 | for (auto& it : data_) { |
357 | variants->push_back(it.second.get()); |
358 | } |
359 | } |
360 | |
361 | template <typename T> |
362 | Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, |
363 | StringPiece key, |
364 | const T& val) { |
365 | if (is_flushed_) { |
366 | return errors::FailedPrecondition( |
367 | "Cannot call WriteScalar after GetData or ReleaseData is called" ); |
368 | } |
369 | Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); |
370 | val_t.scalar<T>()() = val; |
371 | return WriteTensorInternal(name, key, val_t); |
372 | } |
373 | |
374 | Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, |
375 | StringPiece key, |
376 | const Tensor& val) { |
377 | DatasetBase* dataset; |
378 | if (GetDatasetFromVariantTensor(val, &dataset).ok()) { |
379 | return WriteDatasetInternal(n, key, dataset); |
380 | } |
381 | if (is_flushed_) { |
382 | return errors::FailedPrecondition( |
383 | "Cannot call WriteTensor after GetData or ReleaseData is called" ); |
384 | } |
385 | DCHECK_EQ(key.find(kDelimiter), string::npos); |
386 | string name(n); |
387 | if (keys_.count(name) == 0) { |
388 | keys_[name] = std::vector<string>(); |
389 | } |
390 | keys_[name].push_back(string(key)); |
391 | if (data_.count(name) == 0) { |
392 | data_[name] = std::make_unique<VariantTensorData>(); |
393 | data_[name]->set_type_name("tensorflow::Iterator" ); |
394 | } |
395 | *(data_[name]->add_tensors()) = val; |
396 | return OkStatus(); |
397 | } |
398 | |
399 | Status VariantTensorDataWriter::WriteDatasetInternal( |
400 | StringPiece n, StringPiece key, const DatasetBase* dataset) { |
401 | GraphDef graph_def; |
402 | SerializationContext ctx((SerializationContext::Params())); |
403 | TF_RETURN_IF_ERROR(AsGraphDef(dataset, std::move(ctx), &graph_def)); |
404 | string output_node; |
405 | for (const auto& node : graph_def.node()) { |
406 | if (node.op() == "_Retval" ) { |
407 | output_node = node.input(0); |
408 | break; |
409 | } |
410 | } |
411 | string result; |
412 | graph_def.SerializeToString(&result); |
413 | TF_RETURN_IF_ERROR(WriteScalar(n, strings::StrCat(key, kIsDataset), "" )); |
414 | TF_RETURN_IF_ERROR( |
415 | WriteScalar(n, strings::StrCat(key, kOutputNode), output_node)); |
416 | TF_RETURN_IF_ERROR(WriteScalar(n, key, result)); |
417 | return OkStatus(); |
418 | } |
419 | |
420 | Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, |
421 | std::vector<std::pair<string, Tensor>>* input_list, |
422 | GraphDef* result, string* dataset_node) { |
423 | SerializationContext::Params params(ctx); |
424 | params.input_list = input_list; |
425 | params.external_state_policy = |
426 | SerializationContext::ExternalStatePolicy::kIgnore; |
427 | params.is_graph_rewrite = true; |
428 | SerializationContext serialization_ctx(params); |
429 | TF_RETURN_IF_ERROR(AsGraphDef(input, std::move(serialization_ctx), result)); |
430 | |
431 | // Symbolic `_Retval` node indicates which node corresponds to the dataset. |
432 | for (const auto& node : result->node()) { |
433 | if (node.op() == "_Retval" ) { |
434 | *dataset_node = node.input(0); |
435 | } |
436 | } |
437 | return OkStatus(); |
438 | } |
439 | |
440 | Status AsGraphDef(const DatasetBase* dataset, |
441 | SerializationContext&& serialization_ctx, |
442 | GraphDef* graph_def) { |
443 | if (serialization_ctx.external_state_policy() == |
444 | SerializationContext::ExternalStatePolicy::kFail) { |
445 | TF_RETURN_IF_ERROR(dataset->CheckExternalState()); |
446 | } |
447 | if (serialization_ctx.external_state_policy() == |
448 | SerializationContext::ExternalStatePolicy::kWarn) { |
449 | std::vector<string> stateful_op_names; |
450 | TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names)); |
451 | if (!stateful_op_names.empty()) { |
452 | LOG(WARNING) << "We found the following stateful ops in the dataset " |
453 | "construction graph whose state would not be " |
454 | "serialized and might " |
455 | "cause subtle bugs: " |
456 | << absl::StrJoin(stateful_op_names, ", " ); |
457 | } |
458 | } |
459 | GraphDefBuilder b; |
460 | DatasetBase::DatasetGraphDefBuilder db(&b); |
461 | Node* output_node = nullptr; |
462 | TF_RETURN_IF_ERROR( |
463 | db.AddInputDataset(&serialization_ctx, dataset, &output_node)); |
464 | // Insert a purely symbolic _Retval node to indicate to consumers which node |
465 | // represents `dataset`. |
466 | ops::UnaryOp("_Retval" , output_node, |
467 | b.opts() |
468 | .WithName("dataset" ) |
469 | .WithAttr("T" , DT_VARIANT) |
470 | .WithAttr("index" , 0)); |
471 | TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def)); |
472 | return OkStatus(); |
473 | } |
474 | |
475 | } // namespace data |
476 | } // namespace tensorflow |
477 | |