1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_
17#define TENSORFLOW_CORE_DATA_STANDALONE_H_
18
19#include <functional>
20#include <memory>
21
22#include "tensorflow/core/common_runtime/device_mgr.h"
23#include "tensorflow/core/data/unbounded_thread_pool.h"
24#include "tensorflow/core/framework/dataset.h"
25#include "tensorflow/core/framework/function_handle_cache.h"
26#include "tensorflow/core/lib/core/threadpool.h"
27#include "tensorflow/core/public/session_options.h"
28
29namespace tensorflow {
30namespace data {
31namespace standalone {
32
33// The purpose of the API in this file is to facilitate standalone execution of
34// a tf.data input pipeline graph.
35//
36// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which
37// encapsulate TensorFlow runtime.
38//
39// The `Dataset` abstraction represents an input pipeline as a collection
40// of data sources and a logical plan of transformations that operate over the
41// data.
42//
43// The `Iterator` abstraction represents an execution of an input pipeline that
44// can be used to enumerate its elements.
45//
46// Example usage:
47//
48// // Create a `Dataset` by running the `graph_def` graph.
49// tensorflow::data:standalone::Dataset::Params params;
50// std::unique_ptr<tensorflow::data::standalone::Dataset> dataset;
51// Status s = tensorflow::data::standalone::Dataset::FromGraph(
52// params, graph_def, &dataset);
53// if (!s.ok()) { /* error handling */ }
54//
55// std::unique_ptr<tensorflow::data::standalone::Iterator> iterator;
56// s = dataset->MakeIterator(&iterator);
57// if (!s.ok()) { /* error handling */ }
58//
59// bool end_of_input = false;
60// while (!end_of_input) {
61// std::vector<tensorflow::Tensor> outputs;
62// s = iterator->GetNext(&outputs, &end_of_input);
63// if (!s.ok()) { /* error handling */ }
64// if (!end_of_input) { /* output handling */ }
65// }
66
67class Dataset;
68
69// Represents an execution of an input pipeline that can be used to enumerate
70// its elements.
71class Iterator {
72 public:
73 // Returns the next element of the input pipeline (if there is one) and an
74 // indication of whether the end of the input pipeline has been reached.
75 Status GetNext(std::vector<Tensor>* outputs, bool* end_of_input);
76
77 private:
78 friend class Dataset;
79
80 Iterator(IteratorBase* iterator, IteratorContext* ctx);
81
82 std::unique_ptr<IteratorBase> iterator_;
83 std::unique_ptr<IteratorContext> ctx_;
84};
85
86// Represents an input pipeline as a collection of data sources and a logical
87// plan of transformations that operate over the data.
88class Dataset {
89 public:
90 // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration).
91 struct Params {
92 SessionOptions session_options;
93 };
94
95 // Creates a new `Dataset` instance by running the given dataset graph.
96 static Status FromGraph(Params params, const GraphDef& graph_def,
97 std::unique_ptr<Dataset>* result);
98
99 ~Dataset();
100
101 // Creates an iterator for this dataset.
102 Status MakeIterator(std::unique_ptr<Iterator>* result);
103 // Creates an iterator, optionally with a split provider.
104 Status MakeIterator(
105 std::vector<std::unique_ptr<SplitProvider>> split_providers,
106 std::unique_ptr<Iterator>* result);
107
108 // Creates split providers for this dataset.
109 Status MakeSplitProviders(
110 std::vector<std::unique_ptr<SplitProvider>>* result);
111 // Returns a pointer to the underlying dataset.
112 const DatasetBase* Get() const;
113
114 private:
115 Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
116 DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
117 FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
118 std::function<void(std::function<void()>)> runner);
119
120 DatasetBase* finalized_dataset_; // owned
121 DatasetBase* original_dataset_; // owned
122 std::unique_ptr<DeviceMgr> device_mgr_;
123 std::unique_ptr<FunctionLibraryDefinition> flib_def_;
124 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
125 std::unique_ptr<thread::ThreadPool> interop_threadpool_;
126 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
127 std::function<void(std::function<void()>)> runner_;
128 ResourceMgr resource_mgr_;
129 CancellationManager cancellation_manager_;
130 UnboundedThreadPool unbounded_thread_pool_;
131};
132
133} // namespace standalone
134} // namespace data
135} // namespace tensorflow
136
137#endif // TENSORFLOW_CORE_DATA_STANDALONE_H_
138