1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |
17 | |
18 | #include <deque> |
19 | #include <memory> |
20 | #include <unordered_map> |
21 | |
22 | #include "absl/memory/memory.h" |
23 | #include "tensorflow/core/framework/attr_value.pb.h" |
24 | #include "tensorflow/core/framework/attr_value_util.h" |
25 | #include "tensorflow/core/framework/cancellation.h" |
26 | #include "tensorflow/core/framework/collective.h" |
27 | #include "tensorflow/core/framework/dataset_metadata.pb.h" |
28 | #include "tensorflow/core/framework/dataset_options.pb.h" |
29 | #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" |
30 | #include "tensorflow/core/framework/function.h" |
31 | #include "tensorflow/core/framework/function_handle_cache.h" |
32 | #include "tensorflow/core/framework/graph.pb.h" |
33 | #include "tensorflow/core/framework/model.h" |
34 | #include "tensorflow/core/framework/node_def.pb.h" |
35 | #include "tensorflow/core/framework/op_kernel.h" |
36 | #include "tensorflow/core/framework/register_types.h" |
37 | #include "tensorflow/core/framework/thread_factory.h" |
38 | #include "tensorflow/core/framework/types.pb.h" |
39 | #include "tensorflow/core/framework/variant_encode_decode.h" |
40 | #include "tensorflow/core/framework/variant_tensor_data.h" |
41 | #include "tensorflow/core/lib/core/errors.h" |
42 | #include "tensorflow/core/lib/core/threadpool.h" |
43 | #include "tensorflow/core/lib/core/threadpool_interface.h" |
44 | #include "tensorflow/core/lib/strings/str_util.h" |
45 | #include "tensorflow/core/lib/strings/strcat.h" |
46 | #include "tensorflow/core/platform/cpu_info.h" |
47 | #include "tensorflow/core/platform/env.h" |
48 | #include "tensorflow/core/platform/refcount.h" |
49 | #include "tensorflow/core/platform/tracing.h" |
50 | |
51 | // Polymorphic datasets should support all primitive TensorFlow |
52 | // types. Use this macro to expand `m(T)` once for each primitive type |
53 | // `T`, e.g. to build a `switch` statement. |
54 | #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) |
55 | |
56 | namespace tensorflow { |
57 | |
58 | // Forward declarations to avoid introducing a dependency on headers in |
59 | // "tensorflow/core/graph/...". |
60 | class GraphDefBuilder; |
61 | class Node; |
62 | |
63 | namespace data { |
64 | |
65 | namespace internal { |
66 | // Merges Options from source to destination. If there is a conflict on a field, |
67 | // the field value from the source takes precedence. |
68 | void MergeOptions(const protobuf::Message& source, |
69 | protobuf::Message* destination); |
70 | void MergeOptions(const protobuf::MessageLite& source, |
71 | protobuf::MessageLite* destination); |
72 | } // namespace internal |
73 | |
74 | using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>; |
75 | |
76 | constexpr char kTFDataFunction[] = "_tf_data_function" ; |
77 | |
78 | constexpr int kInfiniteCardinality = -1; |
79 | constexpr int kUnknownCardinality = -2; |
80 | |
81 | // This constant is a magic number that is used (as a prefix) to identify keys |
82 | // used for serialization of iterator state. |
83 | constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b" ; |
84 | constexpr char kPipe[] = "|" ; |
85 | constexpr char kColon[] = ":" ; |
86 | |
87 | constexpr char kTFDataResourceTag[] = "tfdata" ; |
88 | constexpr char kTraceInfoUnavailable[] = "unavailable" ; |
89 | constexpr char kMetadata[] = "metadata" ; |
90 | |
91 | constexpr char kCardinalityAttrForRewrite[] = "_cardinality" ; |
92 | |
93 | class DatasetBase; |
94 | class SerializationContext; |
95 | |
96 | inline bool IsTFDataFunction(const FunctionDef& func) { |
97 | auto iter = func.attr().find(data::kTFDataFunction); |
98 | return (iter != func.attr().end() && iter->second.b()); |
99 | } |
100 | |
101 | // Interface for reading values from a key-value store. |
102 | // Used for restoring iterator state. This class is thread safe. |
103 | // Please see comment on IteratorStateWriter for guidance around using the |
104 | // Read*(key, val) vs Read*(name, key, val). |
105 | class IteratorStateReader { |
106 | public: |
107 | // Determines whether the iterator state contains the given key. |
108 | virtual bool Contains(StringPiece key) const = 0; |
109 | virtual bool Contains(StringPiece name, StringPiece key) const = 0; |
110 | |
111 | // Reads an integer for the given key. |
112 | virtual Status ReadScalar(StringPiece key, int64_t* val) const = 0; |
113 | virtual Status ReadScalar(StringPiece name, StringPiece key, |
114 | int64_t* val) const = 0; |
115 | |
116 | // Reads a string for the given key. |
117 | virtual Status ReadScalar(StringPiece key, tstring* val) const = 0; |
118 | virtual Status ReadScalar(StringPiece name, StringPiece key, |
119 | tstring* val) const = 0; |
120 | |
121 | // Reads a tensor for the given key. |
122 | // TODO(jsimsa): Remove non-FLR overrides once all callers are updated. |
123 | virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0; |
124 | virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, |
125 | Tensor* val) const = 0; |
126 | virtual Status ReadTensor(StringPiece name, StringPiece key, |
127 | Tensor* val) const = 0; |
128 | virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, |
129 | StringPiece key, Tensor* val) const = 0; |
130 | |
131 | virtual ~IteratorStateReader() {} |
132 | }; |
133 | |
134 | // Interface for writing values to a key-value store. |
135 | // Used for saving iterator state. Not thread safe. |
136 | // The IteratorStateWriter creates a tensor for each unique iterator name it |
137 | // sees. For the Write*(key, val) API's the key is expected to encode this |
138 | // name as keys are required to be produced using the full_name() method. |
139 | // Each tensor has an upper limit of 2 GB and so if the state for an iterator |
140 | // might exceed the 2 GB limit, you can pass an explicit name in via the |
141 | // Write*(name, key, val) APIs allowing you to further split up the state |
142 | // into more manageable chunks. |
143 | class IteratorStateWriter { |
144 | public: |
145 | // Writes an integer for the given key. |
146 | virtual Status WriteScalar(StringPiece key, const int64_t val) = 0; |
147 | virtual Status WriteScalar(StringPiece name, StringPiece key, |
148 | const int64_t val) = 0; |
149 | |
150 | // Writes a string for the given key. |
151 | virtual Status WriteScalar(StringPiece key, const tstring& val) = 0; |
152 | virtual Status WriteScalar(StringPiece name, StringPiece key, |
153 | const tstring& val) = 0; |
154 | |
155 | // Writes a tensor for the given key. |
156 | virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; |
157 | virtual Status WriteTensor(StringPiece name, StringPiece key, |
158 | const Tensor& val) = 0; |
159 | |
160 | virtual ~IteratorStateWriter() {} |
161 | }; |
162 | |
163 | // Generates a full name key for iterator checkpointing. All keys generated for |
164 | // iterator checkpoints should go through this function. |
165 | std::string FullName(const std::string& prefix, const std::string& name); |
166 | |
167 | // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. |
168 | class GraphDefBuilderWrapper { |
169 | public: |
170 | explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} |
171 | |
172 | // Adds a Const node with scalar value to the Graph. |
173 | // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
174 | // non-null if the method returns with an OK status. |
175 | // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. |
176 | template <typename T> |
177 | Status AddScalar(const T& val, Node** output) { |
178 | Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); |
179 | val_t.scalar<T>()() = val; |
180 | AddTensorInternal(val_t, output); |
181 | if (*output == nullptr) { |
182 | return errors::Internal("AddScalar: Failed to build Const op." ); |
183 | } |
184 | return OkStatus(); |
185 | } |
186 | |
187 | // Adds a Const node with vector value to the Graph. |
188 | // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
189 | // non-null if the method returns with an OK status. |
190 | // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. |
191 | // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? |
192 | template <typename T> |
193 | Status AddVector(const std::vector<T>& val, Node** output) { |
194 | Tensor val_t = Tensor(DataTypeToEnum<T>::v(), |
195 | TensorShape({static_cast<int64_t>(val.size())})); |
196 | for (size_t i = 0; i < val.size(); i++) { |
197 | val_t.flat<T>()(i) = val[i]; |
198 | } |
199 | AddTensorInternal(val_t, output); |
200 | if (*output == nullptr) { |
201 | return errors::Internal("AddVector: Failed to build Const op." ); |
202 | } |
203 | return OkStatus(); |
204 | } |
205 | |
206 | Status AddVector(const std::vector<string>& val, Node** output) { |
207 | Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(), |
208 | TensorShape({static_cast<int64_t>(val.size())})); |
209 | for (size_t i = 0; i < val.size(); i++) { |
210 | val_t.flat<tstring>()(i) = val[i]; |
211 | } |
212 | AddTensorInternal(val_t, output); |
213 | if (*output == nullptr) { |
214 | return errors::Internal("AddVector: Failed to build Const op." ); |
215 | } |
216 | return OkStatus(); |
217 | } |
218 | |
219 | // Adds a `Const` node for the given tensor value to the graph. |
220 | // |
221 | // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
222 | // non-null if the method returns with an OK status. The returned `Node` |
223 | // pointer is owned by the backing graph of `GraphDefBuilder`. |
224 | Status AddTensor(const Tensor& val, Node** output) { |
225 | AddTensorInternal(val, output); |
226 | if (*output == nullptr) { |
227 | return errors::Internal("AddTensor: Failed to build Const op." ); |
228 | } |
229 | return OkStatus(); |
230 | } |
231 | |
232 | // Adds a `Placeholder` node for the given tensor value to the graph. |
233 | // |
234 | // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
235 | // non-null if the method returns with an OK status. The returned `Node` |
236 | // pointer is owned by the backing graph of `GraphDefBuilder`. |
237 | Status AddPlaceholder(const Tensor& val, Node** output) { |
238 | AddPlaceholderInternal(val, output); |
239 | if (*output == nullptr) { |
240 | return errors::Internal( |
241 | "AddPlaceholder: Failed to build Placeholder op." ); |
242 | } |
243 | return OkStatus(); |
244 | } |
245 | |
246 | // Adds a node for the given dataset to the `Graph`. The value of |
247 | // `DatasetBase::type_string()` is used as the op type for the node. Values |
248 | // for the `output_types` and `output_shapes` node attributes are also written |
249 | // if those attributes are defined in the `OpDef`. |
250 | // |
251 | // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is |
252 | // used as the op name for the node. This argument should only be set when |
253 | // serializing `DatasetBase` instances which might not have been created |
254 | // through op kernel execution to make sure the dataset op name is preserved |
255 | // across serialization boundaries, which is in turn needed to make sure |
256 | // iterator checkpoints are valid across serialization boundaries. When |
257 | // `use_dataset_name` is set, the caller is responsible for making sure that |
258 | // the op name is unique across the graph. |
259 | // |
260 | // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
261 | // non-null if the method returns with an OK status. The returned `Node` |
262 | // pointer is owned by the backing `Graph` of `GraphDefBuilder`. |
263 | Status AddDataset(const DatasetBase* dataset, |
264 | const std::vector<Node*>& inputs, Node** output); |
265 | Status AddDataset(const DatasetBase* dataset, |
266 | const std::vector<Node*>& inputs, |
267 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
268 | Node** output); |
269 | Status AddDataset( |
270 | const DatasetBase* dataset, |
271 | const std::vector<std::pair<size_t, Node*>>& inputs, |
272 | const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, |
273 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
274 | Node** output); |
275 | Status AddDataset( |
276 | const DatasetBase* dataset, |
277 | const std::vector<std::pair<size_t, Node*>>& inputs, |
278 | const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, |
279 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
280 | bool use_dataset_name, Node** output); |
281 | |
282 | // Adds a user-defined function with name `function_name` to the graph and |
283 | // recursively adds all functions it references. If a function with a matching |
284 | // name has already been added, returns with OK status. If a user-defined with |
285 | // name `function_name` is not found in the context's function library, |
286 | // returns an InvalidArgumentError. If the function with name `function_name` |
287 | // or any of its dependent functions are stateful, and the context does not |
288 | // explicitly permit stateful functions, returns an InvalidArgument error. |
289 | Status AddFunction(SerializationContext* ctx, const string& function_name, |
290 | const FunctionLibraryDefinition& lib_def); |
291 | |
292 | template <typename T> |
293 | void BuildAttrValue(const T& value, AttrValue* attr) { |
294 | SetAttrValue(value, attr); |
295 | } |
296 | |
297 | template <typename T> |
298 | AttrValue BuildAttrValue(const T& value) { |
299 | AttrValue attr; |
300 | SetAttrValue(value, &attr); |
301 | return attr; |
302 | } |
303 | |
304 | protected: |
305 | GraphDefBuilder* builder() { return b_; } |
306 | |
307 | private: |
308 | void AddPlaceholderInternal(const Tensor& val, Node** output); |
309 | void AddTensorInternal(const Tensor& val, Node** output); |
310 | bool HasAttr(const string& op_type_name, const string& attr_name) const; |
311 | |
312 | bool HasAttr(const OpDef* op_def, const string& attr_name) const { |
313 | for (const auto& attr : op_def->attr()) { |
314 | if (attr.name() == attr_name) { |
315 | return true; |
316 | } |
317 | } |
318 | return false; |
319 | } |
320 | |
321 | Status AddAttrFunctions(SerializationContext* ctx, |
322 | const AttrValue& attr_value, |
323 | const FunctionLibraryDefinition& lib_def) { |
324 | if (attr_value.has_func()) { |
325 | TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def)); |
326 | } else if (attr_value.has_list()) { |
327 | for (const NameAttrList& name_attr_list : attr_value.list().func()) { |
328 | TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def)); |
329 | } |
330 | } |
331 | return OkStatus(); |
332 | } |
333 | |
334 | GraphDefBuilder* b_; |
335 | }; |
336 | |
337 | class StatsAggregator; |
338 | |
339 | // A utility class for running a function and ensuring that there is always a |
340 | // `tensorflow::data` symbol on the stack. |
341 | class Runner { |
342 | public: |
343 | virtual ~Runner() {} |
344 | |
345 | // Runs the given function. |
346 | virtual void Run(const std::function<void()>& f) = 0; |
347 | |
348 | // Returns a global singleton Runner. |
349 | static Runner* get(); |
350 | }; |
351 | |
352 | // A class which provides a sequence of splits. Splits represent subdivisions of |
353 | // a dataset, e.g. filenames or ranges within files. We use splitting to |
354 | // partition input data into smaller pieces for distributed processing (see |
355 | // go/tf-data-splitting-design). |
356 | // |
357 | // Datasets provide a `MakeSplitProvider` method to expose a listing of their |
358 | // splits. |
359 | // |
360 | // Iterators created with a split provider will only iterate over the splits |
361 | // provided by the split provider. |
362 | class SplitProvider { |
363 | public: |
364 | virtual ~SplitProvider() {} |
365 | // Stores the next split in `*split`, setting `*end_of_splits` to indicate |
366 | // whether there were any splits left. |
367 | virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0; |
368 | // Resets the split provider to its beginning. |
369 | virtual Status Reset() = 0; |
370 | // Saves the state of this split provider. |
371 | virtual Status Save(std::function<std::string(std::string)> full_name, |
372 | IteratorStateWriter* writer) = 0; |
373 | // Restores the state of this split provider. |
374 | virtual Status Restore(std::function<std::string(std::string)> full_name, |
375 | IteratorStateReader* reader) = 0; |
376 | }; |
377 | |
378 | // Returns the runner threadpool size from an OpKernelContext. |
379 | int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx); |
380 | |
381 | // A cut-down version of `OpKernelContext` for running computations in |
382 | // iterators. Note that we cannot simply use `OpKernelContext` here because we |
383 | // might run computation in an iterator whose lifetime is not nested within the |
384 | // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching). |
385 | // |
386 | // TODO(mrry): We're making some daring assumptions about the lifetime of the |
387 | // runner passed in here. A runner will be deleted when the original step ends, |
388 | // but all existing runners only close over session-lifetime (or longer-lived) |
389 | // state, so we can make a copy of the function. There's nothing in the |
390 | // definition of the API from which we took the runner to guarantee that what we |
391 | // are doing is safe. We should formalize the properties here. |
392 | class IteratorContext { |
393 | public: |
394 | struct Params { |
395 | explicit Params(IteratorContext* ctx) |
396 | : allocator_getter(ctx->allocator_getter()), |
397 | cancellation_manager(ctx->cancellation_manager()), |
398 | collective_executor(ctx->collective_executor()), |
399 | env(ctx->env()), |
400 | flr(ctx->flr()), |
401 | function_handle_cache(ctx->function_handle_cache()), |
402 | interleave_depth(ctx->interleave_depth()), |
403 | is_restoring(ctx->is_restoring()), |
404 | model(ctx->model()), |
405 | options(ctx->options()), |
406 | resource_mgr(ctx->resource_mgr()), |
407 | runner(*(ctx->runner())), |
408 | runner_threadpool_size(ctx->runner_threadpool_size()), |
409 | split_providers(ctx->split_providers()), |
410 | stats_aggregator(ctx->stats_aggregator()), |
411 | thread_factory(ctx->thread_factory()), |
412 | thread_pool(ctx->thread_pool()) {} |
413 | |
414 | explicit Params(OpKernelContext* ctx) |
415 | : collective_executor(ctx->collective_executor()), |
416 | env(ctx->env()), |
417 | flr(ctx->function_library()) { |
418 | // NOTE: need reinterpret_cast because function.h forward-declares Device. |
419 | DeviceBase* device = |
420 | reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); |
421 | allocator_getter = [device](AllocatorAttributes attrs) { |
422 | return device->GetAllocator(attrs); |
423 | }; |
424 | |
425 | runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx); |
426 | |
427 | // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so |
428 | // that a symbol in the tensorflow::data namespace is always on the stack |
429 | // when executing a function inside a Dataset. |
430 | runner = std::bind( |
431 | []( |
432 | // Note: `runner` is a const reference to avoid copying it. |
433 | const std::function<void(std::function<void()>)>& ctx_runner, |
434 | std::function<void()> fn) { |
435 | std::function<void()> wrapped_fn = std::bind( |
436 | [](const std::function<void()>& fn) { Runner::get()->Run(fn); }, |
437 | std::move(fn)); |
438 | ctx_runner(std::move(wrapped_fn)); |
439 | }, |
440 | *ctx->runner(), std::placeholders::_1); |
441 | } |
442 | |
443 | // The Allocator to be used to allocate the output of an iterator. |
444 | std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; |
445 | |
446 | // The CancellationManager to be used to cancel execution of ops. |
447 | CancellationManager* cancellation_manager = nullptr; |
448 | |
449 | // Collective support. |
450 | CollectiveExecutor* collective_executor = nullptr; |
451 | |
452 | // Interface to operating system functionality. |
453 | Env* env = nullptr; |
454 | |
455 | // The FunctionLibraryRuntime object to be used to make function calls. |
456 | FunctionLibraryRuntime* flr = nullptr; |
457 | |
458 | // A FunctionHandleCache that owns all the function handles. Not owned. |
459 | FunctionHandleCache* function_handle_cache = nullptr; |
460 | |
461 | // Records the number of ParallelInterleave operations in the path from the |
462 | // root node to this node (not including this node) in the input pipeline |
463 | // tree. |
464 | int64 interleave_depth = 0; |
465 | |
466 | // Marks whether the iterator is restored from a checkpoint. |
467 | bool is_restoring = false; |
468 | |
469 | // If non-null, identifies the object used for performance modeling. |
470 | std::shared_ptr<model::Model> model = nullptr; |
471 | |
472 | // The input pipeline options. |
473 | const Options* options = nullptr; |
474 | |
475 | // A resource manager for storing dataset-related state, e.g. random |
476 | // seeds or cached tensors. Not owned. |
477 | ResourceMgr* resource_mgr = nullptr; |
478 | |
479 | // Function call support. |
480 | std::function<void(std::function<void()>)> runner = nullptr; |
481 | |
482 | // Number of threads used for executing user-defined functions. |
483 | int32 runner_threadpool_size = 0; |
484 | |
485 | // Split providers indicating which splits to process. May be empty, |
486 | // indicating that the iterator should process all splits. |
487 | std::vector<std::shared_ptr<SplitProvider>> split_providers; |
488 | |
489 | // The `StatsAggregator` object to record statistics about the iterator. |
490 | // |
491 | // TODO(b/147325552): Remove this API and any of its uses after we switch to |
492 | // using C++ based implementation for tf.data options (on 4/12/2021). |
493 | std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; |
494 | |
495 | // A factory for creating threads to perform blocking work. |
496 | std::shared_ptr<ThreadFactory> thread_factory = nullptr; |
497 | |
498 | // A shared thread pool to schedule computation into. |
499 | thread::ThreadPoolInterface* thread_pool = nullptr; |
500 | }; |
501 | |
502 | explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} |
503 | |
504 | explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {} |
505 | |
506 | explicit IteratorContext(Params params) : params_(std::move(params)) {} |
507 | |
508 | Allocator* allocator(AllocatorAttributes attrs) { |
509 | return params_.allocator_getter(attrs); |
510 | } |
511 | |
512 | std::function<Allocator*(AllocatorAttributes)> allocator_getter() { |
513 | return params_.allocator_getter; |
514 | } |
515 | |
516 | CancellationManager* cancellation_manager() { |
517 | return params_.cancellation_manager; |
518 | } |
519 | |
520 | CollectiveExecutor* collective_executor() { |
521 | return params_.collective_executor; |
522 | } |
523 | |
524 | Env* env() const { return params_.env; } |
525 | |
526 | FunctionLibraryRuntime* flr() { return params_.flr; } |
527 | |
528 | FunctionHandleCache* function_handle_cache() { |
529 | return params_.function_handle_cache; |
530 | } |
531 | |
532 | int64 interleave_depth() { return params_.interleave_depth; } |
533 | |
534 | bool is_restoring() { return params_.is_restoring; } |
535 | |
536 | const std::shared_ptr<model::Model>& model() { return params_.model; } |
537 | |
538 | const Options* options() { return params_.options; } |
539 | |
540 | ResourceMgr* resource_mgr() { return params_.resource_mgr; } |
541 | |
542 | std::function<void(std::function<void()>)>* runner() { |
543 | return ¶ms_.runner; |
544 | } |
545 | |
546 | int32 runner_threadpool_size() { return params_.runner_threadpool_size; } |
547 | |
548 | std::vector<std::shared_ptr<SplitProvider>> split_providers() { |
549 | return params_.split_providers; |
550 | } |
551 | |
552 | std::shared_ptr<StatsAggregator> stats_aggregator() { |
553 | return params_.stats_aggregator; |
554 | } |
555 | |
556 | const std::shared_ptr<ThreadFactory>& thread_factory() { |
557 | return params_.thread_factory; |
558 | } |
559 | |
560 | thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; } |
561 | |
562 | std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name, |
563 | int num_threads) { |
564 | if (params_.thread_pool) { |
565 | // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which |
566 | // is an instance of `thread::ThreadPoolInterface`). Notably, the |
567 | // ownership of `params_.thread_pool` is *not* transferred onto the newly |
568 | // created `ThreadPool` instance. |
569 | return absl::make_unique<thread::ThreadPool>(params_.thread_pool); |
570 | } else { |
571 | return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(), |
572 | name, num_threads, |
573 | /*low_latency_hint=*/false); |
574 | } |
575 | } |
576 | |
577 | std::unique_ptr<Thread> StartThread(const string& name, |
578 | std::function<void()> fn) { |
579 | if (params_.thread_factory) { |
580 | return params_.thread_factory->StartThread(name, std::move(fn)); |
581 | } else { |
582 | return absl::WrapUnique( |
583 | Env::Default()->StartThread({}, name, std::move(fn))); |
584 | } |
585 | } |
586 | |
587 | private: |
588 | Params params_; |
589 | }; |
590 | |
591 | // Aggregates runtime support needed for dataset and iterator serialization. |
592 | class SerializationContext { |
593 | public: |
594 | // Enum describing what to do during serialization when external state is |
595 | // encountered. |
596 | enum class ExternalStatePolicy : int64 { |
597 | // Proceed with serialization, but log a warning about what state will be |
598 | // lost. |
599 | kWarn = 0, |
600 | // Proceed with serialization without logging any warning. |
601 | kIgnore = 1, |
602 | // Fail the serialization with an error. |
603 | kFail = 2, |
604 | }; |
605 | |
606 | // Handles the CheckExternalState status according to the external state |
607 | // policy. |
608 | Status HandleCheckExternalStateStatus(Status s) { |
609 | if (s.ok()) { |
610 | return s; |
611 | } |
612 | switch (params_.external_state_policy) { |
613 | case ExternalStatePolicy::kWarn: |
614 | LOG(WARNING) << s.ToString(); |
615 | return OkStatus(); |
616 | case ExternalStatePolicy::kIgnore: |
617 | VLOG(2) << "Ignoring error status: " << s.ToString(); |
618 | return OkStatus(); |
619 | case ExternalStatePolicy::kFail: |
620 | return s; |
621 | } |
622 | LOG(FATAL) << "Control should never reach here" ; |
623 | } |
624 | |
625 | struct Params { |
626 | explicit Params() {} |
627 | |
628 | explicit Params(OpKernelContext* ctx) |
629 | : resource_mgr(ctx->resource_manager()), |
630 | device_name(ctx->device()->attributes().name()) {} |
631 | |
632 | std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned. |
633 | |
634 | // Indicates what to do if the dataset depends on external state. |
635 | ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn; |
636 | |
637 | // Indicates whether the serialization is for rewrites. |
638 | // |
639 | // If true: |
640 | // * A dataset that doesn't implement serialization is replaced with a |
641 | // placeholder returned in `input_list`. |
642 | // * Data tensors are replaced with a placeholder returned in |
643 | // `input_list`. |
644 | // * Datasets that use random seeds should not serialize the random seeds. |
645 | // This doesn't affect datasets that use fixed seeds; fixed seeds will |
646 | // always be preserved. |
647 | // * Cardinality is serialized as an unregistered attribute |
648 | // `_cardinality`. |
649 | // If false: |
650 | // * A dataset that doesn't implement serialization should result in an |
651 | // error. |
652 | // * Data tensors (potentially large) should be serialized. |
653 | // * Datasets that use random seeds should serialize the random seeds. |
654 | bool is_graph_rewrite = false; |
655 | |
656 | // A resource manager for looking up resources during serialization. |
657 | ResourceMgr* resource_mgr; |
658 | |
659 | // The name of the device doing the serialization. |
660 | std::string device_name; |
661 | }; |
662 | |
663 | explicit SerializationContext(Params params) : params_(params) {} |
664 | |
665 | std::vector<std::pair<string, Tensor>>* input_list() { |
666 | return params_.input_list; |
667 | } |
668 | |
669 | ExternalStatePolicy external_state_policy() const { |
670 | return params_.external_state_policy; |
671 | } |
672 | |
673 | bool is_graph_rewrite() const { return params_.is_graph_rewrite; } |
674 | |
675 | const ResourceMgr* resource_mgr() const { return params_.resource_mgr; } |
676 | |
677 | const std::string& device_name() const { return params_.device_name; } |
678 | |
679 | private: |
680 | Params params_; |
681 | |
682 | TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext); |
683 | }; |
684 | |
685 | // Represents the current position in a range of outputs, where the |
686 | // range of outputs is typically represented by an `DatasetBase`, |
687 | // defined below. |
688 | class IteratorBase { |
689 | public: |
690 | virtual ~IteratorBase() { |
691 | for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { |
692 | (*rit)(); |
693 | } |
694 | } |
695 | |
696 | // Gets the next output from the range that this iterator is traversing. |
697 | // |
698 | // If at least one output remains in this iterator's range, that |
699 | // output will be stored in `*out_tensors` and `false` will be |
700 | // stored in `*end_of_sequence`. |
701 | // |
702 | // If no more outputs remain in this iterator's range, `true` will be stored |
703 | // in `*end_of_sequence`, and `*out_tensors` will be empty. |
704 | // |
705 | // Implementations should never return `OutOfRange` error. If at end of |
706 | // sequence, set `*end_of_sequence = true` and return `Status::OK()`. |
707 | // Internally raised `OutOfRange` errors that do not imply end of sequence |
708 | // should be converted to a different error type before being propagated to |
709 | // the caller. |
710 | // |
711 | // Implementations must explicitly set `*end_of_sequence = false` if an |
712 | // `Status::OK()` status is returned and the iterator is not at the end of the |
713 | // sequence. |
714 | // |
715 | // This method is thread-safe. |
716 | // |
717 | // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and |
718 | // potentially remove this method. |
719 | virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
720 | bool* end_of_sequence) = 0; |
721 | |
722 | Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors, |
723 | bool* end_of_sequence) { |
724 | return GetNext(&ctx, out_tensors, end_of_sequence); |
725 | } |
726 | |
727 | // Skips the next `num_to_skip` outputs from the range that this iterator |
728 | // is traversing. |
729 | // |
730 | // If there are not enough outputs to skip, it will set |
731 | // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will |
732 | // store the number of outputs that are skipped. When `*end_of_sequence` is |
733 | // `false`, `*num_skipped` should equal to `num_to_skip`. |
734 | virtual Status Skip(IteratorContext* ctx, int num_to_skip, |
735 | bool* end_of_sequence, int* num_skipped) = 0; |
736 | |
737 | virtual Status Skip(IteratorContext&& ctx, int num_to_skip, |
738 | bool* end_of_sequence, int* num_skipped) { |
739 | return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped); |
740 | } |
741 | |
742 | // Returns a vector of DataType values, representing the respective |
743 | // element types of each tuple component in the outputs of this |
744 | // iterator. |
745 | virtual const DataTypeVector& output_dtypes() const = 0; |
746 | |
747 | // Returns a vector of tensor shapes, representing the respective |
748 | // (and possibly partially defined) shapes of each tuple component |
749 | // in the outputs of this iterator. |
750 | virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; |
751 | |
752 | // Returns a string that identifies the sequence of iterators leading up to |
753 | // this iterator. |
754 | virtual const string& prefix() const = 0; |
755 | |
756 | // Performs initialization that needs to happen outside of a constructor to |
757 | // properly propagate errors. |
758 | virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); } |
759 | |
760 | // Performs initialization of the base iterator. |
761 | Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); |
762 | |
763 | // Saves the state of this iterator. |
764 | virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { |
765 | int64_t start_us = EnvTime::NowMicros(); |
766 | TF_RETURN_IF_ERROR(SaveInternal(ctx, writer)); |
767 | VLOG(1) << "Saved " << prefix() << " in " |
768 | << (EnvTime::NowMicros() - start_us) << "us" ; |
769 | return OkStatus(); |
770 | } |
771 | |
772 | protected: |
773 | // Returns a node that models this iterator. |
774 | virtual std::shared_ptr<model::Node> CreateNode( |
775 | IteratorContext* ctx, model::Node::Args args) const = 0; |
776 | |
777 | // Restores the state of this iterator. |
778 | virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { |
779 | int64_t start_us = EnvTime::NowMicros(); |
780 | TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader)); |
781 | VLOG(1) << "Restored " << prefix() << " in " |
782 | << (EnvTime::NowMicros() - start_us) << "us" ; |
783 | return OkStatus(); |
784 | } |
785 | |
786 | // This is needed so that sub-classes of IteratorBase can call |
787 | // `SaveInternal` on their input iterators. |
788 | Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, |
789 | const std::unique_ptr<IteratorBase>& input) { |
790 | return input->Save(ctx, writer); |
791 | } |
792 | |
793 | // This is needed so that sub-classes of IteratorBase can call |
794 | // `RestoreInternal` on their input iterators. |
795 | Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, |
796 | const std::unique_ptr<IteratorBase>& input) { |
797 | return input->Restore(ctx, reader); |
798 | } |
799 | |
800 | Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader, |
801 | const std::unique_ptr<IteratorBase>& input) { |
802 | return RestoreInput(&ctx, reader, input); |
803 | } |
804 | |
805 | // Saves the state of this iterator. |
806 | // |
807 | // This method is used to store the state of the iterator in a checkpoint. |
808 | // implementations have an override. |
809 | virtual Status SaveInternal(SerializationContext* ctx, |
810 | IteratorStateWriter* writer) = 0; |
811 | |
812 | // Restores the state of this iterator. |
813 | // |
814 | // This method is used to restore the state of the iterator from a checkpoint. |
815 | // |
816 | // Implementations may assume that the iterator is in a clean state. That is, |
817 | // its `Initialize` method has been called, but its `GetNext` method has |
818 | // never been called. |
819 | // implementations have an override. |
820 | virtual Status RestoreInternal(IteratorContext* ctx, |
821 | IteratorStateReader* reader) = 0; |
822 | |
823 | // Returns a pointer to the node representing this iterator in the performance |
824 | // model. It may be null, if performance modeling is not enabled for this |
825 | // iterator. |
826 | std::shared_ptr<model::Node> model_node() const { return node_; } |
827 | |
828 | // Returns the number of elements produced by this iterator. |
829 | int64_t num_elements() const { |
830 | if (node_) return node_->num_elements(); |
831 | return 0; |
832 | } |
833 | |
834 | private: |
835 | // For access to `AddCleanupFunction` and `Restore`. |
836 | friend class DatasetBase; |
837 | friend class DatasetBaseIterator; // for access to `node_` |
838 | |
839 | std::vector<std::function<void()>> cleanup_fns_; |
840 | std::shared_ptr<model::Node> node_ = nullptr; |
841 | const IteratorBase* parent_ = nullptr; // Not owned. |
842 | int64_t id_ = 0; |
843 | int64_t parent_id_ = 0; |
844 | }; |
845 | |
846 | // Represents runtime information needed to construct a dataset. |
847 | class DatasetContext { |
848 | public: |
849 | struct Params { |
850 | string type_string; // op type name of this dataset. |
851 | string node_name; // graph node name of this dataset op, uniquely |
852 | // identifying the dataset in the graph. |
853 | }; |
854 | |
855 | explicit DatasetContext(Params params) : params_(std::move(params)) {} |
856 | |
857 | explicit DatasetContext(OpKernelContext* ctx) { |
858 | params_.type_string = ctx->op_kernel().type_string(); |
859 | params_.node_name = ctx->op_kernel().name(); |
860 | } |
861 | |
862 | const string& type_string() const { return params_.type_string; } |
863 | const string& node_name() const { return params_.node_name; } |
864 | |
865 | private: |
866 | Params params_; |
867 | }; |
868 | |
869 | // Returns the number of bytes allocated for the given tensor. |
870 | int64_t GetAllocatedBytes(const std::vector<Tensor>& element); |
871 | |
872 | // Returns the estimated memory usage in bytes of the given tensor. |
873 | int64_t GetTotalBytes(const std::vector<Tensor>& element); |
874 | |
875 | // Validates and extracts a `DatasetBase` object from `tensor`. |
876 | // |
877 | // `tensor` must have been written by a call to SetVariantTensorToDataset(). |
878 | // |
879 | // The retrieved pointer is a borrowed reference to the dataset, which is owned |
880 | // by the tensor. The consumer must either acquire its own reference to the |
881 | // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not |
882 | // destroyed or mutated while the retrieved pointer is in use. |
883 | Status GetDatasetFromVariantTensor(const Tensor& tensor, |
884 | DatasetBase** out_dataset); |
885 | |
886 | // Stores a `DatasetBase` object in `tensor`. |
887 | // |
888 | // The ownership of `dataset` is transferred to `tensor`. |
889 | Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); |
890 | |
891 | // Represents a (potentially infinite) range of outputs, where each |
892 | // output is a tuple of tensors. |
893 | class DatasetBase : public core::RefCounted { |
894 | public: |
895 | // Key for storing the Dataset graph in the serialized format. |
896 | TF_EXPORT static const char kDatasetGraphKey[]; |
897 | |
898 | // Key for storing the output node of the Dataset graph in the serialized |
899 | // format. |
900 | TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; |
901 | |
902 | explicit DatasetBase(DatasetContext&& ctx) |
903 | : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {} |
904 | |
905 | // Op type name of this dataset. |
906 | const string& type_string() const { return type_string_; } |
907 | |
908 | // Graph node name of this dataset op, uniquely identifying the dataset in |
909 | // the graph. |
910 | const string& node_name() const { return node_name_; } |
911 | |
912 | // Initializes the dataset. |
913 | void Initialize(const Metadata& metadata); |
914 | |
915 | const Metadata& metadata() const { return metadata_; } |
916 | |
917 | const Options& options() const { return options_; } |
918 | |
919 | int64_t num_sources() const { return num_sources_; } |
920 | |
921 | // Returns a new iterator for iterating over the range of elements in |
922 | // this dataset. |
923 | // |
924 | // This method may be called multiple times on the same instance, |
925 | // and the resulting iterators will have distinct state. Each |
926 | // iterator will traverse all elements in this dataset from the |
927 | // start. |
928 | // |
929 | // The prefix identifies the sequence of iterators leading up to the newly |
930 | // created iterator. |
931 | Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent, |
932 | const string& output_prefix, |
933 | std::unique_ptr<IteratorBase>* iterator) const; |
934 | |
935 | Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent, |
936 | const string& output_prefix, |
937 | std::unique_ptr<IteratorBase>* iterator) const { |
938 | return MakeIterator(&ctx, parent, output_prefix, iterator); |
939 | } |
940 | |
941 | // Returns a new iterator restored from the checkpoint data in `reader`. |
942 | Status MakeIteratorFromCheckpoint( |
943 | IteratorContext* ctx, const string& output_prefix, |
944 | IteratorStateReader* reader, |
945 | std::unique_ptr<IteratorBase>* iterator) const { |
946 | std::unique_ptr<IteratorBase> it; |
947 | IteratorContext::Params params(ctx); |
948 | params.is_restoring = true; |
949 | IteratorContext restore_ctx(std::move(params)); |
950 | TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx, |
951 | /*parent=*/nullptr, output_prefix, &it)); |
952 | TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader)); |
953 | *iterator = std::move(it); |
954 | return OkStatus(); |
955 | } |
956 | |
957 | Status MakeIteratorFromCheckpoint( |
958 | IteratorContext&& ctx, const string& output_prefix, |
959 | IteratorStateReader* reader, |
960 | std::unique_ptr<IteratorBase>* iterator) const { |
961 | return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator); |
962 | } |
963 | |
964 | // Returns a split provider which partitions the dataset's data into splits |
965 | // and provides them in a sequence. The split provider is stored in |
966 | // `*split_provider`. |
967 | virtual Status MakeSplitProviders( |
968 | std::vector<std::unique_ptr<SplitProvider>>* split_providers) const; |
969 | |
970 | // Returns a vector of DataType values, representing the respective |
971 | // element types of each tuple component in the outputs of this |
972 | // dataset. |
973 | virtual const DataTypeVector& output_dtypes() const = 0; |
974 | |
975 | // Returns a vector of tensor shapes, representing the respective |
976 | // (and possibly partially defined) shapes of each tuple component |
977 | // in the outputs of this dataset. |
978 | virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; |
979 | |
980 | // Returns the number of bytes allocated for tensors of this dataset. |
981 | virtual int64_t AllocatedBytes() const { return 0; } |
982 | |
983 | // Returns the estimated number of bytes used for tensors of this dataset. |
984 | virtual int64_t TotalBytes() const { return 0; } |
985 | |
986 | // Returns the cardinality of this dataset. |
987 | // TODO(shilpakrish): Remove this overload once all callers are migrated |
988 | // to the API which passes in the options parameter. |
989 | ABSL_DEPRECATED("Use the overload that passes in the options parameter." ) |
990 | int64_t Cardinality() const; |
991 | |
992 | // Returns the cardinality of this dataset based on the options. |
993 | int64_t Cardinality(CardinalityOptions options) const; |
994 | |
995 | // Internal implementation of cardinality for a dataset. |
996 | // TODO(shilpakrish): Remove this overload once all callers are migrated |
997 | // to the API which passes in the options parameter. |
998 | ABSL_DEPRECATED("Use the overload that passes in the options parameter." ) |
999 | virtual int64_t CardinalityInternal() const { return kUnknownCardinality; } |
1000 | |
1001 | // Internal implementation of cardinality for a dataset based on the options. |
1002 | virtual int64_t CardinalityInternal(CardinalityOptions options) const { |
1003 | return kUnknownCardinality; |
1004 | } |
1005 | |
1006 | // A human-readable debug string for this dataset. |
1007 | virtual string DebugString() const = 0; |
1008 | |
1009 | // Stores the dataset's input datasets in `*inputs`. The pointers stored in |
1010 | // `*inputs` are borrowed. The only valid non-ok return status is |
1011 | // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset |
1012 | // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a |
1013 | // default implementation of `MakeSplitProvider` when there is a single input |
1014 | // dataset. |
1015 | virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const; |
1016 | |
1017 | // Indicates whether the dataset depends on any external state which would |
1018 | // prevent it from being serializable. If so, the method returns |
1019 | // `errors::FailedPrecondition` with a message that identifies the external |
1020 | // state. Otherwise, the method returns `Status::OK()`. |
1021 | virtual Status CheckExternalState() const = 0; |
1022 | |
1023 | // Indicates whether the dataset is compatible with random access. |
1024 | Status CheckRandomAccessCompatible(const int64 index) const; |
1025 | |
1026 | // Return the element at a particular index for a randomly accessible dataset. |
1027 | virtual Status Get(OpKernelContext* ctx, int64 index, |
1028 | std::vector<Tensor>* out_tensors) const; |
1029 | |
1030 | // Return a finalized version of the dataset. The returned DatasetBase is |
1031 | // unowned and lives for as long as this dataset. |
1032 | virtual StatusOr<DatasetBase*> Finalize( |
1033 | OpKernelContext* ctx, |
1034 | std::function<StatusOr<core::RefCountPtr<DatasetBase>>()> |
1035 | make_finalized_dataset) const; |
1036 | |
1037 | // Wrapper around a GraphDefBuilder which provides support for serializing |
1038 | // Datasets as GraphDefs. |
1039 | class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { |
1040 | public: |
1041 | explicit DatasetGraphDefBuilder(GraphDefBuilder* b) |
1042 | : GraphDefBuilderWrapper(b) {} |
1043 | Status AddInputDataset(SerializationContext* ctx, |
1044 | const DatasetBase* dataset, Node** output); |
1045 | Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val, |
1046 | Node** output); |
1047 | Status AddIdentity(SerializationContext* ctx, |
1048 | const std::string& name_prefix, Node** input, |
1049 | Node** output); |
1050 | |
1051 | private: |
1052 | Status AddDatasetOrTensorHelper(SerializationContext* ctx, |
1053 | const Tensor& val, Node** output); |
1054 | Status AddResourceHelper(SerializationContext* ctx, const Tensor& val, |
1055 | Node** output); |
1056 | }; |
1057 | |
1058 | protected: |
1059 | friend class CapturedFunction; |
1060 | |
1061 | // Serializes the dataset into a `GraphDef`, which has two uses: |
1062 | // |
1063 | // 1) To perform static input pipeline optimizations, tf.data serializes the |
1064 | // dataset graph, applies graph rewrites, and then deserializes the graph. |
1065 | // If a subclass of `DatasetBase` does not implement this method, then it will |
1066 | // be excluded from static optimizations (and so will any upstream datasets). |
1067 | // |
1068 | // 2) To save the dataset so that it can restore at a later point (possibly in |
1069 | // different environment). If a subclass of `DatasetBase` does not implement |
1070 | // this method, then this migration will not be possible. |
1071 | virtual Status AsGraphDefInternal(SerializationContext* ctx, |
1072 | DatasetGraphDefBuilder* b, |
1073 | Node** node) const = 0; |
1074 | |
1075 | virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( |
1076 | const string& prefix) const = 0; |
1077 | |
1078 | void set_options(const Options& options) { options_ = options; } |
1079 | |
1080 | private: |
1081 | // Computes and stores the cardinality of a given dataset. |
1082 | Status ComputeCardinality(); |
1083 | |
1084 | // Computes the number of source datasets feeding into this dataset. A source |
1085 | // dataset is a leaf in the subtree of dataset inputs. |
1086 | Status ComputeNumSources(); |
1087 | |
1088 | // Merges options from inputs to this dataset. If there is a conflict in a |
1089 | // field value, the options set on this dataset takes precedence over those in |
1090 | // the inputs. The order of precedence on the inputs is in the same order as |
1091 | // how they appear for this dataset. |
1092 | Status MergeOptionsFromInputs(); |
1093 | |
1094 | const string type_string_; |
1095 | const string node_name_; |
1096 | Metadata metadata_; |
1097 | Options options_; |
1098 | mutable mutex mu_; |
1099 | mutable mutex cardinality_mu_; |
1100 | mutable core::RefCountPtr<DatasetBase> finalized_dataset_; |
1101 | // The number of source datasets feeding into the dataset. A source dataset |
1102 | // is a leaf in the subtree of dataset inputs. |
1103 | int64_t num_sources_ = -1; |
1104 | mutable int64_t cardinality_ TF_GUARDED_BY(cardinality_mu_) = |
1105 | kUnknownCardinality; |
1106 | }; |
1107 | |
1108 | // Represents an iterator that is associated with a particular dataset. |
1109 | class DatasetBaseIterator : public IteratorBase { |
1110 | public: |
1111 | struct BaseParams { |
1112 | // Owns one reference on the shared dataset object. |
1113 | const DatasetBase* dataset; |
1114 | |
1115 | // Identifies the sequence of iterators leading up to this iterator. |
1116 | const string prefix; |
1117 | }; |
1118 | |
1119 | explicit DatasetBaseIterator(const BaseParams& params); |
1120 | |
1121 | ~DatasetBaseIterator() override; |
1122 | |
1123 | virtual const DatasetBase* dataset() const { return params_.dataset; } |
1124 | |
1125 | const DataTypeVector& output_dtypes() const override { |
1126 | return params_.dataset->output_dtypes(); |
1127 | } |
1128 | |
1129 | const std::vector<PartialTensorShape>& output_shapes() const override { |
1130 | return params_.dataset->output_shapes(); |
1131 | } |
1132 | |
1133 | const string& prefix() const override { return params_.prefix; } |
1134 | |
1135 | // Returns a name to be used for the TraceMe event. |
1136 | // |
1137 | // NOTE: TraceMe supports passing key-value pairs of "arguments" using the |
1138 | // following format "name#arg_1=value_,...,arg_n=value_n". |
1139 | string BuildTraceMeName(); |
1140 | |
1141 | Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
1142 | bool* end_of_sequence) final; |
1143 | |
1144 | Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors, |
1145 | bool* end_of_sequence) { |
1146 | return GetNext(&ctx, out_tensors, end_of_sequence); |
1147 | } |
1148 | |
1149 | Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, |
1150 | int* num_skipped) final; |
1151 | |
1152 | Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final { |
1153 | VLOG(2) << "Attempting to save checkpoints on iterator (prefix: " |
1154 | << prefix() << ") from " << dataset()->DebugString(); |
1155 | return IteratorBase::Save(ctx, writer); |
1156 | } |
1157 | |
1158 | protected: |
1159 | Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final { |
1160 | VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: " |
1161 | << prefix() << ") from " << dataset()->DebugString(); |
1162 | return IteratorBase::Restore(ctx, reader); |
1163 | } |
1164 | |
1165 | // Internal implementation of GetNext that is wrapped in tracing logic. |
1166 | // |
1167 | // See the docstring of `GetNext` method regaring the contract for |
1168 | // `out_tensors` and `end_of_sequence`. Implementations may assume that |
1169 | // `*out_tensors` is empty. |
1170 | virtual Status GetNextInternal(IteratorContext* ctx, |
1171 | std::vector<Tensor>* out_tensors, |
1172 | bool* end_of_sequence) = 0; |
1173 | |
1174 | // Internal implementation of Skip that is wrapped in tracing logic |
1175 | virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip, |
1176 | bool* end_of_sequence, int* num_skipped); |
1177 | |
1178 | string full_name(const string& name) const { |
1179 | return FullName(params_.prefix, name); |
1180 | } |
1181 | |
1182 | // Returns a map of key-value pairs to included in the TraceMe string. |
1183 | virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; } |
1184 | |
1185 | // By default we model iterators using an unknown node, which acts as |
1186 | // pass-through with respect to performance modeling. |
1187 | std::shared_ptr<model::Node> CreateNode( |
1188 | IteratorContext* ctx, model::Node::Args args) const override { |
1189 | return model::MakeUnknownNode(std::move(args)); |
1190 | } |
1191 | |
1192 | // When modeling is enabled, this method disables autotuning for the given |
1193 | // iterator (and the transitive closure of its inputs). |
1194 | void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) { |
1195 | if (iterator->node_) { |
1196 | iterator->node_->set_autotune(false); |
1197 | } |
1198 | } |
1199 | |
1200 | // When modeling is enabled, this method enables autotuning for the given |
1201 | // iterator (and the transitive closure of its inputs). |
1202 | void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) { |
1203 | if (iterator->node_) { |
1204 | iterator->node_->set_autotune(true); |
1205 | } |
1206 | } |
1207 | |
1208 | // When modeling is enabled, this method records the fact that this iterator |
1209 | // has dequeued an element from an internal buffer. |
1210 | void RecordBufferDequeue(IteratorContext* ctx, |
1211 | const std::vector<Tensor>& element) { |
1212 | if (collect_resource_usage(ctx)) { |
1213 | node_->record_buffer_event(-GetAllocatedBytes(element), -1); |
1214 | DCHECK_GE(node_->buffered_elements(), 0); |
1215 | } |
1216 | } |
1217 | |
1218 | // When modeling is enabled, this method records the fact that this iterator |
1219 | // has enqueued an element in an internal buffer. |
1220 | void RecordBufferEnqueue(IteratorContext* ctx, |
1221 | const std::vector<Tensor>& element) { |
1222 | if (collect_resource_usage(ctx)) { |
1223 | node_->record_buffer_event(GetAllocatedBytes(element), 1); |
1224 | } |
1225 | } |
1226 | |
1227 | // When modeling is enabled, this method records the fact that this iterator |
1228 | // has produced an element and its size in bytes. |
1229 | void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) { |
1230 | if (collect_resource_usage(ctx)) { |
1231 | int64_t num_bytes = GetAllocatedBytes(*out_tensors); |
1232 | node_->record_element(); |
1233 | node_->record_bytes_produced(num_bytes); |
1234 | if (node_->output()) { |
1235 | node_->output()->record_bytes_consumed(num_bytes); |
1236 | } |
1237 | } |
1238 | } |
1239 | |
1240 | // When modeling is enabled, this method records the fact that a thread of |
1241 | // this iterator has started work. |
1242 | void RecordStart(IteratorContext* ctx) { |
1243 | if (collect_resource_usage(ctx)) { |
1244 | int64_t now_nanos = EnvTime::NowNanos(); |
1245 | node_->record_start(now_nanos); |
1246 | } |
1247 | } |
1248 | |
1249 | // When modeling is enabled, this method records the fact that a thread of |
1250 | // this iterator has stopped work. |
1251 | void RecordStop(IteratorContext* ctx) { |
1252 | if (collect_resource_usage(ctx)) { |
1253 | int64_t now_nanos = EnvTime::NowNanos(); |
1254 | node_->record_stop(now_nanos); |
1255 | } |
1256 | } |
1257 | |
1258 | // Returns whether work is currently being recorded, i.e. whether we are |
1259 | // currently between a `RecordStart` and a `RecordStop`. |
1260 | bool IsRecording(IteratorContext* ctx) { |
1261 | return node_ && node_->is_recording(); |
1262 | } |
1263 | |
1264 | private: |
1265 | bool collect_resource_usage(IteratorContext* ctx) { |
1266 | return ctx->model() && node_; |
1267 | } |
1268 | |
1269 | string traceme_metadata_; |
1270 | BaseParams params_; |
1271 | }; |
1272 | |
1273 | // Represents an iterator that is associated with a particular dataset |
1274 | // with a particular type. |
1275 | template <class DatasetType> |
1276 | class DatasetIterator : public DatasetBaseIterator { |
1277 | public: |
1278 | struct Params { |
1279 | // Borrowed pointer to the dataset. |
1280 | const DatasetType* dataset; |
1281 | |
1282 | // Identifies the sequence of iterators leading up to this iterator. |
1283 | const string prefix; |
1284 | }; |
1285 | |
1286 | explicit DatasetIterator(const Params& params) |
1287 | : DatasetBaseIterator({params.dataset, params.prefix}), |
1288 | typed_dataset_(params.dataset) {} |
1289 | |
1290 | // The dataset from which this iterator was created. |
1291 | const DatasetType* dataset() const final { return typed_dataset_; } |
1292 | |
1293 | private: |
1294 | const DatasetType* const typed_dataset_; // Not owned. |
1295 | }; |
1296 | |
1297 | template <typename T> |
1298 | Status ParseScalarArgument(OpKernelContext* ctx, |
1299 | const StringPiece& argument_name, T* output) { |
1300 | const Tensor* argument_t; |
1301 | TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); |
1302 | if (!TensorShapeUtils::IsScalar(argument_t->shape())) { |
1303 | return errors::InvalidArgument(argument_name, " must be a scalar" ); |
1304 | } |
1305 | *output = argument_t->scalar<T>()(); |
1306 | return OkStatus(); |
1307 | } |
1308 | |
1309 | template <typename T> |
1310 | Status ParseVectorArgument(OpKernelContext* ctx, |
1311 | const StringPiece& argument_name, |
1312 | std::vector<T>* output) { |
1313 | const Tensor* argument_t; |
1314 | TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); |
1315 | if (!TensorShapeUtils::IsVector(argument_t->shape())) { |
1316 | return errors::InvalidArgument(argument_name, " must be a vector" ); |
1317 | } |
1318 | int size = argument_t->vec<T>().size(); |
1319 | output->reserve(size); |
1320 | for (int i = 0; i < size; ++i) { |
1321 | output->push_back(argument_t->vec<T>()(i)); |
1322 | } |
1323 | return OkStatus(); |
1324 | } |
1325 | |
1326 | // Encapsulates the work required to plug a DatasetBase into the core TensorFlow |
1327 | // graph execution engine. |
1328 | class DatasetOpKernel : public OpKernel { |
1329 | public: |
1330 | explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1331 | if (ctx->HasAttr(kMetadata)) { |
1332 | std::string serialized_metadata; |
1333 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata)); |
1334 | OP_REQUIRES(ctx, metadata_.ParseFromString(serialized_metadata), |
1335 | errors::InvalidArgument(absl::StrCat( |
1336 | "Could not parse the 'metadata' attribute." ))); |
1337 | } |
1338 | } |
1339 | |
1340 | void Compute(OpKernelContext* ctx) final; |
1341 | |
1342 | // Checks whether the given op is a tf.data operation. |
1343 | // |
1344 | // NOTE: The check uses a heuristic and can produce both false positives and |
1345 | // false negatives. In particular, tf.data operations are expected to use |
1346 | // names that end with "Dataset" or "DatasetV[0-9]+". |
1347 | static bool IsDatasetOp(const OpDef& op_def); |
1348 | |
1349 | string TraceString(const OpKernelContext& ctx, bool verbose) const override; |
1350 | |
1351 | protected: |
1352 | // Subclasses should implement this method. It will be called during Compute |
1353 | // execution. |
1354 | virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; |
1355 | |
1356 | private: |
1357 | Metadata metadata_; |
1358 | }; |
1359 | |
1360 | // Encapsulates the work required to plug unary Datasets into the core |
1361 | // TensorFlow graph execution engine. |
1362 | class UnaryDatasetOpKernel : public DatasetOpKernel { |
1363 | public: |
1364 | explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx) |
1365 | : DatasetOpKernel(ctx) {} |
1366 | |
1367 | protected: |
1368 | void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; |
1369 | virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
1370 | DatasetBase** output) = 0; |
1371 | }; |
1372 | |
1373 | // Encapsulates the work required to plug binary Datasets into the core |
1374 | // TensorFlow graph execution engine. |
1375 | class BinaryDatasetOpKernel : public DatasetOpKernel { |
1376 | public: |
1377 | explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx) |
1378 | : DatasetOpKernel(ctx) {} |
1379 | |
1380 | protected: |
1381 | void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; |
1382 | virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
1383 | DatasetBase* another_input, |
1384 | DatasetBase** output) = 0; |
1385 | }; |
1386 | |
1387 | // A simple background worker that executes closures asynchronously and without |
1388 | // blocking. |
1389 | // |
1390 | // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` |
1391 | // to avoid blocking an executor thread that may be required by the blocking |
1392 | // work. |
1393 | // |
1394 | // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this |
1395 | // purpose because its current implementation (in Eigen) uses a finite-length |
1396 | // queue and will block the caller when full. This can lead to deadlock under |
1397 | // heavy load. Since the number of concurrent work items in each user of a |
1398 | // `BackgroundWorker` is at most one per op invocation, the dynamic allocation |
1399 | // overhead is tolerable. |
1400 | class BackgroundWorker { |
1401 | public: |
1402 | BackgroundWorker(Env* env, const char* name); |
1403 | |
1404 | ~BackgroundWorker(); |
1405 | |
1406 | void Schedule(std::function<void()> work_item); |
1407 | |
1408 | private: |
1409 | void WorkerLoop(); |
1410 | |
1411 | Env* const env_; |
1412 | const char* const name_; |
1413 | |
1414 | std::unique_ptr<Thread> thread_; |
1415 | mutex mu_; |
1416 | condition_variable cond_var_; |
1417 | bool cancelled_ TF_GUARDED_BY(mu_) = false; |
1418 | std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_); |
1419 | }; |
1420 | |
1421 | } // namespace data |
1422 | } // namespace tensorflow |
1423 | |
1424 | #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |
1425 | |