1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/data/dataset_utils.h"
17
18#include <algorithm>
19#include <functional>
20#include <memory>
21#include <queue>
22#include <string>
23#include <utility>
24
25#include "absl/container/flat_hash_map.h"
26#include "absl/container/flat_hash_set.h"
27#include "absl/strings/str_join.h"
28#include "tensorflow/core/common_runtime/function.h"
29#include "tensorflow/core/framework/attr_value.pb.h"
30#include "tensorflow/core/framework/dataset.h"
31#include "tensorflow/core/framework/function.h"
32#include "tensorflow/core/framework/node_def_util.h"
33#include "tensorflow/core/framework/op_def_builder.h"
34#include "tensorflow/core/framework/op_def_util.h"
35#include "tensorflow/core/framework/op_kernel.h"
36#include "tensorflow/core/framework/tensor.pb.h"
37#include "tensorflow/core/framework/tensor_util.h"
38#include "tensorflow/core/framework/types.h"
39#include "tensorflow/core/graph/graph_def_builder.h"
40#include "tensorflow/core/lib/core/errors.h"
41#include "tensorflow/core/lib/hash/hash.h"
42#include "tensorflow/core/lib/strings/proto_serialization.h"
43#include "tensorflow/core/platform/blocking_counter.h"
44#include "tensorflow/core/platform/host_info.h"
45#include "tensorflow/core/platform/regexp.h"
46#include "tensorflow/core/util/determinism.h"
47#include "tensorflow/core/util/work_sharder.h"
48
49namespace tensorflow {
50namespace data {
51namespace {
52
53constexpr char kOutputSize[] = "output_size";
54constexpr char kCode[] = "code";
55constexpr char kMessage[] = "msg";
56constexpr char kOutput[] = "output";
57
58static mutex* get_dataset_experiment_registry_lock() {
59 static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
60 return &dataset_experiment_registry_lock;
61}
62
63static absl::flat_hash_map<string,
64 DatasetExperimentRegistry::ExperimentSelector>*
65get_dataset_experiments() {
66 static absl::flat_hash_map<
67 string, DatasetExperimentRegistry::ExperimentSelector>* experiments =
68 new absl::flat_hash_map<string,
69 DatasetExperimentRegistry::ExperimentSelector>;
70 return experiments;
71}
72
73// Use "Opt" suffix so that they are not confused with the enums in Options
74// proto.
75constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
76constexpr char kNoopEliminationOpt[] = "noop_elimination";
77constexpr char kMapParallelizationOpt[] = "map_parallelization";
78constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
79constexpr char kFilterFusionOpt[] = "filter_fusion";
80constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
81constexpr char kMapFusionOpt[] = "map_fusion";
82constexpr char kParallelBatchOpt[] = "parallel_batch";
83constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
84constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
85 "disable_prefetch_legacy_autotune";
86constexpr char kMakeSloppyOpt[] = "make_sloppy";
87constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
88constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
89constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
90constexpr char kInjectPrefetchOpt[] = "inject_prefetch";
91constexpr char kAutotuneOpt[] = "autotune";
92constexpr char kSlackOpt[] = "slack";
93constexpr char kSlackPeriodOpt[] = "slack_period";
94constexpr char kMakeDeterministicOpt[] = "make_deterministic";
95constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
96
97void DefaultOptimizationGraphRewrites(
98 const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
99 absl::flat_hash_set<tstring>* optimization_disabled,
100 absl::flat_hash_set<tstring>* optimization_default) {
101 const auto& optimization_options = options.optimization_options();
102 if (optimization_options.optional_apply_default_optimizations_case() !=
103 OptimizationOptions::kApplyDefaultOptimizations ||
104 optimization_options.apply_default_optimizations()) {
105 if (optimization_options.optional_map_and_batch_fusion_case() !=
106 OptimizationOptions::kMapAndBatchFusion) {
107 optimization_default->insert(kMapAndBatchFusionOpt);
108 }
109 if (optimization_options.optional_noop_elimination_case() !=
110 OptimizationOptions::kNoopElimination) {
111 optimization_default->insert(kNoopEliminationOpt);
112 }
113 if (optimization_options.optional_map_parallelization_case() !=
114 OptimizationOptions::kMapParallelization) {
115 optimization_default->insert(kMapParallelizationOpt);
116 }
117 if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
118 OptimizationOptions::kShuffleAndRepeatFusion) {
119 optimization_default->insert(kShuffleAndRepeatFusionOpt);
120 }
121 if (optimization_options.optional_parallel_batch_case() !=
122 OptimizationOptions::kParallelBatch) {
123 optimization_default->insert(kParallelBatchOpt);
124 }
125 }
126 if (OpDeterminismRequired()) {
127 optimization_enabled->insert(kMakeDeterministicOpt);
128 }
129 if (optimization_options.optional_filter_fusion_case() ==
130 OptimizationOptions::kFilterFusion) {
131 if (optimization_options.filter_fusion()) {
132 optimization_enabled->insert(kFilterFusionOpt);
133 } else {
134 optimization_disabled->insert(kFilterFusionOpt);
135 }
136 }
137 if (optimization_options.optional_map_and_batch_fusion_case() ==
138 OptimizationOptions::kMapAndBatchFusion) {
139 if (optimization_options.map_and_batch_fusion()) {
140 optimization_enabled->insert(kMapAndBatchFusionOpt);
141 } else {
142 optimization_disabled->insert(kMapAndBatchFusionOpt);
143 }
144 }
145 if (optimization_options.optional_map_and_filter_fusion_case() ==
146 OptimizationOptions::kMapAndFilterFusion) {
147 if (optimization_options.map_and_filter_fusion()) {
148 optimization_enabled->insert(kMapAndFilterFusionOpt);
149 } else {
150 optimization_disabled->insert(kMapAndFilterFusionOpt);
151 }
152 }
153 if (optimization_options.optional_map_parallelization_case() ==
154 OptimizationOptions::kMapParallelization) {
155 if (optimization_options.map_parallelization()) {
156 optimization_enabled->insert(kMapParallelizationOpt);
157 } else {
158 optimization_disabled->insert(kMapParallelizationOpt);
159 }
160 }
161 if (optimization_options.optional_filter_parallelization_case() ==
162 OptimizationOptions::kFilterParallelization) {
163 if (optimization_options.filter_parallelization()) {
164 optimization_enabled->insert(kFilterParallelizationOpt);
165 } else {
166 optimization_disabled->insert(kFilterParallelizationOpt);
167 }
168 }
169 if (optimization_options.optional_map_fusion_case() ==
170 OptimizationOptions::kMapFusion) {
171 if (optimization_options.map_fusion()) {
172 optimization_enabled->insert(kMapFusionOpt);
173 } else {
174 optimization_disabled->insert(kMapFusionOpt);
175 }
176 }
177 if (optimization_options.optional_noop_elimination_case() ==
178 OptimizationOptions::kNoopElimination) {
179 if (optimization_options.noop_elimination()) {
180 optimization_enabled->insert(kNoopEliminationOpt);
181 } else {
182 optimization_disabled->insert(kNoopEliminationOpt);
183 }
184 }
185 if (optimization_options.optional_parallel_batch_case() ==
186 OptimizationOptions::kParallelBatch) {
187 if (optimization_options.parallel_batch()) {
188 optimization_enabled->insert(kParallelBatchOpt);
189 } else {
190 optimization_disabled->insert(kParallelBatchOpt);
191 }
192 }
193 if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
194 OptimizationOptions::kShuffleAndRepeatFusion) {
195 if (optimization_options.shuffle_and_repeat_fusion()) {
196 optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
197 } else {
198 optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
199 }
200 }
201 if (optimization_options.optional_inject_prefetch_case() ==
202 OptimizationOptions::kInjectPrefetch) {
203 if (optimization_options.inject_prefetch()) {
204 optimization_enabled->insert(kInjectPrefetchOpt);
205 } else {
206 optimization_disabled->insert(kInjectPrefetchOpt);
207 }
208 }
209}
210
211// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
212// allowlist source dataset ops which have been marked stateful due to
213// b/65524810. Also looks up the `op_def->name` in the global
214// `AllowlistedStatefulOpRegistry`.
215bool IsOpAllowlisted(const OpDef* op_def) {
216 return (op_def->output_arg_size() == 1 &&
217 op_def->output_arg(0).type() == DT_VARIANT &&
218 (absl::EndsWith(op_def->name(), "Dataset") ||
219 absl::EndsWith(op_def->name(), "DatasetV2"))) ||
220 AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
221}
222
223} // namespace
224
225std::pair<int64_t, int64_t> MaybeOverrideSeeds(
226 std::pair<int64_t, int64_t> seeds) {
227 if (seeds.first == 0 && seeds.second == 0) {
228 return {random::New64(), random::New64()};
229 }
230 return seeds;
231}
232
233Status VerifyTypeMatch(const DataType& expected, const DataType& received,
234 int index) {
235 if (expected != received) {
236 return errors::InvalidArgument("Data type mismatch at component ", index,
237 ": expected ", DataTypeString(expected),
238 " but got ", DataTypeString(received), ".");
239 }
240 return OkStatus();
241}
242
243Status VerifyTypesMatch(const DataTypeVector& expected,
244 const DataTypeVector& received) {
245 if (expected.size() != received.size()) {
246 return errors::InvalidArgument(
247 "Number of components does not match: expected ", expected.size(),
248 " types but got ", received.size(), ".");
249 }
250 for (size_t i = 0; i < expected.size(); ++i) {
251 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
252 }
253 return OkStatus();
254}
255
256Status VerifyTypesMatch(const DataTypeVector& expected,
257 const std::vector<Tensor>& received) {
258 if (expected.size() != received.size()) {
259 return errors::InvalidArgument(
260 "Number of components does not match: expected ", expected.size(),
261 " types but got ", received.size(), ".");
262 }
263 for (size_t i = 0; i < expected.size(); ++i) {
264 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
265 }
266 return OkStatus();
267}
268
269Status VerifyShapeCompatible(const PartialTensorShape& expected,
270 const PartialTensorShape& received, int index) {
271 if (!expected.IsCompatibleWith(received)) {
272 return errors::InvalidArgument("Incompatible shapes at component ", index,
273 ": expected ", expected.DebugString(),
274 " but got ", received.DebugString(), ".");
275 }
276 return OkStatus();
277}
278
279Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
280 const std::vector<PartialTensorShape>& received) {
281 if (expected.size() != received.size()) {
282 return errors::InvalidArgument(
283 "Number of components does not match: expected ", expected.size(),
284 " shapes but got ", received.size(), ".");
285 }
286 for (size_t i = 0; i < expected.size(); ++i) {
287 TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
288 }
289
290 return OkStatus();
291}
292
293Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
294 const std::vector<Tensor>& received) {
295 if (expected.size() != received.size()) {
296 return errors::InvalidArgument(
297 "Number of components does not match: expected ", expected.size(),
298 " shapes but got ", received.size(), ".");
299 }
300 for (size_t i = 0; i < expected.size(); ++i) {
301 TF_RETURN_IF_ERROR(
302 VerifyShapeCompatible(expected[i], received[i].shape(), i));
303 }
304
305 return OkStatus();
306}
307
308Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
309 const FunctionLibraryDefinition& to_add) {
310 for (const auto& fn : to_add.ListFunctionNames()) {
311 if (auto found = base->Find(fn)) {
312 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
313 return errors::InvalidArgument("Cannot add function '", fn,
314 "' because a different function with "
315 "the same signature already exists.");
316 }
317 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
318 }
319 }
320 return base->AddLibrary(to_add);
321}
322
323Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
324 const FunctionDefLibrary& to_add) {
325 for (const auto& fd : to_add.function()) {
326 if (auto found = base->Find(fd.signature().name())) {
327 if (!OpDefEqual(found->signature(), fd.signature())) {
328 return errors::InvalidArgument("Cannot add function '",
329 fd.signature().name(),
330 "' because a different function with "
331 "the same signature already exists.");
332 }
333 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
334 }
335 }
336 return base->AddLibrary(to_add);
337}
338
339Status IsFunctionStateful(const FunctionLibraryDefinition& library,
340 const FunctionDef& function_def) {
341 if (!function_def.signature().is_stateful()) {
342 return OkStatus();
343 }
344
345 for (const NodeDef& node_def : function_def.node_def()) {
346 TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
347 }
348 return OkStatus();
349}
350
351Status IsNodeStateful(const FunctionLibraryDefinition& library,
352 const NodeDef& node) {
353 const OpDef* op_def;
354
355 // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
356 // `LookUpOpDef` errors here.
357 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
358 IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
359 op_def->name() == "Assert") {
360 return OkStatus();
361 }
362
363 if (op_def->name() == "If") {
364 const FunctionDef* then_func =
365 library.Find(node.attr().at("then_branch").func().name());
366 const FunctionDef* else_func =
367 library.Find(node.attr().at("else_branch").func().name());
368 if (then_func != nullptr) {
369 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
370 }
371 if (else_func != nullptr) {
372 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
373 }
374 return OkStatus();
375 }
376
377 if (op_def->name() == "While") {
378 const FunctionDef* cond_func =
379 library.Find(node.attr().at("cond").func().name());
380 const FunctionDef* body_func =
381 library.Find(node.attr().at("body").func().name());
382 if (cond_func != nullptr) {
383 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
384 }
385 if (body_func != nullptr) {
386 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
387 }
388 return OkStatus();
389 }
390
391 return errors::FailedPrecondition(op_def->name(), " is stateful.");
392}
393
394std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
395 std::function<void(std::function<void()>)> runner, int max_parallelism) {
396 return std::bind(
397 [max_parallelism](
398 // Note: `runner` is a const reference to avoid copying it.
399 const std::function<void(std::function<void()>)>& runner,
400 std::function<void()> fn) {
401 std::function<void()> scoped_fn = std::bind(
402 [max_parallelism](const std::function<void()>& fn) {
403 ScopedPerThreadMaxParallelism scope(max_parallelism);
404 fn();
405 },
406 std::move(fn));
407 runner(std::move(scoped_fn));
408 },
409 std::move(runner), std::placeholders::_1);
410}
411
412Status DeterminismPolicy::FromString(const std::string& s,
413 DeterminismPolicy* out) {
414 DeterminismPolicy::Type type;
415 if (s == DeterminismPolicy::kDeterministic) {
416 type = DeterminismPolicy::Type::kDeterministic;
417 } else if (s == DeterminismPolicy::kNondeterministic) {
418 type = DeterminismPolicy::Type::kNondeterministic;
419 } else if (s == DeterminismPolicy::kDefault) {
420 type = DeterminismPolicy::Type::kDefault;
421 } else {
422 return errors::InvalidArgument("Unrecognized determinism policy: ", s);
423 }
424 *out = DeterminismPolicy(type);
425 return OkStatus();
426}
427
428DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
429 if (is_deterministic) {
430 determinism_ = DeterminismPolicy::Type::kDeterministic;
431 } else {
432 determinism_ = DeterminismPolicy::Type::kNondeterministic;
433 }
434}
435
436std::string DeterminismPolicy::String() const {
437 switch (determinism_) {
438 case DeterminismPolicy::Type::kDeterministic:
439 return DeterminismPolicy::kDeterministic;
440 case DeterminismPolicy::Type::kNondeterministic:
441 return DeterminismPolicy::kNondeterministic;
442 case DeterminismPolicy::Type::kDefault:
443 return DeterminismPolicy::kDefault;
444 default:
445 LOG(ERROR) << "Unrecognized determinism value";
446 return "Unrecognized";
447 }
448}
449
450bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
451 if (!absl::StartsWith(op_to_match, op_prefix)) {
452 return false;
453 }
454 if (op_to_match.length() == op_prefix.length()) {
455 return true;
456 }
457 size_t index = op_to_match.length() - 1;
458 while (isdigit(op_to_match[index])) {
459 index--;
460 }
461 return (op_to_match[index] == 'V') && (op_prefix.length() == index);
462}
463
464absl::flat_hash_set<string> GetExperiments() {
465 return GetExperiments(tsl::port::JobName(), tsl::port::TaskId(),
466 [](const tstring& str) { return Hash64(str); });
467}
468
469absl::flat_hash_set<string> GetExperiments(
470 const string& job_name, int64_t task_id,
471 std::function<uint64_t(const string&)> hash_func) {
472 absl::flat_hash_set<string> experiments;
473 if (job_name.empty() || task_id < 0) {
474 return experiments;
475 }
476
477 // Parse the opt-in and opt-out settings.
478 const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
479 const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
480 string opt_ins_raw;
481 if (opt_ins_raw_cs != nullptr) {
482 opt_ins_raw = string(opt_ins_raw_cs);
483 }
484 string opt_outs_raw;
485 if (opt_outs_raw_cs != nullptr) {
486 opt_outs_raw = string(opt_outs_raw_cs);
487 }
488
489 // Identify opted out experiments.
490 auto live_experiments = DatasetExperimentRegistry::Experiments();
491 absl::flat_hash_set<string> opt_outs;
492 if (opt_outs_raw == "all") {
493 for (const auto& pair : live_experiments) {
494 opt_outs.insert(pair.first);
495 }
496 } else {
497 for (const auto& experiment :
498 str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
499 opt_outs.insert(experiment);
500 }
501 }
502
503 // Include opted in experiments unless they are opted out.
504 if (opt_ins_raw == "all") {
505 for (const auto& pair : live_experiments) {
506 auto experiment = pair.first;
507 if (!opt_outs.contains(experiment)) {
508 experiments.insert(experiment);
509 }
510 }
511 } else {
512 for (const auto& experiment :
513 str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
514 if (!opt_outs.contains(experiment)) {
515 experiments.insert(experiment);
516 }
517 }
518 }
519
520 if (opt_outs_raw == "all_except_opt_in") {
521 return experiments;
522 }
523 // Stochastically include live experiments unless they are opted out.
524 for (const auto& [experiment_name, experiment_selector] : live_experiments) {
525 if (experiment_selector.job_selector(hash_func, experiment_name,
526 job_name) &&
527 experiment_selector.task_selector(task_id) &&
528 !opt_outs.contains(experiment_name)) {
529 experiments.insert(experiment_name);
530 }
531 }
532
533 return experiments;
534}
535
536void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
537 if (!experiments.empty()) {
538 constexpr float TEN_MINUTES = 60.0 * 10.0;
539 LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
540 << "The input pipeline is subject to the following tf.data experiments:"
541 << " " << absl::StrJoin(experiments, ", ") << ". "
542 << "See `go/tf-data-experiments` for more details.";
543 }
544 for (auto& experiment : experiments) {
545 metrics::RecordTFDataExperiment(experiment);
546 }
547}
548
549void GetOptimizations(const Options& options,
550 absl::flat_hash_set<tstring>* optimizations_enabled,
551 absl::flat_hash_set<tstring>* optimizations_disabled,
552 absl::flat_hash_set<tstring>* optimizations_default) {
553 DefaultOptimizationGraphRewrites(options, optimizations_enabled,
554 optimizations_disabled,
555 optimizations_default);
556 if (!OpDeterminismRequired() &&
557 options.optional_deterministic_case() == Options::kDeterministic &&
558 !options.deterministic()) {
559 optimizations_enabled->insert(kMakeSloppyOpt);
560 }
561 if (options.optional_slack_case() == Options::kSlack) {
562 if (options.slack()) {
563 optimizations_enabled->insert(kSlackOpt);
564 } else {
565 optimizations_disabled->insert(kSlackOpt);
566 }
567 }
568}
569
570Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) {
571 Tensor slice = tensor.SubSlice(index);
572 if (slice.IsAligned()) {
573 return slice;
574 } else {
575 return tensorflow::tensor::DeepCopy(slice);
576 }
577}
578
579void StripDevicePlacement(FunctionDefLibrary* library) {
580 for (auto& function : (*library->mutable_function())) {
581 for (auto& node : (*function.mutable_node_def())) {
582 if (!node.device().empty()) {
583 *node.mutable_device() = "";
584 }
585 }
586 }
587}
588
589Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
590 Tensor* output) {
591 switch (value.dtype()) {
592#define HANDLE_TYPE(type) \
593 case DataTypeToEnum<type>::value: { \
594 auto output_t = output->flat_outer_dims<type>(); \
595 auto value_t = value.flat_outer_dims<type>(); \
596 for (size_t i = 0; i < num_elements; i++) { \
597 output_t.template chip<0>(i) = value_t.template chip<0>(i); \
598 } \
599 return OkStatus(); \
600 }
601 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
602#undef HANDLE_TYPE
603 default:
604 return errors::InvalidArgument("Unsupported data type: ",
605 DataTypeString(value.dtype()));
606 }
607 return OkStatus();
608}
609
610Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
611 int64_t batch_size, const string& iterator_prefix,
612 const string& batch_prefix, std::vector<Tensor>* batch) {
613 int64_t output_size;
614 TF_RETURN_IF_ERROR(reader->ReadScalar(
615 FullName(iterator_prefix,
616 strings::StrCat(batch_prefix, "_", kOutputSize)),
617 &output_size));
618 batch->reserve(output_size);
619 for (int i = 0; i < output_size; i++) {
620 Tensor t;
621 TF_RETURN_IF_ERROR(
622 reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
623 strings::StrCat(kOutput, "_", i), &t));
624 // If the batch was not full, we may have stored only the relevant slice.
625 // Since tensors in `BatchResult.output` are expected to have the leading
626 // dimension of size batch_size, we build a larger tensor and copy the slice
627 // read from the checkpoint into it.
628 if (t.dim_size(0) < batch_size) {
629 TensorShape component_shape(t.shape());
630 component_shape.set_dim(0, batch_size);
631 AllocatorAttributes attr;
632 attr.set_gpu_compatible(true);
633 Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
634 TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
635 batch->emplace_back(std::move(new_t));
636 } else {
637 batch->emplace_back(std::move(t));
638 }
639 }
640 return OkStatus();
641}
642
643Status WriteBatch(int64_t batch_size, int64_t num_elements,
644 const string& iterator_prefix, const string& batch_prefix,
645 IteratorStateWriter* writer, std::vector<Tensor>* batch) {
646 TF_RETURN_IF_ERROR(writer->WriteScalar(
647 FullName(iterator_prefix,
648 strings::StrCat(batch_prefix, "_", kOutputSize)),
649 batch->size()));
650 for (int i = 0; i < batch->size(); i++) {
651 // If the batch is not full, we only store the first `num_elements` values.
652 // The rest of the batch tensor is *uninitialized* and accessing that will
653 // raise msan errors.
654 if (num_elements < batch_size) {
655 TF_RETURN_IF_ERROR(
656 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
657 strings::StrCat(kOutput, "_", i),
658 (*batch)[i].Slice(0, num_elements)));
659 } else {
660 TF_RETURN_IF_ERROR(
661 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
662 strings::StrCat(kOutput, "_", i), (*batch)[i]));
663 }
664 }
665 return OkStatus();
666}
667
668Status ReadStatus(const string& iterator_prefix, const string& prefix,
669 IteratorStateReader* reader, Status* status) {
670 int64_t code_int;
671 TF_RETURN_IF_ERROR(reader->ReadScalar(
672 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
673 &code_int));
674 error::Code code = static_cast<error::Code>(code_int);
675
676 if (code != error::Code::OK) {
677 tstring error_message;
678 TF_RETURN_IF_ERROR(reader->ReadScalar(
679 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
680 &error_message));
681 *status = Status(code, error_message);
682 } else {
683 *status = OkStatus();
684 }
685 return OkStatus();
686}
687
688Status WriteStatus(const string& iterator_prefix, const string& prefix,
689 const Status& status, IteratorStateWriter* writer) {
690 TF_RETURN_IF_ERROR(writer->WriteScalar(
691 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
692 static_cast<int64_t>(status.code())));
693 if (!status.ok()) {
694 TF_RETURN_IF_ERROR(writer->WriteScalar(
695 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
696 status.error_message()));
697 }
698 return OkStatus();
699}
700
701Status ProcessBatch(int64_t batch_size, int64_t num_elements,
702 bool drop_remainder, const Status& status,
703 IteratorContext* ctx, std::vector<Tensor>* output,
704 bool* end_of_sequence, std::vector<Tensor>* batch) {
705 if (num_elements == 0) {
706 if (status.ok() || errors::IsOutOfRange(status)) {
707 *end_of_sequence = true;
708 return OkStatus();
709 } else {
710 *end_of_sequence = false;
711 return status;
712 }
713 }
714 if (!status.ok() && !errors::IsOutOfRange(status)) {
715 *end_of_sequence = false;
716 return status;
717 }
718 if (num_elements < batch_size) {
719 if (drop_remainder) {
720 *end_of_sequence = true;
721 return OkStatus();
722 }
723 for (size_t i = 0; i < batch->size(); ++i) {
724 TensorShape component_shape((*batch)[i].shape());
725 component_shape.set_dim(0, num_elements);
726 AllocatorAttributes attr;
727 attr.set_gpu_compatible(true);
728 output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
729 component_shape);
730 if (!output->back().IsInitialized()) {
731 return errors::ResourceExhausted(
732 "Failed to allocate memory for the batch of component ", i);
733 }
734 TF_RETURN_IF_ERROR(
735 CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
736 }
737 } else {
738 *output = std::move(*batch);
739 }
740 *end_of_sequence = false;
741 return OkStatus();
742}
743
744Status CopyBatch(CopyBatchParams params,
745 const std::vector<std::vector<Tensor>>& batch_elements,
746 bool parallel_copy,
747 std::function<Status()> allocation_callback,
748 std::vector<Tensor>* out_tensors) {
749 const size_t num_tuple_components = batch_elements.at(0).size();
750 out_tensors->reserve(num_tuple_components);
751 const int64_t num_batch_elements = batch_elements.size();
752 for (size_t component_index = 0; component_index < num_tuple_components;
753 ++component_index) {
754 const Tensor& first_element = batch_elements.at(0)[component_index];
755 TensorShape first_element_shape(first_element.shape());
756 TensorShape batch_component_shape({num_batch_elements});
757 batch_component_shape.AppendShape(first_element_shape);
758 out_tensors->emplace_back(params.allocator, first_element.dtype(),
759 batch_component_shape);
760 if (!out_tensors->back().IsInitialized()) {
761 return errors::ResourceExhausted(
762 "Failed to allocate memory for the batch of component ",
763 component_index);
764 }
765 }
766 if (allocation_callback) {
767 TF_RETURN_IF_ERROR(allocation_callback());
768 }
769 for (size_t component_index = 0; component_index < num_tuple_components;
770 ++component_index) {
771 Tensor& batch_component = out_tensors->at(component_index);
772 const Tensor& first_element = batch_elements.at(0)[component_index];
773 TensorShape first_element_shape(first_element.shape());
774 // Build the output tuple component by copying one slice from each input
775 // element in the batch.
776 auto copy_element_fn = [component_index, &batch_elements, &batch_component,
777 &first_element_shape](int index) {
778 if (batch_elements.at(index)[component_index].shape() !=
779 first_element_shape) {
780 return errors::InvalidArgument(
781 "Cannot batch tensors with different shapes in component ",
782 component_index, ". First element had shape ",
783 first_element_shape.DebugString(), " and element ", index,
784 " had shape ",
785 batch_elements.at(index)[component_index].shape().DebugString(),
786 ".");
787 }
788 return batch_util::CopyElementToSlice(
789 std::move(batch_elements.at(index)[component_index]),
790 &batch_component, index);
791 };
792 if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) {
793 Status status;
794 mutex status_mu;
795 BlockingCounter counter(num_batch_elements);
796 const auto num_threads = params.runner_threadpool_size;
797 const auto slice_size = num_batch_elements / num_threads;
798 int64_t offset = 0;
799 for (size_t i = 0; i < num_threads; ++i) {
800 int64_t length = slice_size;
801 // When the number of threads does not divide the number of elements
802 // evenly, the size of some slices is incremented to guarantee their
803 // sizes add up to the total number of elements.
804 if (i < num_batch_elements % num_threads) ++length;
805 (*params.runner)([offset, length, &status, &status_mu, &counter,
806 &copy_element_fn]() {
807 for (size_t j = offset; j < offset + length; ++j) {
808 {
809 Status s = copy_element_fn(j);
810 mutex_lock l(status_mu);
811 status.Update(s);
812 }
813 counter.DecrementCount();
814 }
815 });
816 offset += length;
817 }
818 counter.Wait();
819 TF_RETURN_IF_ERROR(status);
820 } else {
821 for (size_t i = 0; i < num_batch_elements; ++i) {
822 TF_RETURN_IF_ERROR(copy_element_fn(i));
823 }
824 }
825 }
826 return OkStatus();
827}
828
829absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
830 absl::flat_hash_set<tstring> configs;
831 const auto& autotune_options = options.autotune_options();
832 std::vector<tstring> autotune_only_optimizations = {
833 kAutotuneBufferSizesOpt,
834 kBatchParallelizationOpt,
835 kDisablePrefetchLegacyAutotuneOpt,
836 kEnableGradientDescentOpt,
837 kFilterParallelizationOpt,
838 kMapParallelizationOpt,
839 kInjectPrefetchOpt};
840
841 if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
842 !autotune_options.enabled()) {
843 for (const auto& optimization : autotune_only_optimizations) {
844 configs.insert(
845 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
846 }
847 } else {
848 for (const auto& optimization : autotune_only_optimizations) {
849 configs.insert(
850 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
851 }
852 }
853 if (options.slack()) {
854 int num_devices = 1;
855 if (options.distribute_options().optional_num_devices_case() ==
856 DistributeOptions::kNumDevices) {
857 num_devices = options.distribute_options().num_devices();
858 }
859 configs.insert(
860 absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
861 }
862 return configs;
863}
864
865bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
866 return options.threading_options().optional_max_intra_op_parallelism_case() ==
867 ThreadingOptions::kMaxIntraOpParallelism;
868}
869
870bool ShouldUsePrivateThreadPool(const Options& options) {
871 return options.threading_options().optional_private_threadpool_size_case() ==
872 ThreadingOptions::kPrivateThreadpoolSize;
873}
874
875bool ShouldUseAutotuning(const Options& options) {
876 return options.autotune_options().optional_enabled_case() !=
877 AutotuneOptions::kEnabled ||
878 options.autotune_options().enabled();
879}
880
881bool ShouldApplyOptimizations(
882 const Options& options,
883 const absl::flat_hash_set<tstring>& optimizations_enabled,
884 const absl::flat_hash_set<tstring>& optimizations_default) {
885 return (options.optimization_options()
886 .optional_apply_default_optimizations_case() !=
887 OptimizationOptions::kApplyDefaultOptimizations ||
888 options.optimization_options().apply_default_optimizations() ||
889 !optimizations_enabled.empty() || !optimizations_default.empty());
890}
891
892int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) {
893 return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size());
894}
895
896// static
897void DatasetExperimentRegistry::Register(const string& experiment,
898 JobSelector job_selector,
899 TaskSelector task_selector) {
900 mutex_lock l(*get_dataset_experiment_registry_lock());
901 get_dataset_experiments()->insert(
902 std::make_pair(experiment, DatasetExperimentRegistry::ExperimentSelector{
903 job_selector, task_selector}));
904}
905
906// static
907absl::flat_hash_map<string, DatasetExperimentRegistry::ExperimentSelector>
908DatasetExperimentRegistry::Experiments() {
909 mutex_lock l(*get_dataset_experiment_registry_lock());
910 return *get_dataset_experiments();
911}
912
913namespace {
914
915// Select `rollout_pct` percent of jobs at random. `hash_func` takes a string
916// and returns a uint64 between 0 and 1.
917template <int64_t rollout_pct>
918bool RandomJobSamplePercentage(std::function<uint64_t(const string&)> hash_func,
919 const std::string& experiment_name,
920 const std::string& job_name) {
921 return hash_func(strings::StrCat(job_name, experiment_name)) % 100 <
922 rollout_pct;
923}
924bool AllTasks(int64_t task_id) { return true; }
925
926REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations",
927 RandomJobSamplePercentage<0>, AllTasks);
928REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization",
929 RandomJobSamplePercentage<0>, AllTasks);
930REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt,
931 RandomJobSamplePercentage<0>, AllTasks);
932REGISTER_DATASET_EXPERIMENT("inject_prefetch", RandomJobSamplePercentage<100>,
933 AllTasks);
934REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism",
935 RandomJobSamplePercentage<0>, AllTasks);
936REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch",
937 RandomJobSamplePercentage<0>, AllTasks);
938REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length",
939 RandomJobSamplePercentage<0>, AllTasks);
940REGISTER_DATASET_EXPERIMENT("stage_based_autotune",
941 RandomJobSamplePercentage<0>, AllTasks);
942} // namespace
943} // namespace data
944} // namespace tensorflow
945