1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/data/root_dataset.h"
17
18#include <algorithm>
19#include <functional>
20#include <string>
21#include <utility>
22
23#include "tensorflow/core/data/dataset_utils.h"
24#include "tensorflow/core/data/name_utils.h"
25#include "tensorflow/core/data/rewrite_utils.h"
26#include "tensorflow/core/framework/model.pb.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/host_info.h"
29#include "tensorflow/core/platform/refcount.h"
30#include "tensorflow/core/platform/stringprintf.h"
31
32namespace tensorflow {
33namespace data {
34namespace {
35
36constexpr char kDatasetType[] = "Root";
37
38constexpr char kAlgorithm[] = "algorithm";
39constexpr char kCpuBudget[] = "cpu_budget";
40constexpr char kExperiments[] = "experiments";
41constexpr char kInjectPrefetchEligibleOpt[] = "inject_prefetch_eligible";
42constexpr char kIntraOpParallelism[] = "intra_op_parallelism";
43constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec";
44constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
45constexpr char kRamBudget[] = "ram_budget_megabytes";
46constexpr char kRamUsage[] = "ram_usage_megabytes";
47constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
48
49// If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
50inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
51 return x == y ? z : x;
52}
53
54void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
55 if (ShouldConfigureMaxIntraOpParallelism(options)) {
56 params->max_intra_op_parallelism =
57 options.threading_options().max_intra_op_parallelism();
58 }
59 if (ShouldUsePrivateThreadPool(options)) {
60 params->private_threadpool_size =
61 options.threading_options().private_threadpool_size();
62 }
63 params->autotune = ShouldUseAutotuning(options);
64 if (params->autotune) {
65 params->autotune_algorithm = model::AutotuneAlgorithm::DEFAULT;
66 if (GetExperiments().contains("stage_based_autotune")) {
67 params->autotune_algorithm = model::AutotuneAlgorithm::STAGE_BASED;
68 }
69 if (options.autotune_options().optional_autotune_algorithm_case() ==
70 AutotuneOptions::kAutotuneAlgorithm) {
71 params->autotune_algorithm =
72 options.autotune_options().autotune_algorithm();
73 }
74 params->autotune_cpu_budget = value_or_default(
75 options.autotune_options().cpu_budget(), 0, GetCpuBudget());
76 params->autotune_ram_budget =
77 value_or_default(options.autotune_options().ram_budget(), 0,
78 model::kRamBudgetShare * port::AvailableRam());
79 }
80}
81
82void AddTraceMetadata(const RootDataset::Params& params,
83 TraceMeMetadata* trace_metadata) {
84 if (params.autotune) {
85 trace_metadata->push_back(std::make_pair(
86 kAlgorithm, model::AutotuneAlgorithm_Name(params.autotune_algorithm)));
87 trace_metadata->push_back(std::make_pair(
88 kCpuBudget, strings::Printf("%lld", static_cast<long long>(
89 params.autotune_cpu_budget))));
90 trace_metadata->push_back(std::make_pair(
91 kRamBudget,
92 strings::Printf("%lld", static_cast<long long>(
93 params.autotune_ram_budget / 1.0e6))));
94 }
95 if (params.max_intra_op_parallelism >= 0) {
96 trace_metadata->push_back(std::make_pair(
97 kIntraOpParallelism,
98 strings::Printf("%lld", static_cast<long long>(value_or_default(
99 params.max_intra_op_parallelism, 0,
100 port::MaxParallelism())))));
101 }
102 if (params.private_threadpool_size >= 0) {
103 trace_metadata->push_back(std::make_pair(
104 kPrivateThreadpoolSize,
105 strings::Printf("%lld", static_cast<long long>(value_or_default(
106 params.private_threadpool_size, 0,
107 port::MaxParallelism())))));
108 }
109 auto experiments = GetExperiments();
110 if (!experiments.empty()) {
111 trace_metadata->push_back(
112 std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
113 }
114}
115} // namespace
116
117// static
118Status RootDataset::FromOptions(const DatasetBase* input,
119 DatasetBase** output) {
120 Params params;
121 SetRootDatasetParams(input->options(), &params);
122 *output = new RootDataset(input, params);
123 (*output)->Initialize(/*metadata=*/{});
124 return OkStatus();
125}
126
127Status RootDataset::FromOptions(core::RefCountPtr<DatasetBase> input,
128 DatasetBase** output) {
129 Params params;
130 SetRootDatasetParams(input->options(), &params);
131 *output = new RootDataset(std::move(input), params);
132 (*output)->Initialize(/*metadata=*/{});
133 return OkStatus();
134}
135
136class RootDataset::Iterator : public DatasetIterator<RootDataset> {
137 public:
138 explicit Iterator(const Params& params)
139 : DatasetIterator<RootDataset>(params) {
140 if (dataset()->params_.autotune) {
141 model_ = std::make_shared<model::Model>();
142 if (GetExperiments().contains("autotune_buffer_optimization")) {
143 model_->SetExperiment("autotune_buffer_optimization");
144 }
145 }
146 if (dataset()->params_.max_intra_op_parallelism >= 0) {
147 max_intra_op_parallelism_ =
148 value_or_default(dataset()->params_.max_intra_op_parallelism, 0,
149 port::MaxParallelism());
150 }
151 if (dataset()->params_.private_threadpool_size >= 0) {
152 threadpool_size_ =
153 value_or_default(dataset()->params_.private_threadpool_size, 0,
154 port::MaxParallelism());
155 thread_pool_ = std::make_unique<thread::ThreadPool>(
156 Env::Default(), ThreadOptions{}, "data_private_threadpool",
157 threadpool_size_);
158 }
159 cancellation_manager_ = std::make_unique<CancellationManager>();
160 }
161
162 ~Iterator() override { cancellation_manager_->StartCancel(); }
163
164 Status Initialize(IteratorContext* ctx) override {
165 return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
166 this, prefix(), &input_impl_);
167 }
168
169 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
170 bool* end_of_sequence) override {
171 {
172 tf_shared_lock l(mu_);
173 if (model_ != nullptr && end_time_usec_ > 0) {
174 model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_);
175 }
176 }
177 if (dataset()->params_.autotune) {
178 TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
179 }
180 TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
181 out_tensors, end_of_sequence));
182 {
183 mutex_lock l(mu_);
184 end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_);
185 }
186 return OkStatus();
187 }
188
189 protected:
190 std::shared_ptr<model::Node> CreateNode(
191 IteratorContext* ctx, model::Node::Args args) const override {
192 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
193 }
194
195 Status SaveInternal(SerializationContext* ctx,
196 IteratorStateWriter* writer) override {
197 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
198 return OkStatus();
199 }
200
201 Status RestoreInternal(IteratorContext* ctx,
202 IteratorStateReader* reader) override {
203 TF_RETURN_IF_ERROR(
204 RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_));
205 return OkStatus();
206 }
207
208 TraceMeMetadata GetTraceMeMetadata() const override {
209 tensorflow::data::TraceMeMetadata traceme_metadata =
210 dataset()->traceme_metadata_;
211 const int64_t mem_bw = port::GetMemoryBandwidthInfo().bw_used;
212 if (mem_bw != INT64_MAX) {
213 traceme_metadata.push_back(std::make_pair(
214 kMemBandwidth,
215 strings::Printf("%lld", static_cast<long long>(mem_bw))));
216 }
217 const auto memory_info = port::GetMemoryInfo();
218 const auto memory_usage = memory_info.total - memory_info.free;
219 traceme_metadata.push_back(std::make_pair(
220 kRamUsage,
221 strings::Printf("%lld out of %lld (%.2f%%)",
222 static_cast<long long>(memory_usage / 1.0e6),
223 static_cast<long long>(memory_info.total / 1.0e6),
224 static_cast<double>(100 * memory_usage) /
225 static_cast<double>(memory_info.total))));
226 if (model_node() != nullptr) {
227 traceme_metadata.push_back(std::make_pair(
228 kMaxBufferBytes,
229 strings::Printf(
230 "%lld", static_cast<long long>(
231 model_node()->TotalMaximumBufferedBytes() / 1.0e6))));
232 }
233 return traceme_metadata;
234 }
235
236 private:
237 IteratorContext::Params CreateParams(IteratorContext* ctx) {
238 IteratorContext::Params params(ctx);
239 if (dataset()->params_.autotune) {
240 params.model = model_;
241 }
242 if (dataset()->params_.private_threadpool_size >= 0) {
243 params.runner = [pool = thread_pool_.get()](std::function<void()> c) {
244 pool->Schedule(std::move(c));
245 };
246 params.runner_threadpool_size = threadpool_size_;
247 }
248 if (dataset()->params_.max_intra_op_parallelism >= 0) {
249 params.runner =
250 RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_);
251 }
252 params.options = &dataset()->options();
253 return params;
254 }
255
256 Status EnsureModelThreadStarted(IteratorContext* ctx) {
257 mutex_lock l(mu_);
258 if (!model_thread_) {
259 model_thread_ = ctx->StartThread("tf_data_model", [this]() {
260 Status status =
261 model_->OptimizeLoop(dataset()->params_.autotune_algorithm,
262 dataset()->params_.autotune_cpu_budget,
263 dataset()->params_.autotune_ram_budget,
264 cancellation_manager_.get());
265 if (!status.ok()) {
266 LOG(WARNING) << "Optimization loop failed: " << status.ToString();
267 }
268 });
269 }
270 return OkStatus();
271 }
272
273 std::shared_ptr<model::Model> model_ = nullptr;
274 // Controls cancellation of `model_thread_`. Must be ordered before
275 // `model_thread_` so that `model_thread_` is destroyed first.
276 std::unique_ptr<CancellationManager> cancellation_manager_;
277 mutex mu_;
278 std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
279 int64_t max_intra_op_parallelism_;
280 int64_t threadpool_size_;
281 std::unique_ptr<thread::ThreadPool> thread_pool_;
282
283 // The end time of the previous `GetNextInternal` call.
284 uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0;
285
286 // Must be ordered last as its execution may depend on other members.
287 std::unique_ptr<IteratorBase> input_impl_;
288};
289
290RootDataset::RootDataset(const DatasetBase* input, const Params& params)
291 : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
292 name_utils::OpName(kDatasetType)})),
293 input_(input),
294 params_(std::move(params)) {
295 AddTraceMetadata(params_, &traceme_metadata_);
296}
297
298RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
299 const Params& params)
300 : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
301 name_utils::OpName(kDatasetType)})),
302 params_(std::move(params)) {
303 owned_input_ = std::move(input);
304 input_ = owned_input_.get();
305 AddTraceMetadata(params_, &traceme_metadata_);
306}
307
308RootDataset::~RootDataset() {}
309
310std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
311 const string& prefix) const {
312 return std::make_unique<Iterator>(
313 Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)});
314}
315
316const DataTypeVector& RootDataset::output_dtypes() const {
317 return input_->output_dtypes();
318}
319
320const std::vector<PartialTensorShape>& RootDataset::output_shapes() const {
321 return input_->output_shapes();
322}
323
324string RootDataset::DebugString() const {
325 return name_utils::DatasetDebugString(kDatasetType);
326}
327
328int64_t RootDataset::CardinalityInternal() const {
329 return input_->Cardinality();
330}
331
332int64_t RootDataset::CardinalityInternal(CardinalityOptions options) const {
333 return input_->Cardinality(options);
334}
335
336Status RootDataset::Get(OpKernelContext* ctx, int64 index,
337 std::vector<Tensor>* out_tensors) const {
338 std::vector<const DatasetBase*> inputs;
339 TF_RETURN_IF_ERROR(this->InputDatasets(&inputs));
340 return inputs[0]->Get(ctx, index, out_tensors);
341}
342
343Status RootDataset::InputDatasets(
344 std::vector<const DatasetBase*>* inputs) const {
345 inputs->push_back(input_);
346 return OkStatus();
347}
348
349Status RootDataset::CheckExternalState() const {
350 return input_->CheckExternalState();
351}
352
353Status RootDataset::AsGraphDefInternal(SerializationContext* ctx,
354 DatasetGraphDefBuilder* b,
355 Node** output) const {
356 return errors::Unimplemented("RootDataset does not support serialization.");
357}
358
359#if !defined(IS_MOBILE_PLATFORM)
360Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
361 DatasetBase** output) {
362 const Options& options = input->options();
363 absl::flat_hash_set<tstring> optimizations_enabled;
364 absl::flat_hash_set<tstring> optimizations_disabled;
365 absl::flat_hash_set<tstring> optimizations_default;
366 GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
367 &optimizations_default);
368 // Disable `enable_gradient_descent` as it assumes presence of ModelDatasetOp.
369 optimizations_disabled.insert("enable_gradient_descent");
370 if (!port::JobName().empty()) {
371 // Enable kInjectPrefetchEligibleOpt that does not modify the graph and is
372 // used to check whether the `inject_prefetch` optimization would modify the
373 // graph.
374 optimizations_enabled.insert(kInjectPrefetchEligibleOpt);
375 }
376
377 auto experiments = GetExperiments();
378 LogAndRecordExperiments(experiments);
379 auto optimizations =
380 SelectOptimizations(experiments, optimizations_enabled,
381 optimizations_disabled, optimizations_default);
382 if (optimizations.empty()) {
383 return RootDataset::FromOptions(input, output);
384 }
385
386 auto optimization_configs = CreateGraphRewriteConfigs(options);
387 auto config_factory = [&optimizations, &optimization_configs]() {
388 return CreateRewriterConfig(optimizations, optimization_configs);
389 };
390 core::RefCountPtr<DatasetBase> rewritten_output;
391 Status s = RewriteDataset(ctx, input, std::move(config_factory),
392 /*record_fingerprint=*/true, &rewritten_output);
393
394 *output = rewritten_output.get();
395 bool rewritten = (*output != input);
396 if (errors::IsDeadlineExceeded(s)) {
397 // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
398 // long which should not prevent further computation.
399 LOG(WARNING) << s.ToString();
400 } else if (!s.ok()) {
401 return s;
402 }
403 if (!rewritten) {
404 return RootDataset::FromOptions(input, output);
405 } else {
406 return RootDataset::FromOptions(std::move(rewritten_output), output);
407 }
408 return OkStatus();
409}
410
411#else // !IS_MOBILE_PLATFORM
412Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
413 DatasetBase** output) {
414 return RootDataset::FromOptions(input, output);
415}
416#endif // !IS_MOBILE_PLATFORM
417
418} // namespace data
419} // namespace tensorflow
420