1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/data/standalone.h"
17
18#include <algorithm>
19#include <functional>
20#include <memory>
21#include <string>
22#include <utility>
23
24#include "absl/memory/memory.h"
25#include "tensorflow/core/common_runtime/device_factory.h"
26#include "tensorflow/core/common_runtime/device_mgr.h"
27#include "tensorflow/core/common_runtime/function.h"
28#include "tensorflow/core/common_runtime/graph_constructor.h"
29#include "tensorflow/core/common_runtime/graph_runner.h"
30#include "tensorflow/core/common_runtime/process_util.h"
31#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32#include "tensorflow/core/data/root_dataset.h"
33#include "tensorflow/core/framework/dataset.h"
34#include "tensorflow/core/framework/op_kernel.h"
35#include "tensorflow/core/graph/graph.h"
36#include "tensorflow/core/lib/core/errors.h"
37#include "tensorflow/core/platform/refcount.h"
38#include "tensorflow/core/public/version.h"
39#include "tensorflow/core/util/ptr_util.h"
40
41namespace tensorflow {
42namespace data {
43namespace standalone {
44
45namespace {
46
47OpKernelContext::Params CreateParams(
48 ProcessFunctionLibraryRuntime* pflr, DeviceMgr* device_mgr,
49 std::function<void(std::function<void()>)>* runner) {
50 OpKernelContext::Params params;
51 params.function_library = pflr->GetFLR("/device:CPU:0");
52 params.device = device_mgr->ListDevices()[0];
53 params.runner = runner;
54 return params;
55}
56
57} // namespace
58
59Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
60 return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
61}
62
63Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
64 : iterator_(iterator), ctx_(ctx) {}
65
66Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
67 std::unique_ptr<Dataset>* result) {
68 Graph graph(OpRegistry::Global());
69 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
70
71 // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
72 auto device_mgr = std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
73 "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
74 Device* device = device_mgr->ListDevices()[0];
75 // Create a copy of the `FunctionLibraryDefinition` to extend lifetime beyond
76 // the lifetime of `graph`.
77 auto flib_def = std::make_unique<FunctionLibraryDefinition>(
78 OpRegistry::Global(), graph_def.library());
79 auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
80 device_mgr.get(), Env::Default(), /*config=*/nullptr,
81 TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
82 /*thread_pool=*/nullptr, /*parent=*/nullptr,
83 /*session_metadata=*/nullptr,
84 Rendezvous::Factory{
85 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
86 *r = new IntraProcessRendezvous(device_mgr);
87 return OkStatus();
88 }});
89
90 string fetch_node = "";
91 for (const auto& node : graph_def.node()) {
92 if (node.op() == "_Retval") {
93 fetch_node = node.input(0);
94 }
95 }
96 if (fetch_node.empty()) {
97 return errors::NotFound("Failed to find a _Retval op in the given dataset");
98 }
99
100 // Run graph up to `output_node` and extract the `DatasetBase` stored in the
101 // DT_VARIANT output tensor.
102 std::vector<Tensor> outputs;
103 GraphRunner graph_runner(device);
104 TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"), {},
105 {fetch_node}, &outputs));
106 data::DatasetBase* dataset;
107 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
108
109 data::DatasetBase* finalized_dataset;
110 std::unique_ptr<thread::ThreadPool> pool(
111 NewThreadPoolFromSessionOptions(params.session_options));
112 std::function<void(std::function<void()>)> runner =
113 [&pool](std::function<void()> c) { pool->Schedule(std::move(c)); };
114 OpKernelContext::Params op_params =
115 CreateParams(pflr.get(), device_mgr.get(), &runner);
116 OpKernelContext ctx(&op_params, /*num_outputs=*/0);
117 TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset));
118 core::ScopedUnref unref(finalized_dataset);
119 *result = WrapUnique(new Dataset(
120 finalized_dataset, dataset, device_mgr.release(), pflr.release(),
121 flib_def.release(), pool.release(), std::move(runner)));
122 return OkStatus();
123} // static
124
125Status Dataset::MakeIterator(
126 std::vector<std::unique_ptr<SplitProvider>> split_providers,
127 std::unique_ptr<Iterator>* result) {
128 // Create an `IteratorContext`, which bundles together the necessary runtime
129 // support to create and get elements from an iterator.
130 std::unique_ptr<IteratorContext> ctx;
131 // NOTE(mrry): In the current API, an `IteratorContext` is always initially
132 // created from an `OpKernelContext*`, so we need to create `OpKernelContext`
133 // with a valid subset of parameters.
134 OpKernelContext::Params op_params =
135 CreateParams(pflr_.get(), device_mgr_.get(), &runner_);
136 OpKernelContext op_ctx(&op_params, /*num_outputs=*/0);
137 IteratorContext::Params params(&op_ctx);
138 params.cancellation_manager = &cancellation_manager_;
139 params.function_handle_cache = function_handle_cache_.get();
140 params.resource_mgr = &resource_mgr_;
141 std::move(split_providers.begin(), split_providers.end(),
142 std::back_inserter(params.split_providers));
143 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
144 params.thread_pool = &unbounded_thread_pool_;
145 ctx = std::make_unique<IteratorContext>(std::move(params));
146
147 // Create the iterator from the dataset.
148 std::unique_ptr<IteratorBase> iterator;
149 TF_RETURN_IF_ERROR(finalized_dataset_->MakeIterator(
150 ctx.get(), /*parent=*/nullptr, "Iterator", &iterator));
151 *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
152
153 return OkStatus();
154}
155
156Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
157 return MakeIterator(/*split_providers=*/{}, result);
158}
159
160Status Dataset::MakeSplitProviders(
161 std::vector<std::unique_ptr<SplitProvider>>* result) {
162 return finalized_dataset_->MakeSplitProviders(result);
163}
164
165const DatasetBase* Dataset::Get() const { return finalized_dataset_; }
166
167Dataset::Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
168 DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
169 FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
170 std::function<void(std::function<void()>)> runner)
171 : finalized_dataset_(finalized_dataset),
172 original_dataset_(original_dataset),
173 device_mgr_(device_mgr),
174 flib_def_(flib_def),
175 pflr_(pflr),
176 interop_threadpool_(pool),
177 runner_(std::move(runner)),
178 unbounded_thread_pool_(Env::Default(), "tf_data_standalone") {
179 finalized_dataset_->Ref();
180 original_dataset_->Ref();
181 function_handle_cache_ =
182 std::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
183}
184
185Dataset::~Dataset() {
186 finalized_dataset_->Unref();
187 original_dataset_->Unref();
188}
189
190} // namespace standalone
191} // namespace data
192} // namespace tensorflow
193