1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16#define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17
18#include <deque>
19#include <memory>
20#include <unordered_map>
21
22#include "absl/memory/memory.h"
23#include "tensorflow/core/framework/attr_value.pb.h"
24#include "tensorflow/core/framework/attr_value_util.h"
25#include "tensorflow/core/framework/cancellation.h"
26#include "tensorflow/core/framework/collective.h"
27#include "tensorflow/core/framework/dataset_metadata.pb.h"
28#include "tensorflow/core/framework/dataset_options.pb.h"
29#include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
30#include "tensorflow/core/framework/function.h"
31#include "tensorflow/core/framework/function_handle_cache.h"
32#include "tensorflow/core/framework/graph.pb.h"
33#include "tensorflow/core/framework/model.h"
34#include "tensorflow/core/framework/node_def.pb.h"
35#include "tensorflow/core/framework/op_kernel.h"
36#include "tensorflow/core/framework/register_types.h"
37#include "tensorflow/core/framework/thread_factory.h"
38#include "tensorflow/core/framework/types.pb.h"
39#include "tensorflow/core/framework/variant_encode_decode.h"
40#include "tensorflow/core/framework/variant_tensor_data.h"
41#include "tensorflow/core/lib/core/errors.h"
42#include "tensorflow/core/lib/core/threadpool.h"
43#include "tensorflow/core/lib/core/threadpool_interface.h"
44#include "tensorflow/core/lib/strings/str_util.h"
45#include "tensorflow/core/lib/strings/strcat.h"
46#include "tensorflow/core/platform/cpu_info.h"
47#include "tensorflow/core/platform/env.h"
48#include "tensorflow/core/platform/refcount.h"
49#include "tensorflow/core/platform/tracing.h"
50
51// Polymorphic datasets should support all primitive TensorFlow
52// types. Use this macro to expand `m(T)` once for each primitive type
53// `T`, e.g. to build a `switch` statement.
54#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
55
56namespace tensorflow {
57
58// Forward declarations to avoid introducing a dependency on headers in
59// "tensorflow/core/graph/...".
60class GraphDefBuilder;
61class Node;
62
63namespace data {
64
65namespace internal {
66// Merges Options from source to destination. If there is a conflict on a field,
67// the field value from the source takes precedence.
68void MergeOptions(const protobuf::Message& source,
69 protobuf::Message* destination);
70void MergeOptions(const protobuf::MessageLite& source,
71 protobuf::MessageLite* destination);
72} // namespace internal
73
74using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
75
76constexpr char kTFDataFunction[] = "_tf_data_function";
77
78constexpr int kInfiniteCardinality = -1;
79constexpr int kUnknownCardinality = -2;
80
81// This constant is a magic number that is used (as a prefix) to identify keys
82// used for serialization of iterator state.
83constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
84constexpr char kPipe[] = "|";
85constexpr char kColon[] = ":";
86
87constexpr char kTFDataResourceTag[] = "tfdata";
88constexpr char kTraceInfoUnavailable[] = "unavailable";
89constexpr char kMetadata[] = "metadata";
90
91constexpr char kCardinalityAttrForRewrite[] = "_cardinality";
92
93class DatasetBase;
94class SerializationContext;
95
96inline bool IsTFDataFunction(const FunctionDef& func) {
97 auto iter = func.attr().find(data::kTFDataFunction);
98 return (iter != func.attr().end() && iter->second.b());
99}
100
101// Interface for reading values from a key-value store.
102// Used for restoring iterator state. This class is thread safe.
103// Please see comment on IteratorStateWriter for guidance around using the
104// Read*(key, val) vs Read*(name, key, val).
105class IteratorStateReader {
106 public:
107 // Determines whether the iterator state contains the given key.
108 virtual bool Contains(StringPiece key) const = 0;
109 virtual bool Contains(StringPiece name, StringPiece key) const = 0;
110
111 // Reads an integer for the given key.
112 virtual Status ReadScalar(StringPiece key, int64_t* val) const = 0;
113 virtual Status ReadScalar(StringPiece name, StringPiece key,
114 int64_t* val) const = 0;
115
116 // Reads a string for the given key.
117 virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
118 virtual Status ReadScalar(StringPiece name, StringPiece key,
119 tstring* val) const = 0;
120
121 // Reads a tensor for the given key.
122 // TODO(jsimsa): Remove non-FLR overrides once all callers are updated.
123 virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
124 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
125 Tensor* val) const = 0;
126 virtual Status ReadTensor(StringPiece name, StringPiece key,
127 Tensor* val) const = 0;
128 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
129 StringPiece key, Tensor* val) const = 0;
130
131 virtual ~IteratorStateReader() {}
132};
133
134// Interface for writing values to a key-value store.
135// Used for saving iterator state. Not thread safe.
136// The IteratorStateWriter creates a tensor for each unique iterator name it
137// sees. For the Write*(key, val) API's the key is expected to encode this
138// name as keys are required to be produced using the full_name() method.
139// Each tensor has an upper limit of 2 GB and so if the state for an iterator
140// might exceed the 2 GB limit, you can pass an explicit name in via the
141// Write*(name, key, val) APIs allowing you to further split up the state
142// into more manageable chunks.
143class IteratorStateWriter {
144 public:
145 // Writes an integer for the given key.
146 virtual Status WriteScalar(StringPiece key, const int64_t val) = 0;
147 virtual Status WriteScalar(StringPiece name, StringPiece key,
148 const int64_t val) = 0;
149
150 // Writes a string for the given key.
151 virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
152 virtual Status WriteScalar(StringPiece name, StringPiece key,
153 const tstring& val) = 0;
154
155 // Writes a tensor for the given key.
156 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
157 virtual Status WriteTensor(StringPiece name, StringPiece key,
158 const Tensor& val) = 0;
159
160 virtual ~IteratorStateWriter() {}
161};
162
163// Generates a full name key for iterator checkpointing. All keys generated for
164// iterator checkpoints should go through this function.
165std::string FullName(const std::string& prefix, const std::string& name);
166
167// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
168class GraphDefBuilderWrapper {
169 public:
170 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
171
172 // Adds a Const node with scalar value to the Graph.
173 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
174 // non-null if the method returns with an OK status.
175 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
176 template <typename T>
177 Status AddScalar(const T& val, Node** output) {
178 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
179 val_t.scalar<T>()() = val;
180 AddTensorInternal(val_t, output);
181 if (*output == nullptr) {
182 return errors::Internal("AddScalar: Failed to build Const op.");
183 }
184 return OkStatus();
185 }
186
187 // Adds a Const node with vector value to the Graph.
188 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
189 // non-null if the method returns with an OK status.
190 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
191 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
192 template <typename T>
193 Status AddVector(const std::vector<T>& val, Node** output) {
194 Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
195 TensorShape({static_cast<int64_t>(val.size())}));
196 for (size_t i = 0; i < val.size(); i++) {
197 val_t.flat<T>()(i) = val[i];
198 }
199 AddTensorInternal(val_t, output);
200 if (*output == nullptr) {
201 return errors::Internal("AddVector: Failed to build Const op.");
202 }
203 return OkStatus();
204 }
205
206 Status AddVector(const std::vector<string>& val, Node** output) {
207 Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
208 TensorShape({static_cast<int64_t>(val.size())}));
209 for (size_t i = 0; i < val.size(); i++) {
210 val_t.flat<tstring>()(i) = val[i];
211 }
212 AddTensorInternal(val_t, output);
213 if (*output == nullptr) {
214 return errors::Internal("AddVector: Failed to build Const op.");
215 }
216 return OkStatus();
217 }
218
219 // Adds a `Const` node for the given tensor value to the graph.
220 //
221 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
222 // non-null if the method returns with an OK status. The returned `Node`
223 // pointer is owned by the backing graph of `GraphDefBuilder`.
224 Status AddTensor(const Tensor& val, Node** output) {
225 AddTensorInternal(val, output);
226 if (*output == nullptr) {
227 return errors::Internal("AddTensor: Failed to build Const op.");
228 }
229 return OkStatus();
230 }
231
232 // Adds a `Placeholder` node for the given tensor value to the graph.
233 //
234 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
235 // non-null if the method returns with an OK status. The returned `Node`
236 // pointer is owned by the backing graph of `GraphDefBuilder`.
237 Status AddPlaceholder(const Tensor& val, Node** output) {
238 AddPlaceholderInternal(val, output);
239 if (*output == nullptr) {
240 return errors::Internal(
241 "AddPlaceholder: Failed to build Placeholder op.");
242 }
243 return OkStatus();
244 }
245
246 // Adds a node for the given dataset to the `Graph`. The value of
247 // `DatasetBase::type_string()` is used as the op type for the node. Values
248 // for the `output_types` and `output_shapes` node attributes are also written
249 // if those attributes are defined in the `OpDef`.
250 //
251 // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is
252 // used as the op name for the node. This argument should only be set when
253 // serializing `DatasetBase` instances which might not have been created
254 // through op kernel execution to make sure the dataset op name is preserved
255 // across serialization boundaries, which is in turn needed to make sure
256 // iterator checkpoints are valid across serialization boundaries. When
257 // `use_dataset_name` is set, the caller is responsible for making sure that
258 // the op name is unique across the graph.
259 //
260 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
261 // non-null if the method returns with an OK status. The returned `Node`
262 // pointer is owned by the backing `Graph` of `GraphDefBuilder`.
263 Status AddDataset(const DatasetBase* dataset,
264 const std::vector<Node*>& inputs, Node** output);
265 Status AddDataset(const DatasetBase* dataset,
266 const std::vector<Node*>& inputs,
267 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
268 Node** output);
269 Status AddDataset(
270 const DatasetBase* dataset,
271 const std::vector<std::pair<size_t, Node*>>& inputs,
272 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
273 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
274 Node** output);
275 Status AddDataset(
276 const DatasetBase* dataset,
277 const std::vector<std::pair<size_t, Node*>>& inputs,
278 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
279 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
280 bool use_dataset_name, Node** output);
281
282 // Adds a user-defined function with name `function_name` to the graph and
283 // recursively adds all functions it references. If a function with a matching
284 // name has already been added, returns with OK status. If a user-defined with
285 // name `function_name` is not found in the context's function library,
286 // returns an InvalidArgumentError. If the function with name `function_name`
287 // or any of its dependent functions are stateful, and the context does not
288 // explicitly permit stateful functions, returns an InvalidArgument error.
289 Status AddFunction(SerializationContext* ctx, const string& function_name,
290 const FunctionLibraryDefinition& lib_def);
291
292 template <typename T>
293 void BuildAttrValue(const T& value, AttrValue* attr) {
294 SetAttrValue(value, attr);
295 }
296
297 template <typename T>
298 AttrValue BuildAttrValue(const T& value) {
299 AttrValue attr;
300 SetAttrValue(value, &attr);
301 return attr;
302 }
303
304 protected:
305 GraphDefBuilder* builder() { return b_; }
306
307 private:
308 void AddPlaceholderInternal(const Tensor& val, Node** output);
309 void AddTensorInternal(const Tensor& val, Node** output);
310 bool HasAttr(const string& op_type_name, const string& attr_name) const;
311
312 bool HasAttr(const OpDef* op_def, const string& attr_name) const {
313 for (const auto& attr : op_def->attr()) {
314 if (attr.name() == attr_name) {
315 return true;
316 }
317 }
318 return false;
319 }
320
321 Status AddAttrFunctions(SerializationContext* ctx,
322 const AttrValue& attr_value,
323 const FunctionLibraryDefinition& lib_def) {
324 if (attr_value.has_func()) {
325 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
326 } else if (attr_value.has_list()) {
327 for (const NameAttrList& name_attr_list : attr_value.list().func()) {
328 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
329 }
330 }
331 return OkStatus();
332 }
333
334 GraphDefBuilder* b_;
335};
336
337class StatsAggregator;
338
339// A utility class for running a function and ensuring that there is always a
340// `tensorflow::data` symbol on the stack.
341class Runner {
342 public:
343 virtual ~Runner() {}
344
345 // Runs the given function.
346 virtual void Run(const std::function<void()>& f) = 0;
347
348 // Returns a global singleton Runner.
349 static Runner* get();
350};
351
352// A class which provides a sequence of splits. Splits represent subdivisions of
353// a dataset, e.g. filenames or ranges within files. We use splitting to
354// partition input data into smaller pieces for distributed processing (see
355// go/tf-data-splitting-design).
356//
357// Datasets provide a `MakeSplitProvider` method to expose a listing of their
358// splits.
359//
360// Iterators created with a split provider will only iterate over the splits
361// provided by the split provider.
362class SplitProvider {
363 public:
364 virtual ~SplitProvider() {}
365 // Stores the next split in `*split`, setting `*end_of_splits` to indicate
366 // whether there were any splits left.
367 virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
368 // Resets the split provider to its beginning.
369 virtual Status Reset() = 0;
370 // Saves the state of this split provider.
371 virtual Status Save(std::function<std::string(std::string)> full_name,
372 IteratorStateWriter* writer) = 0;
373 // Restores the state of this split provider.
374 virtual Status Restore(std::function<std::string(std::string)> full_name,
375 IteratorStateReader* reader) = 0;
376};
377
378// Returns the runner threadpool size from an OpKernelContext.
379int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx);
380
381// A cut-down version of `OpKernelContext` for running computations in
382// iterators. Note that we cannot simply use `OpKernelContext` here because we
383// might run computation in an iterator whose lifetime is not nested within the
384// lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
385//
386// TODO(mrry): We're making some daring assumptions about the lifetime of the
387// runner passed in here. A runner will be deleted when the original step ends,
388// but all existing runners only close over session-lifetime (or longer-lived)
389// state, so we can make a copy of the function. There's nothing in the
390// definition of the API from which we took the runner to guarantee that what we
391// are doing is safe. We should formalize the properties here.
392class IteratorContext {
393 public:
394 struct Params {
395 explicit Params(IteratorContext* ctx)
396 : allocator_getter(ctx->allocator_getter()),
397 cancellation_manager(ctx->cancellation_manager()),
398 collective_executor(ctx->collective_executor()),
399 env(ctx->env()),
400 flr(ctx->flr()),
401 function_handle_cache(ctx->function_handle_cache()),
402 interleave_depth(ctx->interleave_depth()),
403 is_restoring(ctx->is_restoring()),
404 model(ctx->model()),
405 options(ctx->options()),
406 resource_mgr(ctx->resource_mgr()),
407 runner(*(ctx->runner())),
408 runner_threadpool_size(ctx->runner_threadpool_size()),
409 split_providers(ctx->split_providers()),
410 stats_aggregator(ctx->stats_aggregator()),
411 thread_factory(ctx->thread_factory()),
412 thread_pool(ctx->thread_pool()) {}
413
414 explicit Params(OpKernelContext* ctx)
415 : collective_executor(ctx->collective_executor()),
416 env(ctx->env()),
417 flr(ctx->function_library()) {
418 // NOTE: need reinterpret_cast because function.h forward-declares Device.
419 DeviceBase* device =
420 reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
421 allocator_getter = [device](AllocatorAttributes attrs) {
422 return device->GetAllocator(attrs);
423 };
424
425 runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
426
427 // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
428 // that a symbol in the tensorflow::data namespace is always on the stack
429 // when executing a function inside a Dataset.
430 runner = std::bind(
431 [](
432 // Note: `runner` is a const reference to avoid copying it.
433 const std::function<void(std::function<void()>)>& ctx_runner,
434 std::function<void()> fn) {
435 std::function<void()> wrapped_fn = std::bind(
436 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
437 std::move(fn));
438 ctx_runner(std::move(wrapped_fn));
439 },
440 *ctx->runner(), std::placeholders::_1);
441 }
442
443 // The Allocator to be used to allocate the output of an iterator.
444 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
445
446 // The CancellationManager to be used to cancel execution of ops.
447 CancellationManager* cancellation_manager = nullptr;
448
449 // Collective support.
450 CollectiveExecutor* collective_executor = nullptr;
451
452 // Interface to operating system functionality.
453 Env* env = nullptr;
454
455 // The FunctionLibraryRuntime object to be used to make function calls.
456 FunctionLibraryRuntime* flr = nullptr;
457
458 // A FunctionHandleCache that owns all the function handles. Not owned.
459 FunctionHandleCache* function_handle_cache = nullptr;
460
461 // Records the number of ParallelInterleave operations in the path from the
462 // root node to this node (not including this node) in the input pipeline
463 // tree.
464 int64 interleave_depth = 0;
465
466 // Marks whether the iterator is restored from a checkpoint.
467 bool is_restoring = false;
468
469 // If non-null, identifies the object used for performance modeling.
470 std::shared_ptr<model::Model> model = nullptr;
471
472 // The input pipeline options.
473 const Options* options = nullptr;
474
475 // A resource manager for storing dataset-related state, e.g. random
476 // seeds or cached tensors. Not owned.
477 ResourceMgr* resource_mgr = nullptr;
478
479 // Function call support.
480 std::function<void(std::function<void()>)> runner = nullptr;
481
482 // Number of threads used for executing user-defined functions.
483 int32 runner_threadpool_size = 0;
484
485 // Split providers indicating which splits to process. May be empty,
486 // indicating that the iterator should process all splits.
487 std::vector<std::shared_ptr<SplitProvider>> split_providers;
488
489 // The `StatsAggregator` object to record statistics about the iterator.
490 //
491 // TODO(b/147325552): Remove this API and any of its uses after we switch to
492 // using C++ based implementation for tf.data options (on 4/12/2021).
493 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
494
495 // A factory for creating threads to perform blocking work.
496 std::shared_ptr<ThreadFactory> thread_factory = nullptr;
497
498 // A shared thread pool to schedule computation into.
499 thread::ThreadPoolInterface* thread_pool = nullptr;
500 };
501
502 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
503
504 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
505
506 explicit IteratorContext(Params params) : params_(std::move(params)) {}
507
508 Allocator* allocator(AllocatorAttributes attrs) {
509 return params_.allocator_getter(attrs);
510 }
511
512 std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
513 return params_.allocator_getter;
514 }
515
516 CancellationManager* cancellation_manager() {
517 return params_.cancellation_manager;
518 }
519
520 CollectiveExecutor* collective_executor() {
521 return params_.collective_executor;
522 }
523
524 Env* env() const { return params_.env; }
525
526 FunctionLibraryRuntime* flr() { return params_.flr; }
527
528 FunctionHandleCache* function_handle_cache() {
529 return params_.function_handle_cache;
530 }
531
532 int64 interleave_depth() { return params_.interleave_depth; }
533
534 bool is_restoring() { return params_.is_restoring; }
535
536 const std::shared_ptr<model::Model>& model() { return params_.model; }
537
538 const Options* options() { return params_.options; }
539
540 ResourceMgr* resource_mgr() { return params_.resource_mgr; }
541
542 std::function<void(std::function<void()>)>* runner() {
543 return &params_.runner;
544 }
545
546 int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
547
548 std::vector<std::shared_ptr<SplitProvider>> split_providers() {
549 return params_.split_providers;
550 }
551
552 std::shared_ptr<StatsAggregator> stats_aggregator() {
553 return params_.stats_aggregator;
554 }
555
556 const std::shared_ptr<ThreadFactory>& thread_factory() {
557 return params_.thread_factory;
558 }
559
560 thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
561
562 std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
563 int num_threads) {
564 if (params_.thread_pool) {
565 // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
566 // is an instance of `thread::ThreadPoolInterface`). Notably, the
567 // ownership of `params_.thread_pool` is *not* transferred onto the newly
568 // created `ThreadPool` instance.
569 return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
570 } else {
571 return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
572 name, num_threads,
573 /*low_latency_hint=*/false);
574 }
575 }
576
577 std::unique_ptr<Thread> StartThread(const string& name,
578 std::function<void()> fn) {
579 if (params_.thread_factory) {
580 return params_.thread_factory->StartThread(name, std::move(fn));
581 } else {
582 return absl::WrapUnique(
583 Env::Default()->StartThread({}, name, std::move(fn)));
584 }
585 }
586
587 private:
588 Params params_;
589};
590
591// Aggregates runtime support needed for dataset and iterator serialization.
592class SerializationContext {
593 public:
594 // Enum describing what to do during serialization when external state is
595 // encountered.
596 enum class ExternalStatePolicy : int64 {
597 // Proceed with serialization, but log a warning about what state will be
598 // lost.
599 kWarn = 0,
600 // Proceed with serialization without logging any warning.
601 kIgnore = 1,
602 // Fail the serialization with an error.
603 kFail = 2,
604 };
605
606 // Handles the CheckExternalState status according to the external state
607 // policy.
608 Status HandleCheckExternalStateStatus(Status s) {
609 if (s.ok()) {
610 return s;
611 }
612 switch (params_.external_state_policy) {
613 case ExternalStatePolicy::kWarn:
614 LOG(WARNING) << s.ToString();
615 return OkStatus();
616 case ExternalStatePolicy::kIgnore:
617 VLOG(2) << "Ignoring error status: " << s.ToString();
618 return OkStatus();
619 case ExternalStatePolicy::kFail:
620 return s;
621 }
622 LOG(FATAL) << "Control should never reach here";
623 }
624
625 struct Params {
626 explicit Params() {}
627
628 explicit Params(OpKernelContext* ctx)
629 : resource_mgr(ctx->resource_manager()),
630 device_name(ctx->device()->attributes().name()) {}
631
632 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
633
634 // Indicates what to do if the dataset depends on external state.
635 ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
636
637 // Indicates whether the serialization is for rewrites.
638 //
639 // If true:
640 // * A dataset that doesn't implement serialization is replaced with a
641 // placeholder returned in `input_list`.
642 // * Data tensors are replaced with a placeholder returned in
643 // `input_list`.
644 // * Datasets that use random seeds should not serialize the random seeds.
645 // This doesn't affect datasets that use fixed seeds; fixed seeds will
646 // always be preserved.
647 // * Cardinality is serialized as an unregistered attribute
648 // `_cardinality`.
649 // If false:
650 // * A dataset that doesn't implement serialization should result in an
651 // error.
652 // * Data tensors (potentially large) should be serialized.
653 // * Datasets that use random seeds should serialize the random seeds.
654 bool is_graph_rewrite = false;
655
656 // A resource manager for looking up resources during serialization.
657 ResourceMgr* resource_mgr;
658
659 // The name of the device doing the serialization.
660 std::string device_name;
661 };
662
663 explicit SerializationContext(Params params) : params_(params) {}
664
665 std::vector<std::pair<string, Tensor>>* input_list() {
666 return params_.input_list;
667 }
668
669 ExternalStatePolicy external_state_policy() const {
670 return params_.external_state_policy;
671 }
672
673 bool is_graph_rewrite() const { return params_.is_graph_rewrite; }
674
675 const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
676
677 const std::string& device_name() const { return params_.device_name; }
678
679 private:
680 Params params_;
681
682 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
683};
684
685// Represents the current position in a range of outputs, where the
686// range of outputs is typically represented by an `DatasetBase`,
687// defined below.
688class IteratorBase {
689 public:
690 virtual ~IteratorBase() {
691 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
692 (*rit)();
693 }
694 }
695
696 // Gets the next output from the range that this iterator is traversing.
697 //
698 // If at least one output remains in this iterator's range, that
699 // output will be stored in `*out_tensors` and `false` will be
700 // stored in `*end_of_sequence`.
701 //
702 // If no more outputs remain in this iterator's range, `true` will be stored
703 // in `*end_of_sequence`, and `*out_tensors` will be empty.
704 //
705 // Implementations should never return `OutOfRange` error. If at end of
706 // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
707 // Internally raised `OutOfRange` errors that do not imply end of sequence
708 // should be converted to a different error type before being propagated to
709 // the caller.
710 //
711 // Implementations must explicitly set `*end_of_sequence = false` if an
712 // `Status::OK()` status is returned and the iterator is not at the end of the
713 // sequence.
714 //
715 // This method is thread-safe.
716 //
717 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
718 // potentially remove this method.
719 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
720 bool* end_of_sequence) = 0;
721
722 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
723 bool* end_of_sequence) {
724 return GetNext(&ctx, out_tensors, end_of_sequence);
725 }
726
727 // Skips the next `num_to_skip` outputs from the range that this iterator
728 // is traversing.
729 //
730 // If there are not enough outputs to skip, it will set
731 // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
732 // store the number of outputs that are skipped. When `*end_of_sequence` is
733 // `false`, `*num_skipped` should equal to `num_to_skip`.
734 virtual Status Skip(IteratorContext* ctx, int num_to_skip,
735 bool* end_of_sequence, int* num_skipped) = 0;
736
737 virtual Status Skip(IteratorContext&& ctx, int num_to_skip,
738 bool* end_of_sequence, int* num_skipped) {
739 return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped);
740 }
741
742 // Returns a vector of DataType values, representing the respective
743 // element types of each tuple component in the outputs of this
744 // iterator.
745 virtual const DataTypeVector& output_dtypes() const = 0;
746
747 // Returns a vector of tensor shapes, representing the respective
748 // (and possibly partially defined) shapes of each tuple component
749 // in the outputs of this iterator.
750 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
751
752 // Returns a string that identifies the sequence of iterators leading up to
753 // this iterator.
754 virtual const string& prefix() const = 0;
755
756 // Performs initialization that needs to happen outside of a constructor to
757 // properly propagate errors.
758 virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); }
759
760 // Performs initialization of the base iterator.
761 Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
762
763 // Saves the state of this iterator.
764 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
765 int64_t start_us = EnvTime::NowMicros();
766 TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
767 VLOG(1) << "Saved " << prefix() << " in "
768 << (EnvTime::NowMicros() - start_us) << "us";
769 return OkStatus();
770 }
771
772 protected:
773 // Returns a node that models this iterator.
774 virtual std::shared_ptr<model::Node> CreateNode(
775 IteratorContext* ctx, model::Node::Args args) const = 0;
776
777 // Restores the state of this iterator.
778 virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
779 int64_t start_us = EnvTime::NowMicros();
780 TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
781 VLOG(1) << "Restored " << prefix() << " in "
782 << (EnvTime::NowMicros() - start_us) << "us";
783 return OkStatus();
784 }
785
786 // This is needed so that sub-classes of IteratorBase can call
787 // `SaveInternal` on their input iterators.
788 Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
789 const std::unique_ptr<IteratorBase>& input) {
790 return input->Save(ctx, writer);
791 }
792
793 // This is needed so that sub-classes of IteratorBase can call
794 // `RestoreInternal` on their input iterators.
795 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
796 const std::unique_ptr<IteratorBase>& input) {
797 return input->Restore(ctx, reader);
798 }
799
800 Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
801 const std::unique_ptr<IteratorBase>& input) {
802 return RestoreInput(&ctx, reader, input);
803 }
804
805 // Saves the state of this iterator.
806 //
807 // This method is used to store the state of the iterator in a checkpoint.
808 // implementations have an override.
809 virtual Status SaveInternal(SerializationContext* ctx,
810 IteratorStateWriter* writer) = 0;
811
812 // Restores the state of this iterator.
813 //
814 // This method is used to restore the state of the iterator from a checkpoint.
815 //
816 // Implementations may assume that the iterator is in a clean state. That is,
817 // its `Initialize` method has been called, but its `GetNext` method has
818 // never been called.
819 // implementations have an override.
820 virtual Status RestoreInternal(IteratorContext* ctx,
821 IteratorStateReader* reader) = 0;
822
823 // Returns a pointer to the node representing this iterator in the performance
824 // model. It may be null, if performance modeling is not enabled for this
825 // iterator.
826 std::shared_ptr<model::Node> model_node() const { return node_; }
827
828 // Returns the number of elements produced by this iterator.
829 int64_t num_elements() const {
830 if (node_) return node_->num_elements();
831 return 0;
832 }
833
834 private:
835 // For access to `AddCleanupFunction` and `Restore`.
836 friend class DatasetBase;
837 friend class DatasetBaseIterator; // for access to `node_`
838
839 std::vector<std::function<void()>> cleanup_fns_;
840 std::shared_ptr<model::Node> node_ = nullptr;
841 const IteratorBase* parent_ = nullptr; // Not owned.
842 int64_t id_ = 0;
843 int64_t parent_id_ = 0;
844};
845
846// Represents runtime information needed to construct a dataset.
847class DatasetContext {
848 public:
849 struct Params {
850 string type_string; // op type name of this dataset.
851 string node_name; // graph node name of this dataset op, uniquely
852 // identifying the dataset in the graph.
853 };
854
855 explicit DatasetContext(Params params) : params_(std::move(params)) {}
856
857 explicit DatasetContext(OpKernelContext* ctx) {
858 params_.type_string = ctx->op_kernel().type_string();
859 params_.node_name = ctx->op_kernel().name();
860 }
861
862 const string& type_string() const { return params_.type_string; }
863 const string& node_name() const { return params_.node_name; }
864
865 private:
866 Params params_;
867};
868
869// Returns the number of bytes allocated for the given tensor.
870int64_t GetAllocatedBytes(const std::vector<Tensor>& element);
871
872// Returns the estimated memory usage in bytes of the given tensor.
873int64_t GetTotalBytes(const std::vector<Tensor>& element);
874
875// Validates and extracts a `DatasetBase` object from `tensor`.
876//
877// `tensor` must have been written by a call to SetVariantTensorToDataset().
878//
879// The retrieved pointer is a borrowed reference to the dataset, which is owned
880// by the tensor. The consumer must either acquire its own reference to the
881// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
882// destroyed or mutated while the retrieved pointer is in use.
883Status GetDatasetFromVariantTensor(const Tensor& tensor,
884 DatasetBase** out_dataset);
885
886// Stores a `DatasetBase` object in `tensor`.
887//
888// The ownership of `dataset` is transferred to `tensor`.
889Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
890
891// Represents a (potentially infinite) range of outputs, where each
892// output is a tuple of tensors.
893class DatasetBase : public core::RefCounted {
894 public:
895 // Key for storing the Dataset graph in the serialized format.
896 TF_EXPORT static const char kDatasetGraphKey[];
897
898 // Key for storing the output node of the Dataset graph in the serialized
899 // format.
900 TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
901
902 explicit DatasetBase(DatasetContext&& ctx)
903 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
904
905 // Op type name of this dataset.
906 const string& type_string() const { return type_string_; }
907
908 // Graph node name of this dataset op, uniquely identifying the dataset in
909 // the graph.
910 const string& node_name() const { return node_name_; }
911
912 // Initializes the dataset.
913 void Initialize(const Metadata& metadata);
914
915 const Metadata& metadata() const { return metadata_; }
916
917 const Options& options() const { return options_; }
918
919 int64_t num_sources() const { return num_sources_; }
920
921 // Returns a new iterator for iterating over the range of elements in
922 // this dataset.
923 //
924 // This method may be called multiple times on the same instance,
925 // and the resulting iterators will have distinct state. Each
926 // iterator will traverse all elements in this dataset from the
927 // start.
928 //
929 // The prefix identifies the sequence of iterators leading up to the newly
930 // created iterator.
931 Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
932 const string& output_prefix,
933 std::unique_ptr<IteratorBase>* iterator) const;
934
935 Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
936 const string& output_prefix,
937 std::unique_ptr<IteratorBase>* iterator) const {
938 return MakeIterator(&ctx, parent, output_prefix, iterator);
939 }
940
941 // Returns a new iterator restored from the checkpoint data in `reader`.
942 Status MakeIteratorFromCheckpoint(
943 IteratorContext* ctx, const string& output_prefix,
944 IteratorStateReader* reader,
945 std::unique_ptr<IteratorBase>* iterator) const {
946 std::unique_ptr<IteratorBase> it;
947 IteratorContext::Params params(ctx);
948 params.is_restoring = true;
949 IteratorContext restore_ctx(std::move(params));
950 TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx,
951 /*parent=*/nullptr, output_prefix, &it));
952 TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader));
953 *iterator = std::move(it);
954 return OkStatus();
955 }
956
957 Status MakeIteratorFromCheckpoint(
958 IteratorContext&& ctx, const string& output_prefix,
959 IteratorStateReader* reader,
960 std::unique_ptr<IteratorBase>* iterator) const {
961 return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
962 }
963
964 // Returns a split provider which partitions the dataset's data into splits
965 // and provides them in a sequence. The split provider is stored in
966 // `*split_provider`.
967 virtual Status MakeSplitProviders(
968 std::vector<std::unique_ptr<SplitProvider>>* split_providers) const;
969
970 // Returns a vector of DataType values, representing the respective
971 // element types of each tuple component in the outputs of this
972 // dataset.
973 virtual const DataTypeVector& output_dtypes() const = 0;
974
975 // Returns a vector of tensor shapes, representing the respective
976 // (and possibly partially defined) shapes of each tuple component
977 // in the outputs of this dataset.
978 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
979
980 // Returns the number of bytes allocated for tensors of this dataset.
981 virtual int64_t AllocatedBytes() const { return 0; }
982
983 // Returns the estimated number of bytes used for tensors of this dataset.
984 virtual int64_t TotalBytes() const { return 0; }
985
986 // Returns the cardinality of this dataset.
987 // TODO(shilpakrish): Remove this overload once all callers are migrated
988 // to the API which passes in the options parameter.
989 ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
990 int64_t Cardinality() const;
991
992 // Returns the cardinality of this dataset based on the options.
993 int64_t Cardinality(CardinalityOptions options) const;
994
995 // Internal implementation of cardinality for a dataset.
996 // TODO(shilpakrish): Remove this overload once all callers are migrated
997 // to the API which passes in the options parameter.
998 ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
999 virtual int64_t CardinalityInternal() const { return kUnknownCardinality; }
1000
1001 // Internal implementation of cardinality for a dataset based on the options.
1002 virtual int64_t CardinalityInternal(CardinalityOptions options) const {
1003 return kUnknownCardinality;
1004 }
1005
1006 // A human-readable debug string for this dataset.
1007 virtual string DebugString() const = 0;
1008
1009 // Stores the dataset's input datasets in `*inputs`. The pointers stored in
1010 // `*inputs` are borrowed. The only valid non-ok return status is
1011 // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
1012 // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
1013 // default implementation of `MakeSplitProvider` when there is a single input
1014 // dataset.
1015 virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
1016
1017 // Indicates whether the dataset depends on any external state which would
1018 // prevent it from being serializable. If so, the method returns
1019 // `errors::FailedPrecondition` with a message that identifies the external
1020 // state. Otherwise, the method returns `Status::OK()`.
1021 virtual Status CheckExternalState() const = 0;
1022
1023 // Indicates whether the dataset is compatible with random access.
1024 Status CheckRandomAccessCompatible(const int64 index) const;
1025
1026 // Return the element at a particular index for a randomly accessible dataset.
1027 virtual Status Get(OpKernelContext* ctx, int64 index,
1028 std::vector<Tensor>* out_tensors) const;
1029
1030 // Return a finalized version of the dataset. The returned DatasetBase is
1031 // unowned and lives for as long as this dataset.
1032 virtual StatusOr<DatasetBase*> Finalize(
1033 OpKernelContext* ctx,
1034 std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
1035 make_finalized_dataset) const;
1036
1037 // Wrapper around a GraphDefBuilder which provides support for serializing
1038 // Datasets as GraphDefs.
1039 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
1040 public:
1041 explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
1042 : GraphDefBuilderWrapper(b) {}
1043 Status AddInputDataset(SerializationContext* ctx,
1044 const DatasetBase* dataset, Node** output);
1045 Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
1046 Node** output);
1047 Status AddIdentity(SerializationContext* ctx,
1048 const std::string& name_prefix, Node** input,
1049 Node** output);
1050
1051 private:
1052 Status AddDatasetOrTensorHelper(SerializationContext* ctx,
1053 const Tensor& val, Node** output);
1054 Status AddResourceHelper(SerializationContext* ctx, const Tensor& val,
1055 Node** output);
1056 };
1057
1058 protected:
1059 friend class CapturedFunction;
1060
1061 // Serializes the dataset into a `GraphDef`, which has two uses:
1062 //
1063 // 1) To perform static input pipeline optimizations, tf.data serializes the
1064 // dataset graph, applies graph rewrites, and then deserializes the graph.
1065 // If a subclass of `DatasetBase` does not implement this method, then it will
1066 // be excluded from static optimizations (and so will any upstream datasets).
1067 //
1068 // 2) To save the dataset so that it can restore at a later point (possibly in
1069 // different environment). If a subclass of `DatasetBase` does not implement
1070 // this method, then this migration will not be possible.
1071 virtual Status AsGraphDefInternal(SerializationContext* ctx,
1072 DatasetGraphDefBuilder* b,
1073 Node** node) const = 0;
1074
1075 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
1076 const string& prefix) const = 0;
1077
1078 void set_options(const Options& options) { options_ = options; }
1079
1080 private:
1081 // Computes and stores the cardinality of a given dataset.
1082 Status ComputeCardinality();
1083
1084 // Computes the number of source datasets feeding into this dataset. A source
1085 // dataset is a leaf in the subtree of dataset inputs.
1086 Status ComputeNumSources();
1087
1088 // Merges options from inputs to this dataset. If there is a conflict in a
1089 // field value, the options set on this dataset takes precedence over those in
1090 // the inputs. The order of precedence on the inputs is in the same order as
1091 // how they appear for this dataset.
1092 Status MergeOptionsFromInputs();
1093
1094 const string type_string_;
1095 const string node_name_;
1096 Metadata metadata_;
1097 Options options_;
1098 mutable mutex mu_;
1099 mutable mutex cardinality_mu_;
1100 mutable core::RefCountPtr<DatasetBase> finalized_dataset_;
1101 // The number of source datasets feeding into the dataset. A source dataset
1102 // is a leaf in the subtree of dataset inputs.
1103 int64_t num_sources_ = -1;
1104 mutable int64_t cardinality_ TF_GUARDED_BY(cardinality_mu_) =
1105 kUnknownCardinality;
1106};
1107
1108// Represents an iterator that is associated with a particular dataset.
1109class DatasetBaseIterator : public IteratorBase {
1110 public:
1111 struct BaseParams {
1112 // Owns one reference on the shared dataset object.
1113 const DatasetBase* dataset;
1114
1115 // Identifies the sequence of iterators leading up to this iterator.
1116 const string prefix;
1117 };
1118
1119 explicit DatasetBaseIterator(const BaseParams& params);
1120
1121 ~DatasetBaseIterator() override;
1122
1123 virtual const DatasetBase* dataset() const { return params_.dataset; }
1124
1125 const DataTypeVector& output_dtypes() const override {
1126 return params_.dataset->output_dtypes();
1127 }
1128
1129 const std::vector<PartialTensorShape>& output_shapes() const override {
1130 return params_.dataset->output_shapes();
1131 }
1132
1133 const string& prefix() const override { return params_.prefix; }
1134
1135 // Returns a name to be used for the TraceMe event.
1136 //
1137 // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
1138 // following format "name#arg_1=value_,...,arg_n=value_n".
1139 string BuildTraceMeName();
1140
1141 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
1142 bool* end_of_sequence) final;
1143
1144 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
1145 bool* end_of_sequence) {
1146 return GetNext(&ctx, out_tensors, end_of_sequence);
1147 }
1148
1149 Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
1150 int* num_skipped) final;
1151
1152 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
1153 VLOG(2) << "Attempting to save checkpoints on iterator (prefix: "
1154 << prefix() << ") from " << dataset()->DebugString();
1155 return IteratorBase::Save(ctx, writer);
1156 }
1157
1158 protected:
1159 Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final {
1160 VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: "
1161 << prefix() << ") from " << dataset()->DebugString();
1162 return IteratorBase::Restore(ctx, reader);
1163 }
1164
1165 // Internal implementation of GetNext that is wrapped in tracing logic.
1166 //
1167 // See the docstring of `GetNext` method regaring the contract for
1168 // `out_tensors` and `end_of_sequence`. Implementations may assume that
1169 // `*out_tensors` is empty.
1170 virtual Status GetNextInternal(IteratorContext* ctx,
1171 std::vector<Tensor>* out_tensors,
1172 bool* end_of_sequence) = 0;
1173
1174 // Internal implementation of Skip that is wrapped in tracing logic
1175 virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1176 bool* end_of_sequence, int* num_skipped);
1177
1178 string full_name(const string& name) const {
1179 return FullName(params_.prefix, name);
1180 }
1181
1182 // Returns a map of key-value pairs to included in the TraceMe string.
1183 virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1184
1185 // By default we model iterators using an unknown node, which acts as
1186 // pass-through with respect to performance modeling.
1187 std::shared_ptr<model::Node> CreateNode(
1188 IteratorContext* ctx, model::Node::Args args) const override {
1189 return model::MakeUnknownNode(std::move(args));
1190 }
1191
1192 // When modeling is enabled, this method disables autotuning for the given
1193 // iterator (and the transitive closure of its inputs).
1194 void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1195 if (iterator->node_) {
1196 iterator->node_->set_autotune(false);
1197 }
1198 }
1199
1200 // When modeling is enabled, this method enables autotuning for the given
1201 // iterator (and the transitive closure of its inputs).
1202 void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1203 if (iterator->node_) {
1204 iterator->node_->set_autotune(true);
1205 }
1206 }
1207
1208 // When modeling is enabled, this method records the fact that this iterator
1209 // has dequeued an element from an internal buffer.
1210 void RecordBufferDequeue(IteratorContext* ctx,
1211 const std::vector<Tensor>& element) {
1212 if (collect_resource_usage(ctx)) {
1213 node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1214 DCHECK_GE(node_->buffered_elements(), 0);
1215 }
1216 }
1217
1218 // When modeling is enabled, this method records the fact that this iterator
1219 // has enqueued an element in an internal buffer.
1220 void RecordBufferEnqueue(IteratorContext* ctx,
1221 const std::vector<Tensor>& element) {
1222 if (collect_resource_usage(ctx)) {
1223 node_->record_buffer_event(GetAllocatedBytes(element), 1);
1224 }
1225 }
1226
1227 // When modeling is enabled, this method records the fact that this iterator
1228 // has produced an element and its size in bytes.
1229 void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1230 if (collect_resource_usage(ctx)) {
1231 int64_t num_bytes = GetAllocatedBytes(*out_tensors);
1232 node_->record_element();
1233 node_->record_bytes_produced(num_bytes);
1234 if (node_->output()) {
1235 node_->output()->record_bytes_consumed(num_bytes);
1236 }
1237 }
1238 }
1239
1240 // When modeling is enabled, this method records the fact that a thread of
1241 // this iterator has started work.
1242 void RecordStart(IteratorContext* ctx) {
1243 if (collect_resource_usage(ctx)) {
1244 int64_t now_nanos = EnvTime::NowNanos();
1245 node_->record_start(now_nanos);
1246 }
1247 }
1248
1249 // When modeling is enabled, this method records the fact that a thread of
1250 // this iterator has stopped work.
1251 void RecordStop(IteratorContext* ctx) {
1252 if (collect_resource_usage(ctx)) {
1253 int64_t now_nanos = EnvTime::NowNanos();
1254 node_->record_stop(now_nanos);
1255 }
1256 }
1257
1258 // Returns whether work is currently being recorded, i.e. whether we are
1259 // currently between a `RecordStart` and a `RecordStop`.
1260 bool IsRecording(IteratorContext* ctx) {
1261 return node_ && node_->is_recording();
1262 }
1263
1264 private:
1265 bool collect_resource_usage(IteratorContext* ctx) {
1266 return ctx->model() && node_;
1267 }
1268
1269 string traceme_metadata_;
1270 BaseParams params_;
1271};
1272
1273// Represents an iterator that is associated with a particular dataset
1274// with a particular type.
1275template <class DatasetType>
1276class DatasetIterator : public DatasetBaseIterator {
1277 public:
1278 struct Params {
1279 // Borrowed pointer to the dataset.
1280 const DatasetType* dataset;
1281
1282 // Identifies the sequence of iterators leading up to this iterator.
1283 const string prefix;
1284 };
1285
1286 explicit DatasetIterator(const Params& params)
1287 : DatasetBaseIterator({params.dataset, params.prefix}),
1288 typed_dataset_(params.dataset) {}
1289
1290 // The dataset from which this iterator was created.
1291 const DatasetType* dataset() const final { return typed_dataset_; }
1292
1293 private:
1294 const DatasetType* const typed_dataset_; // Not owned.
1295};
1296
1297template <typename T>
1298Status ParseScalarArgument(OpKernelContext* ctx,
1299 const StringPiece& argument_name, T* output) {
1300 const Tensor* argument_t;
1301 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1302 if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1303 return errors::InvalidArgument(argument_name, " must be a scalar");
1304 }
1305 *output = argument_t->scalar<T>()();
1306 return OkStatus();
1307}
1308
1309template <typename T>
1310Status ParseVectorArgument(OpKernelContext* ctx,
1311 const StringPiece& argument_name,
1312 std::vector<T>* output) {
1313 const Tensor* argument_t;
1314 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1315 if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1316 return errors::InvalidArgument(argument_name, " must be a vector");
1317 }
1318 int size = argument_t->vec<T>().size();
1319 output->reserve(size);
1320 for (int i = 0; i < size; ++i) {
1321 output->push_back(argument_t->vec<T>()(i));
1322 }
1323 return OkStatus();
1324}
1325
1326// Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1327// graph execution engine.
1328class DatasetOpKernel : public OpKernel {
1329 public:
1330 explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
1331 if (ctx->HasAttr(kMetadata)) {
1332 std::string serialized_metadata;
1333 OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata));
1334 OP_REQUIRES(ctx, metadata_.ParseFromString(serialized_metadata),
1335 errors::InvalidArgument(absl::StrCat(
1336 "Could not parse the 'metadata' attribute.")));
1337 }
1338 }
1339
1340 void Compute(OpKernelContext* ctx) final;
1341
1342 // Checks whether the given op is a tf.data operation.
1343 //
1344 // NOTE: The check uses a heuristic and can produce both false positives and
1345 // false negatives. In particular, tf.data operations are expected to use
1346 // names that end with "Dataset" or "DatasetV[0-9]+".
1347 static bool IsDatasetOp(const OpDef& op_def);
1348
1349 string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1350
1351 protected:
1352 // Subclasses should implement this method. It will be called during Compute
1353 // execution.
1354 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1355
1356 private:
1357 Metadata metadata_;
1358};
1359
1360// Encapsulates the work required to plug unary Datasets into the core
1361// TensorFlow graph execution engine.
1362class UnaryDatasetOpKernel : public DatasetOpKernel {
1363 public:
1364 explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1365 : DatasetOpKernel(ctx) {}
1366
1367 protected:
1368 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1369 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1370 DatasetBase** output) = 0;
1371};
1372
1373// Encapsulates the work required to plug binary Datasets into the core
1374// TensorFlow graph execution engine.
1375class BinaryDatasetOpKernel : public DatasetOpKernel {
1376 public:
1377 explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1378 : DatasetOpKernel(ctx) {}
1379
1380 protected:
1381 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1382 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1383 DatasetBase* another_input,
1384 DatasetBase** output) = 0;
1385};
1386
1387// A simple background worker that executes closures asynchronously and without
1388// blocking.
1389//
1390// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1391// to avoid blocking an executor thread that may be required by the blocking
1392// work.
1393//
1394// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1395// purpose because its current implementation (in Eigen) uses a finite-length
1396// queue and will block the caller when full. This can lead to deadlock under
1397// heavy load. Since the number of concurrent work items in each user of a
1398// `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1399// overhead is tolerable.
1400class BackgroundWorker {
1401 public:
1402 BackgroundWorker(Env* env, const char* name);
1403
1404 ~BackgroundWorker();
1405
1406 void Schedule(std::function<void()> work_item);
1407
1408 private:
1409 void WorkerLoop();
1410
1411 Env* const env_;
1412 const char* const name_;
1413
1414 std::unique_ptr<Thread> thread_;
1415 mutex mu_;
1416 condition_variable cond_var_;
1417 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1418 std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1419};
1420
1421} // namespace data
1422} // namespace tensorflow
1423
1424#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1425