1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/data/root_dataset.h" |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | #include <string> |
21 | #include <utility> |
22 | |
23 | #include "tensorflow/core/data/dataset_utils.h" |
24 | #include "tensorflow/core/data/name_utils.h" |
25 | #include "tensorflow/core/data/rewrite_utils.h" |
26 | #include "tensorflow/core/framework/model.pb.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/host_info.h" |
29 | #include "tensorflow/core/platform/refcount.h" |
30 | #include "tensorflow/core/platform/stringprintf.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace data { |
34 | namespace { |
35 | |
36 | constexpr char kDatasetType[] = "Root" ; |
37 | |
38 | constexpr char kAlgorithm[] = "algorithm" ; |
39 | constexpr char kCpuBudget[] = "cpu_budget" ; |
40 | constexpr char kExperiments[] = "experiments" ; |
41 | constexpr char kInjectPrefetchEligibleOpt[] = "inject_prefetch_eligible" ; |
42 | constexpr char kIntraOpParallelism[] = "intra_op_parallelism" ; |
43 | constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec" ; |
44 | constexpr char kPrivateThreadpoolSize[] = "threadpool_size" ; |
45 | constexpr char kRamBudget[] = "ram_budget_megabytes" ; |
46 | constexpr char kRamUsage[] = "ram_usage_megabytes" ; |
47 | constexpr char kMaxBufferBytes[] = "max_buffered_megabytes" ; |
48 | |
49 | // If value `x` matches `y`, returns default value `z`. Otherwise, return `x`. |
50 | inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) { |
51 | return x == y ? z : x; |
52 | } |
53 | |
54 | void SetRootDatasetParams(const Options& options, RootDataset::Params* params) { |
55 | if (ShouldConfigureMaxIntraOpParallelism(options)) { |
56 | params->max_intra_op_parallelism = |
57 | options.threading_options().max_intra_op_parallelism(); |
58 | } |
59 | if (ShouldUsePrivateThreadPool(options)) { |
60 | params->private_threadpool_size = |
61 | options.threading_options().private_threadpool_size(); |
62 | } |
63 | params->autotune = ShouldUseAutotuning(options); |
64 | if (params->autotune) { |
65 | params->autotune_algorithm = model::AutotuneAlgorithm::DEFAULT; |
66 | if (GetExperiments().contains("stage_based_autotune" )) { |
67 | params->autotune_algorithm = model::AutotuneAlgorithm::STAGE_BASED; |
68 | } |
69 | if (options.autotune_options().optional_autotune_algorithm_case() == |
70 | AutotuneOptions::kAutotuneAlgorithm) { |
71 | params->autotune_algorithm = |
72 | options.autotune_options().autotune_algorithm(); |
73 | } |
74 | params->autotune_cpu_budget = value_or_default( |
75 | options.autotune_options().cpu_budget(), 0, GetCpuBudget()); |
76 | params->autotune_ram_budget = |
77 | value_or_default(options.autotune_options().ram_budget(), 0, |
78 | model::kRamBudgetShare * port::AvailableRam()); |
79 | } |
80 | } |
81 | |
82 | void AddTraceMetadata(const RootDataset::Params& params, |
83 | TraceMeMetadata* trace_metadata) { |
84 | if (params.autotune) { |
85 | trace_metadata->push_back(std::make_pair( |
86 | kAlgorithm, model::AutotuneAlgorithm_Name(params.autotune_algorithm))); |
87 | trace_metadata->push_back(std::make_pair( |
88 | kCpuBudget, strings::Printf("%lld" , static_cast<long long>( |
89 | params.autotune_cpu_budget)))); |
90 | trace_metadata->push_back(std::make_pair( |
91 | kRamBudget, |
92 | strings::Printf("%lld" , static_cast<long long>( |
93 | params.autotune_ram_budget / 1.0e6)))); |
94 | } |
95 | if (params.max_intra_op_parallelism >= 0) { |
96 | trace_metadata->push_back(std::make_pair( |
97 | kIntraOpParallelism, |
98 | strings::Printf("%lld" , static_cast<long long>(value_or_default( |
99 | params.max_intra_op_parallelism, 0, |
100 | port::MaxParallelism()))))); |
101 | } |
102 | if (params.private_threadpool_size >= 0) { |
103 | trace_metadata->push_back(std::make_pair( |
104 | kPrivateThreadpoolSize, |
105 | strings::Printf("%lld" , static_cast<long long>(value_or_default( |
106 | params.private_threadpool_size, 0, |
107 | port::MaxParallelism()))))); |
108 | } |
109 | auto experiments = GetExperiments(); |
110 | if (!experiments.empty()) { |
111 | trace_metadata->push_back( |
112 | std::make_pair(kExperiments, absl::StrJoin(experiments, " " ))); |
113 | } |
114 | } |
115 | } // namespace |
116 | |
117 | // static |
118 | Status RootDataset::FromOptions(const DatasetBase* input, |
119 | DatasetBase** output) { |
120 | Params params; |
121 | SetRootDatasetParams(input->options(), ¶ms); |
122 | *output = new RootDataset(input, params); |
123 | (*output)->Initialize(/*metadata=*/{}); |
124 | return OkStatus(); |
125 | } |
126 | |
127 | Status RootDataset::FromOptions(core::RefCountPtr<DatasetBase> input, |
128 | DatasetBase** output) { |
129 | Params params; |
130 | SetRootDatasetParams(input->options(), ¶ms); |
131 | *output = new RootDataset(std::move(input), params); |
132 | (*output)->Initialize(/*metadata=*/{}); |
133 | return OkStatus(); |
134 | } |
135 | |
136 | class RootDataset::Iterator : public DatasetIterator<RootDataset> { |
137 | public: |
138 | explicit Iterator(const Params& params) |
139 | : DatasetIterator<RootDataset>(params) { |
140 | if (dataset()->params_.autotune) { |
141 | model_ = std::make_shared<model::Model>(); |
142 | if (GetExperiments().contains("autotune_buffer_optimization" )) { |
143 | model_->SetExperiment("autotune_buffer_optimization" ); |
144 | } |
145 | } |
146 | if (dataset()->params_.max_intra_op_parallelism >= 0) { |
147 | max_intra_op_parallelism_ = |
148 | value_or_default(dataset()->params_.max_intra_op_parallelism, 0, |
149 | port::MaxParallelism()); |
150 | } |
151 | if (dataset()->params_.private_threadpool_size >= 0) { |
152 | threadpool_size_ = |
153 | value_or_default(dataset()->params_.private_threadpool_size, 0, |
154 | port::MaxParallelism()); |
155 | thread_pool_ = std::make_unique<thread::ThreadPool>( |
156 | Env::Default(), ThreadOptions{}, "data_private_threadpool" , |
157 | threadpool_size_); |
158 | } |
159 | cancellation_manager_ = std::make_unique<CancellationManager>(); |
160 | } |
161 | |
162 | ~Iterator() override { cancellation_manager_->StartCancel(); } |
163 | |
164 | Status Initialize(IteratorContext* ctx) override { |
165 | return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)), |
166 | this, prefix(), &input_impl_); |
167 | } |
168 | |
169 | Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
170 | bool* end_of_sequence) override { |
171 | { |
172 | tf_shared_lock l(mu_); |
173 | if (model_ != nullptr && end_time_usec_ > 0) { |
174 | model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_); |
175 | } |
176 | } |
177 | if (dataset()->params_.autotune) { |
178 | TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx)); |
179 | } |
180 | TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)), |
181 | out_tensors, end_of_sequence)); |
182 | { |
183 | mutex_lock l(mu_); |
184 | end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_); |
185 | } |
186 | return OkStatus(); |
187 | } |
188 | |
189 | protected: |
190 | std::shared_ptr<model::Node> CreateNode( |
191 | IteratorContext* ctx, model::Node::Args args) const override { |
192 | return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); |
193 | } |
194 | |
195 | Status SaveInternal(SerializationContext* ctx, |
196 | IteratorStateWriter* writer) override { |
197 | TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); |
198 | return OkStatus(); |
199 | } |
200 | |
201 | Status RestoreInternal(IteratorContext* ctx, |
202 | IteratorStateReader* reader) override { |
203 | TF_RETURN_IF_ERROR( |
204 | RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_)); |
205 | return OkStatus(); |
206 | } |
207 | |
208 | TraceMeMetadata GetTraceMeMetadata() const override { |
209 | tensorflow::data::TraceMeMetadata traceme_metadata = |
210 | dataset()->traceme_metadata_; |
211 | const int64_t mem_bw = port::GetMemoryBandwidthInfo().bw_used; |
212 | if (mem_bw != INT64_MAX) { |
213 | traceme_metadata.push_back(std::make_pair( |
214 | kMemBandwidth, |
215 | strings::Printf("%lld" , static_cast<long long>(mem_bw)))); |
216 | } |
217 | const auto memory_info = port::GetMemoryInfo(); |
218 | const auto memory_usage = memory_info.total - memory_info.free; |
219 | traceme_metadata.push_back(std::make_pair( |
220 | kRamUsage, |
221 | strings::Printf("%lld out of %lld (%.2f%%)" , |
222 | static_cast<long long>(memory_usage / 1.0e6), |
223 | static_cast<long long>(memory_info.total / 1.0e6), |
224 | static_cast<double>(100 * memory_usage) / |
225 | static_cast<double>(memory_info.total)))); |
226 | if (model_node() != nullptr) { |
227 | traceme_metadata.push_back(std::make_pair( |
228 | kMaxBufferBytes, |
229 | strings::Printf( |
230 | "%lld" , static_cast<long long>( |
231 | model_node()->TotalMaximumBufferedBytes() / 1.0e6)))); |
232 | } |
233 | return traceme_metadata; |
234 | } |
235 | |
236 | private: |
237 | IteratorContext::Params CreateParams(IteratorContext* ctx) { |
238 | IteratorContext::Params params(ctx); |
239 | if (dataset()->params_.autotune) { |
240 | params.model = model_; |
241 | } |
242 | if (dataset()->params_.private_threadpool_size >= 0) { |
243 | params.runner = [pool = thread_pool_.get()](std::function<void()> c) { |
244 | pool->Schedule(std::move(c)); |
245 | }; |
246 | params.runner_threadpool_size = threadpool_size_; |
247 | } |
248 | if (dataset()->params_.max_intra_op_parallelism >= 0) { |
249 | params.runner = |
250 | RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_); |
251 | } |
252 | params.options = &dataset()->options(); |
253 | return params; |
254 | } |
255 | |
256 | Status EnsureModelThreadStarted(IteratorContext* ctx) { |
257 | mutex_lock l(mu_); |
258 | if (!model_thread_) { |
259 | model_thread_ = ctx->StartThread("tf_data_model" , [this]() { |
260 | Status status = |
261 | model_->OptimizeLoop(dataset()->params_.autotune_algorithm, |
262 | dataset()->params_.autotune_cpu_budget, |
263 | dataset()->params_.autotune_ram_budget, |
264 | cancellation_manager_.get()); |
265 | if (!status.ok()) { |
266 | LOG(WARNING) << "Optimization loop failed: " << status.ToString(); |
267 | } |
268 | }); |
269 | } |
270 | return OkStatus(); |
271 | } |
272 | |
273 | std::shared_ptr<model::Model> model_ = nullptr; |
274 | // Controls cancellation of `model_thread_`. Must be ordered before |
275 | // `model_thread_` so that `model_thread_` is destroyed first. |
276 | std::unique_ptr<CancellationManager> cancellation_manager_; |
277 | mutex mu_; |
278 | std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_); |
279 | int64_t max_intra_op_parallelism_; |
280 | int64_t threadpool_size_; |
281 | std::unique_ptr<thread::ThreadPool> thread_pool_; |
282 | |
283 | // The end time of the previous `GetNextInternal` call. |
284 | uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0; |
285 | |
286 | // Must be ordered last as its execution may depend on other members. |
287 | std::unique_ptr<IteratorBase> input_impl_; |
288 | }; |
289 | |
290 | RootDataset::RootDataset(const DatasetBase* input, const Params& params) |
291 | : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType), |
292 | name_utils::OpName(kDatasetType)})), |
293 | input_(input), |
294 | params_(std::move(params)) { |
295 | AddTraceMetadata(params_, &traceme_metadata_); |
296 | } |
297 | |
298 | RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input, |
299 | const Params& params) |
300 | : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType), |
301 | name_utils::OpName(kDatasetType)})), |
302 | params_(std::move(params)) { |
303 | owned_input_ = std::move(input); |
304 | input_ = owned_input_.get(); |
305 | AddTraceMetadata(params_, &traceme_metadata_); |
306 | } |
307 | |
308 | RootDataset::~RootDataset() {} |
309 | |
310 | std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal( |
311 | const string& prefix) const { |
312 | return std::make_unique<Iterator>( |
313 | Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)}); |
314 | } |
315 | |
316 | const DataTypeVector& RootDataset::output_dtypes() const { |
317 | return input_->output_dtypes(); |
318 | } |
319 | |
320 | const std::vector<PartialTensorShape>& RootDataset::output_shapes() const { |
321 | return input_->output_shapes(); |
322 | } |
323 | |
324 | string RootDataset::DebugString() const { |
325 | return name_utils::DatasetDebugString(kDatasetType); |
326 | } |
327 | |
328 | int64_t RootDataset::CardinalityInternal() const { |
329 | return input_->Cardinality(); |
330 | } |
331 | |
332 | int64_t RootDataset::CardinalityInternal(CardinalityOptions options) const { |
333 | return input_->Cardinality(options); |
334 | } |
335 | |
336 | Status RootDataset::Get(OpKernelContext* ctx, int64 index, |
337 | std::vector<Tensor>* out_tensors) const { |
338 | std::vector<const DatasetBase*> inputs; |
339 | TF_RETURN_IF_ERROR(this->InputDatasets(&inputs)); |
340 | return inputs[0]->Get(ctx, index, out_tensors); |
341 | } |
342 | |
343 | Status RootDataset::InputDatasets( |
344 | std::vector<const DatasetBase*>* inputs) const { |
345 | inputs->push_back(input_); |
346 | return OkStatus(); |
347 | } |
348 | |
349 | Status RootDataset::CheckExternalState() const { |
350 | return input_->CheckExternalState(); |
351 | } |
352 | |
353 | Status RootDataset::AsGraphDefInternal(SerializationContext* ctx, |
354 | DatasetGraphDefBuilder* b, |
355 | Node** output) const { |
356 | return errors::Unimplemented("RootDataset does not support serialization." ); |
357 | } |
358 | |
359 | #if !defined(IS_MOBILE_PLATFORM) |
360 | Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, |
361 | DatasetBase** output) { |
362 | const Options& options = input->options(); |
363 | absl::flat_hash_set<tstring> optimizations_enabled; |
364 | absl::flat_hash_set<tstring> optimizations_disabled; |
365 | absl::flat_hash_set<tstring> optimizations_default; |
366 | GetOptimizations(options, &optimizations_enabled, &optimizations_disabled, |
367 | &optimizations_default); |
368 | // Disable `enable_gradient_descent` as it assumes presence of ModelDatasetOp. |
369 | optimizations_disabled.insert("enable_gradient_descent" ); |
370 | if (!port::JobName().empty()) { |
371 | // Enable kInjectPrefetchEligibleOpt that does not modify the graph and is |
372 | // used to check whether the `inject_prefetch` optimization would modify the |
373 | // graph. |
374 | optimizations_enabled.insert(kInjectPrefetchEligibleOpt); |
375 | } |
376 | |
377 | auto experiments = GetExperiments(); |
378 | LogAndRecordExperiments(experiments); |
379 | auto optimizations = |
380 | SelectOptimizations(experiments, optimizations_enabled, |
381 | optimizations_disabled, optimizations_default); |
382 | if (optimizations.empty()) { |
383 | return RootDataset::FromOptions(input, output); |
384 | } |
385 | |
386 | auto optimization_configs = CreateGraphRewriteConfigs(options); |
387 | auto config_factory = [&optimizations, &optimization_configs]() { |
388 | return CreateRewriterConfig(optimizations, optimization_configs); |
389 | }; |
390 | core::RefCountPtr<DatasetBase> rewritten_output; |
391 | Status s = RewriteDataset(ctx, input, std::move(config_factory), |
392 | /*record_fingerprint=*/true, &rewritten_output); |
393 | |
394 | *output = rewritten_output.get(); |
395 | bool rewritten = (*output != input); |
396 | if (errors::IsDeadlineExceeded(s)) { |
397 | // Ignore DeadlineExceeded as it implies that the attempted rewrite took too |
398 | // long which should not prevent further computation. |
399 | LOG(WARNING) << s.ToString(); |
400 | } else if (!s.ok()) { |
401 | return s; |
402 | } |
403 | if (!rewritten) { |
404 | return RootDataset::FromOptions(input, output); |
405 | } else { |
406 | return RootDataset::FromOptions(std::move(rewritten_output), output); |
407 | } |
408 | return OkStatus(); |
409 | } |
410 | |
411 | #else // !IS_MOBILE_PLATFORM |
412 | Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, |
413 | DatasetBase** output) { |
414 | return RootDataset::FromOptions(input, output); |
415 | } |
416 | #endif // !IS_MOBILE_PLATFORM |
417 | |
418 | } // namespace data |
419 | } // namespace tensorflow |
420 | |