1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/data/standalone.h" |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <utility> |
23 | |
24 | #include "absl/memory/memory.h" |
25 | #include "tensorflow/core/common_runtime/device_factory.h" |
26 | #include "tensorflow/core/common_runtime/device_mgr.h" |
27 | #include "tensorflow/core/common_runtime/function.h" |
28 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
29 | #include "tensorflow/core/common_runtime/graph_runner.h" |
30 | #include "tensorflow/core/common_runtime/process_util.h" |
31 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
32 | #include "tensorflow/core/data/root_dataset.h" |
33 | #include "tensorflow/core/framework/dataset.h" |
34 | #include "tensorflow/core/framework/op_kernel.h" |
35 | #include "tensorflow/core/graph/graph.h" |
36 | #include "tensorflow/core/lib/core/errors.h" |
37 | #include "tensorflow/core/platform/refcount.h" |
38 | #include "tensorflow/core/public/version.h" |
39 | #include "tensorflow/core/util/ptr_util.h" |
40 | |
41 | namespace tensorflow { |
42 | namespace data { |
43 | namespace standalone { |
44 | |
45 | namespace { |
46 | |
47 | OpKernelContext::Params CreateParams( |
48 | ProcessFunctionLibraryRuntime* pflr, DeviceMgr* device_mgr, |
49 | std::function<void(std::function<void()>)>* runner) { |
50 | OpKernelContext::Params params; |
51 | params.function_library = pflr->GetFLR("/device:CPU:0" ); |
52 | params.device = device_mgr->ListDevices()[0]; |
53 | params.runner = runner; |
54 | return params; |
55 | } |
56 | |
57 | } // namespace |
58 | |
59 | Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) { |
60 | return iterator_->GetNext(ctx_.get(), outputs, end_of_input); |
61 | } |
62 | |
63 | Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx) |
64 | : iterator_(iterator), ctx_(ctx) {} |
65 | |
66 | Status Dataset::FromGraph(Params params, const GraphDef& graph_def, |
67 | std::unique_ptr<Dataset>* result) { |
68 | Graph graph(OpRegistry::Global()); |
69 | TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); |
70 | |
71 | // Instantiate enough of the TF runtime to run `graph` on a single CPU device. |
72 | auto device_mgr = std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice( |
73 | "CPU" , params.session_options, "/job:localhost/replica:0/task:0" )); |
74 | Device* device = device_mgr->ListDevices()[0]; |
75 | // Create a copy of the `FunctionLibraryDefinition` to extend lifetime beyond |
76 | // the lifetime of `graph`. |
77 | auto flib_def = std::make_unique<FunctionLibraryDefinition>( |
78 | OpRegistry::Global(), graph_def.library()); |
79 | auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>( |
80 | device_mgr.get(), Env::Default(), /*config=*/nullptr, |
81 | TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{}, |
82 | /*thread_pool=*/nullptr, /*parent=*/nullptr, |
83 | /*session_metadata=*/nullptr, |
84 | Rendezvous::Factory{ |
85 | [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) { |
86 | *r = new IntraProcessRendezvous(device_mgr); |
87 | return OkStatus(); |
88 | }}); |
89 | |
90 | string fetch_node = "" ; |
91 | for (const auto& node : graph_def.node()) { |
92 | if (node.op() == "_Retval" ) { |
93 | fetch_node = node.input(0); |
94 | } |
95 | } |
96 | if (fetch_node.empty()) { |
97 | return errors::NotFound("Failed to find a _Retval op in the given dataset" ); |
98 | } |
99 | |
100 | // Run graph up to `output_node` and extract the `DatasetBase` stored in the |
101 | // DT_VARIANT output tensor. |
102 | std::vector<Tensor> outputs; |
103 | GraphRunner graph_runner(device); |
104 | TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0" ), {}, |
105 | {fetch_node}, &outputs)); |
106 | data::DatasetBase* dataset; |
107 | TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); |
108 | |
109 | data::DatasetBase* finalized_dataset; |
110 | std::unique_ptr<thread::ThreadPool> pool( |
111 | NewThreadPoolFromSessionOptions(params.session_options)); |
112 | std::function<void(std::function<void()>)> runner = |
113 | [&pool](std::function<void()> c) { pool->Schedule(std::move(c)); }; |
114 | OpKernelContext::Params op_params = |
115 | CreateParams(pflr.get(), device_mgr.get(), &runner); |
116 | OpKernelContext ctx(&op_params, /*num_outputs=*/0); |
117 | TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset)); |
118 | core::ScopedUnref unref(finalized_dataset); |
119 | *result = WrapUnique(new Dataset( |
120 | finalized_dataset, dataset, device_mgr.release(), pflr.release(), |
121 | flib_def.release(), pool.release(), std::move(runner))); |
122 | return OkStatus(); |
123 | } // static |
124 | |
125 | Status Dataset::MakeIterator( |
126 | std::vector<std::unique_ptr<SplitProvider>> split_providers, |
127 | std::unique_ptr<Iterator>* result) { |
128 | // Create an `IteratorContext`, which bundles together the necessary runtime |
129 | // support to create and get elements from an iterator. |
130 | std::unique_ptr<IteratorContext> ctx; |
131 | // NOTE(mrry): In the current API, an `IteratorContext` is always initially |
132 | // created from an `OpKernelContext*`, so we need to create `OpKernelContext` |
133 | // with a valid subset of parameters. |
134 | OpKernelContext::Params op_params = |
135 | CreateParams(pflr_.get(), device_mgr_.get(), &runner_); |
136 | OpKernelContext op_ctx(&op_params, /*num_outputs=*/0); |
137 | IteratorContext::Params params(&op_ctx); |
138 | params.cancellation_manager = &cancellation_manager_; |
139 | params.function_handle_cache = function_handle_cache_.get(); |
140 | params.resource_mgr = &resource_mgr_; |
141 | std::move(split_providers.begin(), split_providers.end(), |
142 | std::back_inserter(params.split_providers)); |
143 | params.thread_factory = unbounded_thread_pool_.get_thread_factory(); |
144 | params.thread_pool = &unbounded_thread_pool_; |
145 | ctx = std::make_unique<IteratorContext>(std::move(params)); |
146 | |
147 | // Create the iterator from the dataset. |
148 | std::unique_ptr<IteratorBase> iterator; |
149 | TF_RETURN_IF_ERROR(finalized_dataset_->MakeIterator( |
150 | ctx.get(), /*parent=*/nullptr, "Iterator" , &iterator)); |
151 | *result = WrapUnique(new Iterator(iterator.release(), ctx.release())); |
152 | |
153 | return OkStatus(); |
154 | } |
155 | |
156 | Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) { |
157 | return MakeIterator(/*split_providers=*/{}, result); |
158 | } |
159 | |
160 | Status Dataset::MakeSplitProviders( |
161 | std::vector<std::unique_ptr<SplitProvider>>* result) { |
162 | return finalized_dataset_->MakeSplitProviders(result); |
163 | } |
164 | |
165 | const DatasetBase* Dataset::Get() const { return finalized_dataset_; } |
166 | |
167 | Dataset::Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset, |
168 | DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr, |
169 | FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool, |
170 | std::function<void(std::function<void()>)> runner) |
171 | : finalized_dataset_(finalized_dataset), |
172 | original_dataset_(original_dataset), |
173 | device_mgr_(device_mgr), |
174 | flib_def_(flib_def), |
175 | pflr_(pflr), |
176 | interop_threadpool_(pool), |
177 | runner_(std::move(runner)), |
178 | unbounded_thread_pool_(Env::Default(), "tf_data_standalone" ) { |
179 | finalized_dataset_->Ref(); |
180 | original_dataset_->Ref(); |
181 | function_handle_cache_ = |
182 | std::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0" )); |
183 | } |
184 | |
185 | Dataset::~Dataset() { |
186 | finalized_dataset_->Unref(); |
187 | original_dataset_->Unref(); |
188 | } |
189 | |
190 | } // namespace standalone |
191 | } // namespace data |
192 | } // namespace tensorflow |
193 | |