1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ |
16 | #define TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ |
17 | |
18 | #include "tensorflow/core/framework/dataset.h" |
19 | #include "tensorflow/core/framework/model.h" |
20 | #include "tensorflow/core/framework/model.pb.h" |
21 | #include "tensorflow/core/platform/refcount.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace data { |
25 | |
26 | // Dataset transformation responsible for internal tf.data logic such as |
27 | // autotuning, applying threading configuration. |
28 | class RootDataset : public DatasetBase { |
29 | public: |
30 | struct Params { |
31 | bool autotune = true; |
32 | model::AutotuneAlgorithm autotune_algorithm; |
33 | int64_t autotune_cpu_budget = 0; |
34 | int64_t autotune_ram_budget = 0; |
35 | int64_t max_intra_op_parallelism = 1; |
36 | int64_t private_threadpool_size = 0; |
37 | }; |
38 | |
39 | static Status FromOptions(const DatasetBase* input, DatasetBase** output); |
40 | static Status FromOptions(core::RefCountPtr<DatasetBase> input, |
41 | DatasetBase** output); |
42 | |
43 | ~RootDataset() override; |
44 | |
45 | const DataTypeVector& output_dtypes() const override; |
46 | const std::vector<PartialTensorShape>& output_shapes() const override; |
47 | |
48 | int64_t CardinalityInternal() const override; |
49 | int64_t CardinalityInternal(CardinalityOptions options) const override; |
50 | Status Get(OpKernelContext* ctx, int64 index, |
51 | std::vector<Tensor>* out_tensors) const override; |
52 | Status CheckExternalState() const override; |
53 | string DebugString() const override; |
54 | Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override; |
55 | std::unique_ptr<IteratorBase> MakeIteratorInternal( |
56 | const string& prefix) const override; |
57 | |
58 | protected: |
59 | Status AsGraphDefInternal(SerializationContext* ctx, |
60 | DatasetGraphDefBuilder* b, |
61 | Node** output) const override; |
62 | |
63 | private: |
64 | class Iterator; |
65 | |
66 | RootDataset(const DatasetBase* input, const Params& params); |
67 | |
68 | RootDataset(core::RefCountPtr<DatasetBase> input, const Params& params); |
69 | |
70 | const DatasetBase* input_; |
71 | core::RefCountPtr<DatasetBase> owned_input_; |
72 | const Params params_; |
73 | TraceMeMetadata traceme_metadata_; |
74 | }; |
75 | |
76 | // Finalizes the `input` dataset, which is expected to be called before the |
77 | // dataset is about to be iterated. This can for instance apply static graph |
78 | // optimizations or inject internal tf.data transformations responsible for |
79 | // autotuning or threading configuration. The caller must ensure that the |
80 | // input dataset to be finalized outlives the output. |
81 | Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, |
82 | DatasetBase** output); |
83 | |
84 | } // namespace data |
85 | } // namespace tensorflow |
86 | |
87 | #endif // TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ |
88 | |