1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ |
16 | #define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ |
17 | |
18 | #include <memory> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/cancellation.h" |
22 | #include "tensorflow/core/framework/dataset.h" |
23 | #include "tensorflow/core/framework/function.h" |
24 | #include "tensorflow/core/framework/model.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/lib/core/status.h" |
28 | #include "tensorflow/core/lib/gtl/array_slice.h" |
29 | #include "tensorflow/core/lib/random/random.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | class Device; |
35 | class OpKernelContext; |
36 | class ResourceMgr; |
37 | |
38 | namespace data { |
39 | |
40 | class CapturedFunction; |
41 | class InstantiatedCapturedFunction; |
42 | |
43 | // Creates an iterator for a dataset which is created by applying the given |
44 | // function to the given input element. |
45 | Status MakeIteratorFromInputElement( |
46 | IteratorContext* ctx, const IteratorBase* parent, |
47 | const std::vector<Tensor>& input_element, int64_t thread_index, |
48 | const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, |
49 | std::unique_ptr<IteratorBase>* out_iterator); |
50 | |
51 | // Creates an iterator for a dataset which is created by applying the given |
52 | // function to the given input element. Pass non-null `node` to record |
53 | // processing time for modeling Iterator's GetNext() resource usage. |
54 | Status MakeIteratorFromInputElement( |
55 | IteratorContext* ctx, const IteratorBase* parent, |
56 | const std::vector<Tensor>& input_element, int64_t thread_index, |
57 | const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, |
58 | std::unique_ptr<IteratorBase>* out_iterator, |
59 | const std::shared_ptr<model::Node>& node); |
60 | |
61 | // Creates an iterator context appropriate for a nested dataset's iterator. A |
62 | // nested dataset is a dataset created within another dataset, e.g. by the |
63 | // function passed to `interleave` or `flat_map`. |
64 | IteratorContext MakeNestedIteratorContext(IteratorContext* ctx); |
65 | |
66 | struct ShortCircuitInfo { |
67 | std::vector<int> indices; |
68 | std::vector<bool> can_move; |
69 | }; |
70 | |
71 | // Metadata shared across all captures of the same function. |
72 | class FunctionMetadata { |
73 | public: |
74 | struct Params { |
75 | bool use_inter_op_parallelism = true; |
76 | bool use_default_device = true; |
77 | }; |
78 | |
79 | // Creates a new instance of the `FunctionMetadata` class, fetching function |
80 | // from a context argument. |
81 | static Status Create(tensorflow::OpKernelConstruction* ctx, |
82 | const string& func_name, Params params, |
83 | std::shared_ptr<FunctionMetadata>* out_metadata); |
84 | |
85 | // Creates a new instance of the `FunctionMetadata` class, using the provided |
86 | // function. |
87 | static Status Create(tensorflow::OpKernelConstruction* ctx, |
88 | NameAttrList&& func, Params params, |
89 | std::shared_ptr<FunctionMetadata>* out_metadata); |
90 | |
91 | // Returns the named list of function arguments. |
92 | const NameAttrList& func() const { return func_; } |
93 | |
94 | // Returns a borrowed pointer to the function library that contains the |
95 | // transitive closure of definitions used by the function. |
96 | const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); } |
97 | |
98 | // Returns short-circuit information. |
99 | const ShortCircuitInfo& short_circuit_info() const { |
100 | return short_circuit_info_; |
101 | } |
102 | |
103 | // Indicates whether a default device should be used for executing function |
104 | // ops. |
105 | bool use_default_device() const { return use_default_device_; } |
106 | |
107 | // Indicates whether to use inter-op parallelism for execution of the |
108 | // function. |
109 | bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; } |
110 | |
111 | // Indicates whether the function should a multi-device function backend. |
112 | bool use_multi_device_function() const { return use_multi_device_function_; } |
113 | |
114 | private: |
115 | FunctionMetadata(NameAttrList&& func, Params params) |
116 | : func_(std::move(func)), |
117 | use_default_device_(params.use_default_device), |
118 | use_inter_op_parallelism_(params.use_inter_op_parallelism) {} |
119 | |
120 | NameAttrList func_; |
121 | std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr; |
122 | ShortCircuitInfo short_circuit_info_; |
123 | bool use_default_device_ = true; |
124 | bool use_inter_op_parallelism_ = true; |
125 | bool use_multi_device_function_ = true; |
126 | }; |
127 | |
128 | // Constructs and stores the parameters for the CapturedFunction Instantiate |
129 | // function. |
130 | struct InstantiateCapturedFunctionParams { |
131 | explicit InstantiateCapturedFunctionParams(IteratorContext* ctx) { |
132 | flr = ctx->flr(); |
133 | function_handle_cache = ctx->function_handle_cache(); |
134 | runner = ctx->runner(); |
135 | } |
136 | |
137 | explicit InstantiateCapturedFunctionParams(OpKernelContext* ctx) { |
138 | flr = ctx->function_library(); |
139 | function_handle_cache = nullptr; |
140 | runner = ctx->runner(); |
141 | } |
142 | |
143 | FunctionLibraryRuntime* flr; |
144 | FunctionHandleCache* function_handle_cache; |
145 | std::function<void(std::function<void()>)>* runner; |
146 | }; |
147 | |
148 | // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" |
149 | // arguments that it closed over in the user program. |
150 | class CapturedFunction { |
151 | public: |
152 | // Creates a new instance using a list of named attributes, fetching captured |
153 | // inputs from a context argument. |
154 | static Status Create(OpKernelContext* ctx, |
155 | std::shared_ptr<const FunctionMetadata> metadata, |
156 | const string& argument_name, |
157 | std::unique_ptr<CapturedFunction>* out_function); |
158 | |
159 | // Creates a new instance using a list of named attributes, using provided |
160 | // captured inputs. |
161 | static Status Create(OpKernelContext* ctx, |
162 | std::shared_ptr<const FunctionMetadata> metadata, |
163 | std::vector<Tensor>&& captured_inputs, |
164 | std::unique_ptr<CapturedFunction>* out_function); |
165 | |
166 | // Adds the definition of this captured function into the given graph, |
167 | // returning its captured inputs and types through the respective output |
168 | // arguments. |
169 | Status AddToGraph(SerializationContext* ctx, |
170 | DatasetBase::DatasetGraphDefBuilder* b, |
171 | std::vector<Node*>* other_arguments, |
172 | DataTypeVector* other_arguments_types) const; |
173 | |
174 | // Instantiates this function for use in the given context, providing an |
175 | // InstantiatedCapturedFunction that can be used to execute functions. |
176 | Status Instantiate(IteratorContext* ctx, |
177 | std::unique_ptr<InstantiatedCapturedFunction>* |
178 | instantiated_captured_function); |
179 | |
180 | Status Instantiate(InstantiateCapturedFunctionParams params, |
181 | std::unique_ptr<InstantiatedCapturedFunction>* |
182 | instantiated_captured_function); |
183 | |
184 | // Determines whether the captured function is stateful. |
185 | Status CheckExternalState() const; |
186 | |
187 | // Returns the additional captured inputs that will be passed to the function. |
188 | const std::vector<Tensor>& captured_inputs() const { |
189 | return captured_inputs_; |
190 | } |
191 | |
192 | // Returns the named list of function arguments. |
193 | const NameAttrList& func() const { return metadata_->func(); } |
194 | |
195 | // Returns the transitive set of function definition required to instantiate |
196 | // this function. |
197 | const FunctionLibraryDefinition* lib_def() const { |
198 | return metadata_->lib_def(); |
199 | } |
200 | |
201 | // If every function output corresponds to one of its inputs, the method |
202 | // returns the mapping from output indices to input indices. Otherwise, it |
203 | // returns an empty list. |
204 | const ShortCircuitInfo& short_circuit_info() const { |
205 | return metadata_->short_circuit_info(); |
206 | } |
207 | |
208 | // Indicates whether the function should use inter op parallelism. |
209 | bool use_inter_op_parallelism() const { |
210 | return metadata_->use_inter_op_parallelism(); |
211 | } |
212 | |
213 | private: |
214 | CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata, |
215 | std::vector<Tensor> captured_inputs); |
216 | |
217 | Status IsMultiDevice(FunctionLibraryRuntime* flr, |
218 | bool* is_multi_device) const; |
219 | |
220 | const std::shared_ptr<const FunctionMetadata> metadata_; |
221 | const std::vector<Tensor> captured_inputs_; |
222 | |
223 | TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); |
224 | }; |
225 | |
226 | // `InstantiatedCapturedFunction` encapsulates all the runtime support needed |
227 | // to execute a tensorflow function. |
228 | // |
229 | // While `CapturedFunction` encapsulates constant attributes of the function, |
230 | // such as its name and captured arguments, `InstantiatedCapturedFunction` |
231 | // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function |
232 | // handle. |
233 | // |
234 | // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute |
235 | // functions outside of the normal `OpKernel::Compute()` context. |
236 | class InstantiatedCapturedFunction { |
237 | public: |
238 | // Runs the instantiated captured function. This method takes ownership of |
239 | // the tensors in `args`, in order to be able to deallocate them as early as |
240 | // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain |
241 | // ownership of the `args`. |
242 | Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, |
243 | std::vector<Tensor>* rets) const; |
244 | |
245 | // Runs the instantiated captured function. This method takes ownership of |
246 | // the tensors in `args`, in order to be able to deallocate them as early as |
247 | // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain |
248 | // ownership of the `args`. Pass non-null `node` to record processing time |
249 | // for modeling Iterator's GetNext() resource usage. When non-null node is |
250 | // provided, the pre-requisite is that the calling thread has previously |
251 | // called `DatasetBaseIterator::RecordStart(). |
252 | Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, |
253 | std::vector<Tensor>* rets, |
254 | const std::shared_ptr<model::Node>& node) const; |
255 | |
256 | // Synchronously runs the captured function on the given `args`, and stores |
257 | // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when |
258 | // possible. |
259 | Status RunWithBorrowedArgs(IteratorContext* ctx, |
260 | const std::vector<Tensor>& args, |
261 | std::vector<Tensor>* rets) const; |
262 | |
263 | // Synchronously runs the captured function on the given `args`, and stores |
264 | // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when |
265 | // possible. Pass non-null `node` to record processing time for modeling |
266 | // Iterator's GetNext() resource usage. When non-null node is provided, the |
267 | // pre-requisite is that the calling thread has previously called |
268 | // `DatasetBaseIterator::RecordStart(). |
269 | Status RunWithBorrowedArgs(IteratorContext* ctx, |
270 | const std::vector<Tensor>& args, |
271 | std::vector<Tensor>* rets, |
272 | const std::shared_ptr<model::Node>& node) const; |
273 | |
274 | // Synchronously runs the captured function on the given `args`, and stores |
275 | // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when |
276 | // possible. This can be useful for calling a captured function in cases where |
277 | // an `IteratorContext*` is not available (such as a destructor). |
278 | // |
279 | // TODO(b/144278100): Avoid running functions without IteratorContext. |
280 | Status RunInstantiated(const std::vector<Tensor>& args, |
281 | std::vector<Tensor>* rets); |
282 | |
283 | // Asynchronously runs the captured function on the given `args`, stores the |
284 | // results in `*rets`, and calls the given `done` callback when the function |
285 | // returns. This method takes ownership of the tensors in `args`, in order to |
286 | // be able to deallocate them as early as possible. Pass non-null `node` to |
287 | // record processing time for modeling Iterator's GetNext() resource usage. |
288 | // When non-null node is provided, the pre-requisite is that the calling |
289 | // thread has previously called `DatasetBaseIterator::RecordStart(). |
290 | void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, |
291 | std::vector<Tensor>* rets, |
292 | FunctionLibraryRuntime::DoneCallback done, |
293 | const std::shared_ptr<model::Node>& node) const; |
294 | |
295 | private: |
296 | friend class CapturedFunction; |
297 | |
298 | InstantiatedCapturedFunction( |
299 | FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, |
300 | DataTypeVector ret_types, |
301 | std::function<void(std::function<void()>)> runner, |
302 | CapturedFunction* captured_func, bool is_multi_device); |
303 | |
304 | // Determines whether a rendezvous object should be created when running the |
305 | // instantiated function. |
306 | bool ShouldCreateRendezvous() const; |
307 | |
308 | FunctionLibraryRuntime* const lib_; // Not owned. |
309 | const FunctionLibraryRuntime::Handle f_handle_; |
310 | const DataTypeVector ret_types_; |
311 | // Note: We capture the runner at function instantiation time to be able to |
312 | // run the function without `IteratorContext` via `RunInstantiated`. |
313 | std::function<void(std::function<void()>)> captured_runner_; |
314 | CapturedFunction* const captured_func_; // Not owned. |
315 | const bool is_multi_device_; |
316 | |
317 | TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); |
318 | }; |
319 | |
320 | } // namespace data |
321 | } // namespace tensorflow |
322 | |
323 | #endif // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ |
324 | |