1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
16#define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
17
18#include <memory>
19#include <vector>
20
21#include "tensorflow/core/framework/cancellation.h"
22#include "tensorflow/core/framework/dataset.h"
23#include "tensorflow/core/framework/function.h"
24#include "tensorflow/core/framework/model.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29#include "tensorflow/core/lib/random/random.h"
30#include "tensorflow/core/platform/macros.h"
31
32namespace tensorflow {
33
34class Device;
35class OpKernelContext;
36class ResourceMgr;
37
38namespace data {
39
40class CapturedFunction;
41class InstantiatedCapturedFunction;
42
43// Creates an iterator for a dataset which is created by applying the given
44// function to the given input element.
45Status MakeIteratorFromInputElement(
46 IteratorContext* ctx, const IteratorBase* parent,
47 const std::vector<Tensor>& input_element, int64_t thread_index,
48 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
49 std::unique_ptr<IteratorBase>* out_iterator);
50
51// Creates an iterator for a dataset which is created by applying the given
52// function to the given input element. Pass non-null `node` to record
53// processing time for modeling Iterator's GetNext() resource usage.
54Status MakeIteratorFromInputElement(
55 IteratorContext* ctx, const IteratorBase* parent,
56 const std::vector<Tensor>& input_element, int64_t thread_index,
57 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
58 std::unique_ptr<IteratorBase>* out_iterator,
59 const std::shared_ptr<model::Node>& node);
60
61// Creates an iterator context appropriate for a nested dataset's iterator. A
62// nested dataset is a dataset created within another dataset, e.g. by the
63// function passed to `interleave` or `flat_map`.
64IteratorContext MakeNestedIteratorContext(IteratorContext* ctx);
65
66struct ShortCircuitInfo {
67 std::vector<int> indices;
68 std::vector<bool> can_move;
69};
70
71// Metadata shared across all captures of the same function.
72class FunctionMetadata {
73 public:
74 struct Params {
75 bool use_inter_op_parallelism = true;
76 bool use_default_device = true;
77 };
78
79 // Creates a new instance of the `FunctionMetadata` class, fetching function
80 // from a context argument.
81 static Status Create(tensorflow::OpKernelConstruction* ctx,
82 const string& func_name, Params params,
83 std::shared_ptr<FunctionMetadata>* out_metadata);
84
85 // Creates a new instance of the `FunctionMetadata` class, using the provided
86 // function.
87 static Status Create(tensorflow::OpKernelConstruction* ctx,
88 NameAttrList&& func, Params params,
89 std::shared_ptr<FunctionMetadata>* out_metadata);
90
91 // Returns the named list of function arguments.
92 const NameAttrList& func() const { return func_; }
93
94 // Returns a borrowed pointer to the function library that contains the
95 // transitive closure of definitions used by the function.
96 const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); }
97
98 // Returns short-circuit information.
99 const ShortCircuitInfo& short_circuit_info() const {
100 return short_circuit_info_;
101 }
102
103 // Indicates whether a default device should be used for executing function
104 // ops.
105 bool use_default_device() const { return use_default_device_; }
106
107 // Indicates whether to use inter-op parallelism for execution of the
108 // function.
109 bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; }
110
111 // Indicates whether the function should a multi-device function backend.
112 bool use_multi_device_function() const { return use_multi_device_function_; }
113
114 private:
115 FunctionMetadata(NameAttrList&& func, Params params)
116 : func_(std::move(func)),
117 use_default_device_(params.use_default_device),
118 use_inter_op_parallelism_(params.use_inter_op_parallelism) {}
119
120 NameAttrList func_;
121 std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
122 ShortCircuitInfo short_circuit_info_;
123 bool use_default_device_ = true;
124 bool use_inter_op_parallelism_ = true;
125 bool use_multi_device_function_ = true;
126};
127
128// Constructs and stores the parameters for the CapturedFunction Instantiate
129// function.
130struct InstantiateCapturedFunctionParams {
131 explicit InstantiateCapturedFunctionParams(IteratorContext* ctx) {
132 flr = ctx->flr();
133 function_handle_cache = ctx->function_handle_cache();
134 runner = ctx->runner();
135 }
136
137 explicit InstantiateCapturedFunctionParams(OpKernelContext* ctx) {
138 flr = ctx->function_library();
139 function_handle_cache = nullptr;
140 runner = ctx->runner();
141 }
142
143 FunctionLibraryRuntime* flr;
144 FunctionHandleCache* function_handle_cache;
145 std::function<void(std::function<void()>)>* runner;
146};
147
148// A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
149// arguments that it closed over in the user program.
150class CapturedFunction {
151 public:
152 // Creates a new instance using a list of named attributes, fetching captured
153 // inputs from a context argument.
154 static Status Create(OpKernelContext* ctx,
155 std::shared_ptr<const FunctionMetadata> metadata,
156 const string& argument_name,
157 std::unique_ptr<CapturedFunction>* out_function);
158
159 // Creates a new instance using a list of named attributes, using provided
160 // captured inputs.
161 static Status Create(OpKernelContext* ctx,
162 std::shared_ptr<const FunctionMetadata> metadata,
163 std::vector<Tensor>&& captured_inputs,
164 std::unique_ptr<CapturedFunction>* out_function);
165
166 // Adds the definition of this captured function into the given graph,
167 // returning its captured inputs and types through the respective output
168 // arguments.
169 Status AddToGraph(SerializationContext* ctx,
170 DatasetBase::DatasetGraphDefBuilder* b,
171 std::vector<Node*>* other_arguments,
172 DataTypeVector* other_arguments_types) const;
173
174 // Instantiates this function for use in the given context, providing an
175 // InstantiatedCapturedFunction that can be used to execute functions.
176 Status Instantiate(IteratorContext* ctx,
177 std::unique_ptr<InstantiatedCapturedFunction>*
178 instantiated_captured_function);
179
180 Status Instantiate(InstantiateCapturedFunctionParams params,
181 std::unique_ptr<InstantiatedCapturedFunction>*
182 instantiated_captured_function);
183
184 // Determines whether the captured function is stateful.
185 Status CheckExternalState() const;
186
187 // Returns the additional captured inputs that will be passed to the function.
188 const std::vector<Tensor>& captured_inputs() const {
189 return captured_inputs_;
190 }
191
192 // Returns the named list of function arguments.
193 const NameAttrList& func() const { return metadata_->func(); }
194
195 // Returns the transitive set of function definition required to instantiate
196 // this function.
197 const FunctionLibraryDefinition* lib_def() const {
198 return metadata_->lib_def();
199 }
200
201 // If every function output corresponds to one of its inputs, the method
202 // returns the mapping from output indices to input indices. Otherwise, it
203 // returns an empty list.
204 const ShortCircuitInfo& short_circuit_info() const {
205 return metadata_->short_circuit_info();
206 }
207
208 // Indicates whether the function should use inter op parallelism.
209 bool use_inter_op_parallelism() const {
210 return metadata_->use_inter_op_parallelism();
211 }
212
213 private:
214 CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
215 std::vector<Tensor> captured_inputs);
216
217 Status IsMultiDevice(FunctionLibraryRuntime* flr,
218 bool* is_multi_device) const;
219
220 const std::shared_ptr<const FunctionMetadata> metadata_;
221 const std::vector<Tensor> captured_inputs_;
222
223 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
224};
225
226// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
227// to execute a tensorflow function.
228//
229// While `CapturedFunction` encapsulates constant attributes of the function,
230// such as its name and captured arguments, `InstantiatedCapturedFunction`
231// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
232// handle.
233//
234// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
235// functions outside of the normal `OpKernel::Compute()` context.
236class InstantiatedCapturedFunction {
237 public:
238 // Runs the instantiated captured function. This method takes ownership of
239 // the tensors in `args`, in order to be able to deallocate them as early as
240 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
241 // ownership of the `args`.
242 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
243 std::vector<Tensor>* rets) const;
244
245 // Runs the instantiated captured function. This method takes ownership of
246 // the tensors in `args`, in order to be able to deallocate them as early as
247 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
248 // ownership of the `args`. Pass non-null `node` to record processing time
249 // for modeling Iterator's GetNext() resource usage. When non-null node is
250 // provided, the pre-requisite is that the calling thread has previously
251 // called `DatasetBaseIterator::RecordStart().
252 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
253 std::vector<Tensor>* rets,
254 const std::shared_ptr<model::Node>& node) const;
255
256 // Synchronously runs the captured function on the given `args`, and stores
257 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
258 // possible.
259 Status RunWithBorrowedArgs(IteratorContext* ctx,
260 const std::vector<Tensor>& args,
261 std::vector<Tensor>* rets) const;
262
263 // Synchronously runs the captured function on the given `args`, and stores
264 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
265 // possible. Pass non-null `node` to record processing time for modeling
266 // Iterator's GetNext() resource usage. When non-null node is provided, the
267 // pre-requisite is that the calling thread has previously called
268 // `DatasetBaseIterator::RecordStart().
269 Status RunWithBorrowedArgs(IteratorContext* ctx,
270 const std::vector<Tensor>& args,
271 std::vector<Tensor>* rets,
272 const std::shared_ptr<model::Node>& node) const;
273
274 // Synchronously runs the captured function on the given `args`, and stores
275 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
276 // possible. This can be useful for calling a captured function in cases where
277 // an `IteratorContext*` is not available (such as a destructor).
278 //
279 // TODO(b/144278100): Avoid running functions without IteratorContext.
280 Status RunInstantiated(const std::vector<Tensor>& args,
281 std::vector<Tensor>* rets);
282
283 // Asynchronously runs the captured function on the given `args`, stores the
284 // results in `*rets`, and calls the given `done` callback when the function
285 // returns. This method takes ownership of the tensors in `args`, in order to
286 // be able to deallocate them as early as possible. Pass non-null `node` to
287 // record processing time for modeling Iterator's GetNext() resource usage.
288 // When non-null node is provided, the pre-requisite is that the calling
289 // thread has previously called `DatasetBaseIterator::RecordStart().
290 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
291 std::vector<Tensor>* rets,
292 FunctionLibraryRuntime::DoneCallback done,
293 const std::shared_ptr<model::Node>& node) const;
294
295 private:
296 friend class CapturedFunction;
297
298 InstantiatedCapturedFunction(
299 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
300 DataTypeVector ret_types,
301 std::function<void(std::function<void()>)> runner,
302 CapturedFunction* captured_func, bool is_multi_device);
303
304 // Determines whether a rendezvous object should be created when running the
305 // instantiated function.
306 bool ShouldCreateRendezvous() const;
307
308 FunctionLibraryRuntime* const lib_; // Not owned.
309 const FunctionLibraryRuntime::Handle f_handle_;
310 const DataTypeVector ret_types_;
311 // Note: We capture the runner at function instantiation time to be able to
312 // run the function without `IteratorContext` via `RunInstantiated`.
313 std::function<void(std::function<void()>)> captured_runner_;
314 CapturedFunction* const captured_func_; // Not owned.
315 const bool is_multi_device_;
316
317 TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
318};
319
320} // namespace data
321} // namespace tensorflow
322
323#endif // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
324