1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/data/dataset_utils.h" |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <queue> |
22 | #include <string> |
23 | #include <utility> |
24 | |
25 | #include "absl/container/flat_hash_map.h" |
26 | #include "absl/container/flat_hash_set.h" |
27 | #include "absl/strings/str_join.h" |
28 | #include "tensorflow/core/common_runtime/function.h" |
29 | #include "tensorflow/core/framework/attr_value.pb.h" |
30 | #include "tensorflow/core/framework/dataset.h" |
31 | #include "tensorflow/core/framework/function.h" |
32 | #include "tensorflow/core/framework/node_def_util.h" |
33 | #include "tensorflow/core/framework/op_def_builder.h" |
34 | #include "tensorflow/core/framework/op_def_util.h" |
35 | #include "tensorflow/core/framework/op_kernel.h" |
36 | #include "tensorflow/core/framework/tensor.pb.h" |
37 | #include "tensorflow/core/framework/tensor_util.h" |
38 | #include "tensorflow/core/framework/types.h" |
39 | #include "tensorflow/core/graph/graph_def_builder.h" |
40 | #include "tensorflow/core/lib/core/errors.h" |
41 | #include "tensorflow/core/lib/hash/hash.h" |
42 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
43 | #include "tensorflow/core/platform/blocking_counter.h" |
44 | #include "tensorflow/core/platform/host_info.h" |
45 | #include "tensorflow/core/platform/regexp.h" |
46 | #include "tensorflow/core/util/determinism.h" |
47 | #include "tensorflow/core/util/work_sharder.h" |
48 | |
49 | namespace tensorflow { |
50 | namespace data { |
51 | namespace { |
52 | |
53 | constexpr char kOutputSize[] = "output_size" ; |
54 | constexpr char kCode[] = "code" ; |
55 | constexpr char kMessage[] = "msg" ; |
56 | constexpr char kOutput[] = "output" ; |
57 | |
58 | static mutex* get_dataset_experiment_registry_lock() { |
59 | static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED); |
60 | return &dataset_experiment_registry_lock; |
61 | } |
62 | |
63 | static absl::flat_hash_map<string, |
64 | DatasetExperimentRegistry::ExperimentSelector>* |
65 | get_dataset_experiments() { |
66 | static absl::flat_hash_map< |
67 | string, DatasetExperimentRegistry::ExperimentSelector>* experiments = |
68 | new absl::flat_hash_map<string, |
69 | DatasetExperimentRegistry::ExperimentSelector>; |
70 | return experiments; |
71 | } |
72 | |
73 | // Use "Opt" suffix so that they are not confused with the enums in Options |
74 | // proto. |
75 | constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion" ; |
76 | constexpr char kNoopEliminationOpt[] = "noop_elimination" ; |
77 | constexpr char kMapParallelizationOpt[] = "map_parallelization" ; |
78 | constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion" ; |
79 | constexpr char kFilterFusionOpt[] = "filter_fusion" ; |
80 | constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion" ; |
81 | constexpr char kMapFusionOpt[] = "map_fusion" ; |
82 | constexpr char kParallelBatchOpt[] = "parallel_batch" ; |
83 | constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes" ; |
84 | constexpr char kDisablePrefetchLegacyAutotuneOpt[] = |
85 | "disable_prefetch_legacy_autotune" ; |
86 | constexpr char kMakeSloppyOpt[] = "make_sloppy" ; |
87 | constexpr char kUseChooseFastestOpt[] = "use_choose_fastest" ; |
88 | constexpr char kBatchParallelizationOpt[] = "batch_parallelization" ; |
89 | constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent" ; |
90 | constexpr char kInjectPrefetchOpt[] = "inject_prefetch" ; |
91 | constexpr char kAutotuneOpt[] = "autotune" ; |
92 | constexpr char kSlackOpt[] = "slack" ; |
93 | constexpr char kSlackPeriodOpt[] = "slack_period" ; |
94 | constexpr char kMakeDeterministicOpt[] = "make_deterministic" ; |
95 | constexpr char kFilterParallelizationOpt[] = "filter_parallelization" ; |
96 | |
97 | void DefaultOptimizationGraphRewrites( |
98 | const Options& options, absl::flat_hash_set<tstring>* optimization_enabled, |
99 | absl::flat_hash_set<tstring>* optimization_disabled, |
100 | absl::flat_hash_set<tstring>* optimization_default) { |
101 | const auto& optimization_options = options.optimization_options(); |
102 | if (optimization_options.optional_apply_default_optimizations_case() != |
103 | OptimizationOptions::kApplyDefaultOptimizations || |
104 | optimization_options.apply_default_optimizations()) { |
105 | if (optimization_options.optional_map_and_batch_fusion_case() != |
106 | OptimizationOptions::kMapAndBatchFusion) { |
107 | optimization_default->insert(kMapAndBatchFusionOpt); |
108 | } |
109 | if (optimization_options.optional_noop_elimination_case() != |
110 | OptimizationOptions::kNoopElimination) { |
111 | optimization_default->insert(kNoopEliminationOpt); |
112 | } |
113 | if (optimization_options.optional_map_parallelization_case() != |
114 | OptimizationOptions::kMapParallelization) { |
115 | optimization_default->insert(kMapParallelizationOpt); |
116 | } |
117 | if (optimization_options.optional_shuffle_and_repeat_fusion_case() != |
118 | OptimizationOptions::kShuffleAndRepeatFusion) { |
119 | optimization_default->insert(kShuffleAndRepeatFusionOpt); |
120 | } |
121 | if (optimization_options.optional_parallel_batch_case() != |
122 | OptimizationOptions::kParallelBatch) { |
123 | optimization_default->insert(kParallelBatchOpt); |
124 | } |
125 | } |
126 | if (OpDeterminismRequired()) { |
127 | optimization_enabled->insert(kMakeDeterministicOpt); |
128 | } |
129 | if (optimization_options.optional_filter_fusion_case() == |
130 | OptimizationOptions::kFilterFusion) { |
131 | if (optimization_options.filter_fusion()) { |
132 | optimization_enabled->insert(kFilterFusionOpt); |
133 | } else { |
134 | optimization_disabled->insert(kFilterFusionOpt); |
135 | } |
136 | } |
137 | if (optimization_options.optional_map_and_batch_fusion_case() == |
138 | OptimizationOptions::kMapAndBatchFusion) { |
139 | if (optimization_options.map_and_batch_fusion()) { |
140 | optimization_enabled->insert(kMapAndBatchFusionOpt); |
141 | } else { |
142 | optimization_disabled->insert(kMapAndBatchFusionOpt); |
143 | } |
144 | } |
145 | if (optimization_options.optional_map_and_filter_fusion_case() == |
146 | OptimizationOptions::kMapAndFilterFusion) { |
147 | if (optimization_options.map_and_filter_fusion()) { |
148 | optimization_enabled->insert(kMapAndFilterFusionOpt); |
149 | } else { |
150 | optimization_disabled->insert(kMapAndFilterFusionOpt); |
151 | } |
152 | } |
153 | if (optimization_options.optional_map_parallelization_case() == |
154 | OptimizationOptions::kMapParallelization) { |
155 | if (optimization_options.map_parallelization()) { |
156 | optimization_enabled->insert(kMapParallelizationOpt); |
157 | } else { |
158 | optimization_disabled->insert(kMapParallelizationOpt); |
159 | } |
160 | } |
161 | if (optimization_options.optional_filter_parallelization_case() == |
162 | OptimizationOptions::kFilterParallelization) { |
163 | if (optimization_options.filter_parallelization()) { |
164 | optimization_enabled->insert(kFilterParallelizationOpt); |
165 | } else { |
166 | optimization_disabled->insert(kFilterParallelizationOpt); |
167 | } |
168 | } |
169 | if (optimization_options.optional_map_fusion_case() == |
170 | OptimizationOptions::kMapFusion) { |
171 | if (optimization_options.map_fusion()) { |
172 | optimization_enabled->insert(kMapFusionOpt); |
173 | } else { |
174 | optimization_disabled->insert(kMapFusionOpt); |
175 | } |
176 | } |
177 | if (optimization_options.optional_noop_elimination_case() == |
178 | OptimizationOptions::kNoopElimination) { |
179 | if (optimization_options.noop_elimination()) { |
180 | optimization_enabled->insert(kNoopEliminationOpt); |
181 | } else { |
182 | optimization_disabled->insert(kNoopEliminationOpt); |
183 | } |
184 | } |
185 | if (optimization_options.optional_parallel_batch_case() == |
186 | OptimizationOptions::kParallelBatch) { |
187 | if (optimization_options.parallel_batch()) { |
188 | optimization_enabled->insert(kParallelBatchOpt); |
189 | } else { |
190 | optimization_disabled->insert(kParallelBatchOpt); |
191 | } |
192 | } |
193 | if (optimization_options.optional_shuffle_and_repeat_fusion_case() == |
194 | OptimizationOptions::kShuffleAndRepeatFusion) { |
195 | if (optimization_options.shuffle_and_repeat_fusion()) { |
196 | optimization_enabled->insert(kShuffleAndRepeatFusionOpt); |
197 | } else { |
198 | optimization_disabled->insert(kShuffleAndRepeatFusionOpt); |
199 | } |
200 | } |
201 | if (optimization_options.optional_inject_prefetch_case() == |
202 | OptimizationOptions::kInjectPrefetch) { |
203 | if (optimization_options.inject_prefetch()) { |
204 | optimization_enabled->insert(kInjectPrefetchOpt); |
205 | } else { |
206 | optimization_disabled->insert(kInjectPrefetchOpt); |
207 | } |
208 | } |
209 | } |
210 | |
211 | // Returns whether an op has been allowlisted as stateless. Uses a heuristic to |
212 | // allowlist source dataset ops which have been marked stateful due to |
213 | // b/65524810. Also looks up the `op_def->name` in the global |
214 | // `AllowlistedStatefulOpRegistry`. |
215 | bool IsOpAllowlisted(const OpDef* op_def) { |
216 | return (op_def->output_arg_size() == 1 && |
217 | op_def->output_arg(0).type() == DT_VARIANT && |
218 | (absl::EndsWith(op_def->name(), "Dataset" ) || |
219 | absl::EndsWith(op_def->name(), "DatasetV2" ))) || |
220 | AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name()); |
221 | } |
222 | |
223 | } // namespace |
224 | |
225 | std::pair<int64_t, int64_t> MaybeOverrideSeeds( |
226 | std::pair<int64_t, int64_t> seeds) { |
227 | if (seeds.first == 0 && seeds.second == 0) { |
228 | return {random::New64(), random::New64()}; |
229 | } |
230 | return seeds; |
231 | } |
232 | |
233 | Status VerifyTypeMatch(const DataType& expected, const DataType& received, |
234 | int index) { |
235 | if (expected != received) { |
236 | return errors::InvalidArgument("Data type mismatch at component " , index, |
237 | ": expected " , DataTypeString(expected), |
238 | " but got " , DataTypeString(received), "." ); |
239 | } |
240 | return OkStatus(); |
241 | } |
242 | |
243 | Status VerifyTypesMatch(const DataTypeVector& expected, |
244 | const DataTypeVector& received) { |
245 | if (expected.size() != received.size()) { |
246 | return errors::InvalidArgument( |
247 | "Number of components does not match: expected " , expected.size(), |
248 | " types but got " , received.size(), "." ); |
249 | } |
250 | for (size_t i = 0; i < expected.size(); ++i) { |
251 | TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i)); |
252 | } |
253 | return OkStatus(); |
254 | } |
255 | |
256 | Status VerifyTypesMatch(const DataTypeVector& expected, |
257 | const std::vector<Tensor>& received) { |
258 | if (expected.size() != received.size()) { |
259 | return errors::InvalidArgument( |
260 | "Number of components does not match: expected " , expected.size(), |
261 | " types but got " , received.size(), "." ); |
262 | } |
263 | for (size_t i = 0; i < expected.size(); ++i) { |
264 | TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i)); |
265 | } |
266 | return OkStatus(); |
267 | } |
268 | |
269 | Status VerifyShapeCompatible(const PartialTensorShape& expected, |
270 | const PartialTensorShape& received, int index) { |
271 | if (!expected.IsCompatibleWith(received)) { |
272 | return errors::InvalidArgument("Incompatible shapes at component " , index, |
273 | ": expected " , expected.DebugString(), |
274 | " but got " , received.DebugString(), "." ); |
275 | } |
276 | return OkStatus(); |
277 | } |
278 | |
279 | Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
280 | const std::vector<PartialTensorShape>& received) { |
281 | if (expected.size() != received.size()) { |
282 | return errors::InvalidArgument( |
283 | "Number of components does not match: expected " , expected.size(), |
284 | " shapes but got " , received.size(), "." ); |
285 | } |
286 | for (size_t i = 0; i < expected.size(); ++i) { |
287 | TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i)); |
288 | } |
289 | |
290 | return OkStatus(); |
291 | } |
292 | |
293 | Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
294 | const std::vector<Tensor>& received) { |
295 | if (expected.size() != received.size()) { |
296 | return errors::InvalidArgument( |
297 | "Number of components does not match: expected " , expected.size(), |
298 | " shapes but got " , received.size(), "." ); |
299 | } |
300 | for (size_t i = 0; i < expected.size(); ++i) { |
301 | TF_RETURN_IF_ERROR( |
302 | VerifyShapeCompatible(expected[i], received[i].shape(), i)); |
303 | } |
304 | |
305 | return OkStatus(); |
306 | } |
307 | |
308 | Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
309 | const FunctionLibraryDefinition& to_add) { |
310 | for (const auto& fn : to_add.ListFunctionNames()) { |
311 | if (auto found = base->Find(fn)) { |
312 | if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) { |
313 | return errors::InvalidArgument("Cannot add function '" , fn, |
314 | "' because a different function with " |
315 | "the same signature already exists." ); |
316 | } |
317 | TF_RETURN_IF_ERROR(base->RemoveFunction(fn)); |
318 | } |
319 | } |
320 | return base->AddLibrary(to_add); |
321 | } |
322 | |
323 | Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
324 | const FunctionDefLibrary& to_add) { |
325 | for (const auto& fd : to_add.function()) { |
326 | if (auto found = base->Find(fd.signature().name())) { |
327 | if (!OpDefEqual(found->signature(), fd.signature())) { |
328 | return errors::InvalidArgument("Cannot add function '" , |
329 | fd.signature().name(), |
330 | "' because a different function with " |
331 | "the same signature already exists." ); |
332 | } |
333 | TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name())); |
334 | } |
335 | } |
336 | return base->AddLibrary(to_add); |
337 | } |
338 | |
339 | Status IsFunctionStateful(const FunctionLibraryDefinition& library, |
340 | const FunctionDef& function_def) { |
341 | if (!function_def.signature().is_stateful()) { |
342 | return OkStatus(); |
343 | } |
344 | |
345 | for (const NodeDef& node_def : function_def.node_def()) { |
346 | TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def)); |
347 | } |
348 | return OkStatus(); |
349 | } |
350 | |
351 | Status IsNodeStateful(const FunctionLibraryDefinition& library, |
352 | const NodeDef& node) { |
353 | const OpDef* op_def; |
354 | |
355 | // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore |
356 | // `LookUpOpDef` errors here. |
357 | if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() || |
358 | IsOpAllowlisted(op_def) || !op_def->is_stateful() || |
359 | op_def->name() == "Assert" ) { |
360 | return OkStatus(); |
361 | } |
362 | |
363 | if (op_def->name() == "If" ) { |
364 | const FunctionDef* then_func = |
365 | library.Find(node.attr().at("then_branch" ).func().name()); |
366 | const FunctionDef* else_func = |
367 | library.Find(node.attr().at("else_branch" ).func().name()); |
368 | if (then_func != nullptr) { |
369 | TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func)); |
370 | } |
371 | if (else_func != nullptr) { |
372 | TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func)); |
373 | } |
374 | return OkStatus(); |
375 | } |
376 | |
377 | if (op_def->name() == "While" ) { |
378 | const FunctionDef* cond_func = |
379 | library.Find(node.attr().at("cond" ).func().name()); |
380 | const FunctionDef* body_func = |
381 | library.Find(node.attr().at("body" ).func().name()); |
382 | if (cond_func != nullptr) { |
383 | TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func)); |
384 | } |
385 | if (body_func != nullptr) { |
386 | TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func)); |
387 | } |
388 | return OkStatus(); |
389 | } |
390 | |
391 | return errors::FailedPrecondition(op_def->name(), " is stateful." ); |
392 | } |
393 | |
394 | std::function<void(std::function<void()>)> RunnerWithMaxParallelism( |
395 | std::function<void(std::function<void()>)> runner, int max_parallelism) { |
396 | return std::bind( |
397 | [max_parallelism]( |
398 | // Note: `runner` is a const reference to avoid copying it. |
399 | const std::function<void(std::function<void()>)>& runner, |
400 | std::function<void()> fn) { |
401 | std::function<void()> scoped_fn = std::bind( |
402 | [max_parallelism](const std::function<void()>& fn) { |
403 | ScopedPerThreadMaxParallelism scope(max_parallelism); |
404 | fn(); |
405 | }, |
406 | std::move(fn)); |
407 | runner(std::move(scoped_fn)); |
408 | }, |
409 | std::move(runner), std::placeholders::_1); |
410 | } |
411 | |
412 | Status DeterminismPolicy::FromString(const std::string& s, |
413 | DeterminismPolicy* out) { |
414 | DeterminismPolicy::Type type; |
415 | if (s == DeterminismPolicy::kDeterministic) { |
416 | type = DeterminismPolicy::Type::kDeterministic; |
417 | } else if (s == DeterminismPolicy::kNondeterministic) { |
418 | type = DeterminismPolicy::Type::kNondeterministic; |
419 | } else if (s == DeterminismPolicy::kDefault) { |
420 | type = DeterminismPolicy::Type::kDefault; |
421 | } else { |
422 | return errors::InvalidArgument("Unrecognized determinism policy: " , s); |
423 | } |
424 | *out = DeterminismPolicy(type); |
425 | return OkStatus(); |
426 | } |
427 | |
428 | DeterminismPolicy::DeterminismPolicy(bool is_deterministic) { |
429 | if (is_deterministic) { |
430 | determinism_ = DeterminismPolicy::Type::kDeterministic; |
431 | } else { |
432 | determinism_ = DeterminismPolicy::Type::kNondeterministic; |
433 | } |
434 | } |
435 | |
436 | std::string DeterminismPolicy::String() const { |
437 | switch (determinism_) { |
438 | case DeterminismPolicy::Type::kDeterministic: |
439 | return DeterminismPolicy::kDeterministic; |
440 | case DeterminismPolicy::Type::kNondeterministic: |
441 | return DeterminismPolicy::kNondeterministic; |
442 | case DeterminismPolicy::Type::kDefault: |
443 | return DeterminismPolicy::kDefault; |
444 | default: |
445 | LOG(ERROR) << "Unrecognized determinism value" ; |
446 | return "Unrecognized" ; |
447 | } |
448 | } |
449 | |
450 | bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) { |
451 | if (!absl::StartsWith(op_to_match, op_prefix)) { |
452 | return false; |
453 | } |
454 | if (op_to_match.length() == op_prefix.length()) { |
455 | return true; |
456 | } |
457 | size_t index = op_to_match.length() - 1; |
458 | while (isdigit(op_to_match[index])) { |
459 | index--; |
460 | } |
461 | return (op_to_match[index] == 'V') && (op_prefix.length() == index); |
462 | } |
463 | |
464 | absl::flat_hash_set<string> GetExperiments() { |
465 | return GetExperiments(tsl::port::JobName(), tsl::port::TaskId(), |
466 | [](const tstring& str) { return Hash64(str); }); |
467 | } |
468 | |
469 | absl::flat_hash_set<string> GetExperiments( |
470 | const string& job_name, int64_t task_id, |
471 | std::function<uint64_t(const string&)> hash_func) { |
472 | absl::flat_hash_set<string> experiments; |
473 | if (job_name.empty() || task_id < 0) { |
474 | return experiments; |
475 | } |
476 | |
477 | // Parse the opt-in and opt-out settings. |
478 | const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN" ); |
479 | const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT" ); |
480 | string opt_ins_raw; |
481 | if (opt_ins_raw_cs != nullptr) { |
482 | opt_ins_raw = string(opt_ins_raw_cs); |
483 | } |
484 | string opt_outs_raw; |
485 | if (opt_outs_raw_cs != nullptr) { |
486 | opt_outs_raw = string(opt_outs_raw_cs); |
487 | } |
488 | |
489 | // Identify opted out experiments. |
490 | auto live_experiments = DatasetExperimentRegistry::Experiments(); |
491 | absl::flat_hash_set<string> opt_outs; |
492 | if (opt_outs_raw == "all" ) { |
493 | for (const auto& pair : live_experiments) { |
494 | opt_outs.insert(pair.first); |
495 | } |
496 | } else { |
497 | for (const auto& experiment : |
498 | str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) { |
499 | opt_outs.insert(experiment); |
500 | } |
501 | } |
502 | |
503 | // Include opted in experiments unless they are opted out. |
504 | if (opt_ins_raw == "all" ) { |
505 | for (const auto& pair : live_experiments) { |
506 | auto experiment = pair.first; |
507 | if (!opt_outs.contains(experiment)) { |
508 | experiments.insert(experiment); |
509 | } |
510 | } |
511 | } else { |
512 | for (const auto& experiment : |
513 | str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) { |
514 | if (!opt_outs.contains(experiment)) { |
515 | experiments.insert(experiment); |
516 | } |
517 | } |
518 | } |
519 | |
520 | if (opt_outs_raw == "all_except_opt_in" ) { |
521 | return experiments; |
522 | } |
523 | // Stochastically include live experiments unless they are opted out. |
524 | for (const auto& [experiment_name, experiment_selector] : live_experiments) { |
525 | if (experiment_selector.job_selector(hash_func, experiment_name, |
526 | job_name) && |
527 | experiment_selector.task_selector(task_id) && |
528 | !opt_outs.contains(experiment_name)) { |
529 | experiments.insert(experiment_name); |
530 | } |
531 | } |
532 | |
533 | return experiments; |
534 | } |
535 | |
536 | void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) { |
537 | if (!experiments.empty()) { |
538 | constexpr float TEN_MINUTES = 60.0 * 10.0; |
539 | LOG_EVERY_N_SEC(INFO, TEN_MINUTES) |
540 | << "The input pipeline is subject to the following tf.data experiments:" |
541 | << " " << absl::StrJoin(experiments, ", " ) << ". " |
542 | << "See `go/tf-data-experiments` for more details." ; |
543 | } |
544 | for (auto& experiment : experiments) { |
545 | metrics::RecordTFDataExperiment(experiment); |
546 | } |
547 | } |
548 | |
549 | void GetOptimizations(const Options& options, |
550 | absl::flat_hash_set<tstring>* optimizations_enabled, |
551 | absl::flat_hash_set<tstring>* optimizations_disabled, |
552 | absl::flat_hash_set<tstring>* optimizations_default) { |
553 | DefaultOptimizationGraphRewrites(options, optimizations_enabled, |
554 | optimizations_disabled, |
555 | optimizations_default); |
556 | if (!OpDeterminismRequired() && |
557 | options.optional_deterministic_case() == Options::kDeterministic && |
558 | !options.deterministic()) { |
559 | optimizations_enabled->insert(kMakeSloppyOpt); |
560 | } |
561 | if (options.optional_slack_case() == Options::kSlack) { |
562 | if (options.slack()) { |
563 | optimizations_enabled->insert(kSlackOpt); |
564 | } else { |
565 | optimizations_disabled->insert(kSlackOpt); |
566 | } |
567 | } |
568 | } |
569 | |
570 | Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) { |
571 | Tensor slice = tensor.SubSlice(index); |
572 | if (slice.IsAligned()) { |
573 | return slice; |
574 | } else { |
575 | return tensorflow::tensor::DeepCopy(slice); |
576 | } |
577 | } |
578 | |
579 | void StripDevicePlacement(FunctionDefLibrary* library) { |
580 | for (auto& function : (*library->mutable_function())) { |
581 | for (auto& node : (*function.mutable_node_def())) { |
582 | if (!node.device().empty()) { |
583 | *node.mutable_device() = "" ; |
584 | } |
585 | } |
586 | } |
587 | } |
588 | |
589 | Status CopyPartialBatch(int64_t num_elements, const Tensor& value, |
590 | Tensor* output) { |
591 | switch (value.dtype()) { |
592 | #define HANDLE_TYPE(type) \ |
593 | case DataTypeToEnum<type>::value: { \ |
594 | auto output_t = output->flat_outer_dims<type>(); \ |
595 | auto value_t = value.flat_outer_dims<type>(); \ |
596 | for (size_t i = 0; i < num_elements; i++) { \ |
597 | output_t.template chip<0>(i) = value_t.template chip<0>(i); \ |
598 | } \ |
599 | return OkStatus(); \ |
600 | } |
601 | TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
602 | #undef HANDLE_TYPE |
603 | default: |
604 | return errors::InvalidArgument("Unsupported data type: " , |
605 | DataTypeString(value.dtype())); |
606 | } |
607 | return OkStatus(); |
608 | } |
609 | |
610 | Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, |
611 | int64_t batch_size, const string& iterator_prefix, |
612 | const string& batch_prefix, std::vector<Tensor>* batch) { |
613 | int64_t output_size; |
614 | TF_RETURN_IF_ERROR(reader->ReadScalar( |
615 | FullName(iterator_prefix, |
616 | strings::StrCat(batch_prefix, "_" , kOutputSize)), |
617 | &output_size)); |
618 | batch->reserve(output_size); |
619 | for (int i = 0; i < output_size; i++) { |
620 | Tensor t; |
621 | TF_RETURN_IF_ERROR( |
622 | reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix), |
623 | strings::StrCat(kOutput, "_" , i), &t)); |
624 | // If the batch was not full, we may have stored only the relevant slice. |
625 | // Since tensors in `BatchResult.output` are expected to have the leading |
626 | // dimension of size batch_size, we build a larger tensor and copy the slice |
627 | // read from the checkpoint into it. |
628 | if (t.dim_size(0) < batch_size) { |
629 | TensorShape component_shape(t.shape()); |
630 | component_shape.set_dim(0, batch_size); |
631 | AllocatorAttributes attr; |
632 | attr.set_gpu_compatible(true); |
633 | Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape); |
634 | TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t)); |
635 | batch->emplace_back(std::move(new_t)); |
636 | } else { |
637 | batch->emplace_back(std::move(t)); |
638 | } |
639 | } |
640 | return OkStatus(); |
641 | } |
642 | |
643 | Status WriteBatch(int64_t batch_size, int64_t num_elements, |
644 | const string& iterator_prefix, const string& batch_prefix, |
645 | IteratorStateWriter* writer, std::vector<Tensor>* batch) { |
646 | TF_RETURN_IF_ERROR(writer->WriteScalar( |
647 | FullName(iterator_prefix, |
648 | strings::StrCat(batch_prefix, "_" , kOutputSize)), |
649 | batch->size())); |
650 | for (int i = 0; i < batch->size(); i++) { |
651 | // If the batch is not full, we only store the first `num_elements` values. |
652 | // The rest of the batch tensor is *uninitialized* and accessing that will |
653 | // raise msan errors. |
654 | if (num_elements < batch_size) { |
655 | TF_RETURN_IF_ERROR( |
656 | writer->WriteTensor(FullName(iterator_prefix, batch_prefix), |
657 | strings::StrCat(kOutput, "_" , i), |
658 | (*batch)[i].Slice(0, num_elements))); |
659 | } else { |
660 | TF_RETURN_IF_ERROR( |
661 | writer->WriteTensor(FullName(iterator_prefix, batch_prefix), |
662 | strings::StrCat(kOutput, "_" , i), (*batch)[i])); |
663 | } |
664 | } |
665 | return OkStatus(); |
666 | } |
667 | |
668 | Status ReadStatus(const string& iterator_prefix, const string& prefix, |
669 | IteratorStateReader* reader, Status* status) { |
670 | int64_t code_int; |
671 | TF_RETURN_IF_ERROR(reader->ReadScalar( |
672 | FullName(iterator_prefix, strings::StrCat(prefix, "_" , kCode)), |
673 | &code_int)); |
674 | error::Code code = static_cast<error::Code>(code_int); |
675 | |
676 | if (code != error::Code::OK) { |
677 | tstring error_message; |
678 | TF_RETURN_IF_ERROR(reader->ReadScalar( |
679 | FullName(iterator_prefix, strings::StrCat(prefix, "_" , kMessage)), |
680 | &error_message)); |
681 | *status = Status(code, error_message); |
682 | } else { |
683 | *status = OkStatus(); |
684 | } |
685 | return OkStatus(); |
686 | } |
687 | |
688 | Status WriteStatus(const string& iterator_prefix, const string& prefix, |
689 | const Status& status, IteratorStateWriter* writer) { |
690 | TF_RETURN_IF_ERROR(writer->WriteScalar( |
691 | FullName(iterator_prefix, strings::StrCat(prefix, "_" , kCode)), |
692 | static_cast<int64_t>(status.code()))); |
693 | if (!status.ok()) { |
694 | TF_RETURN_IF_ERROR(writer->WriteScalar( |
695 | FullName(iterator_prefix, strings::StrCat(prefix, "_" , kMessage)), |
696 | status.error_message())); |
697 | } |
698 | return OkStatus(); |
699 | } |
700 | |
701 | Status ProcessBatch(int64_t batch_size, int64_t num_elements, |
702 | bool drop_remainder, const Status& status, |
703 | IteratorContext* ctx, std::vector<Tensor>* output, |
704 | bool* end_of_sequence, std::vector<Tensor>* batch) { |
705 | if (num_elements == 0) { |
706 | if (status.ok() || errors::IsOutOfRange(status)) { |
707 | *end_of_sequence = true; |
708 | return OkStatus(); |
709 | } else { |
710 | *end_of_sequence = false; |
711 | return status; |
712 | } |
713 | } |
714 | if (!status.ok() && !errors::IsOutOfRange(status)) { |
715 | *end_of_sequence = false; |
716 | return status; |
717 | } |
718 | if (num_elements < batch_size) { |
719 | if (drop_remainder) { |
720 | *end_of_sequence = true; |
721 | return OkStatus(); |
722 | } |
723 | for (size_t i = 0; i < batch->size(); ++i) { |
724 | TensorShape component_shape((*batch)[i].shape()); |
725 | component_shape.set_dim(0, num_elements); |
726 | AllocatorAttributes attr; |
727 | attr.set_gpu_compatible(true); |
728 | output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(), |
729 | component_shape); |
730 | if (!output->back().IsInitialized()) { |
731 | return errors::ResourceExhausted( |
732 | "Failed to allocate memory for the batch of component " , i); |
733 | } |
734 | TF_RETURN_IF_ERROR( |
735 | CopyPartialBatch(num_elements, (*batch)[i], &output->back())); |
736 | } |
737 | } else { |
738 | *output = std::move(*batch); |
739 | } |
740 | *end_of_sequence = false; |
741 | return OkStatus(); |
742 | } |
743 | |
744 | Status CopyBatch(CopyBatchParams params, |
745 | const std::vector<std::vector<Tensor>>& batch_elements, |
746 | bool parallel_copy, |
747 | std::function<Status()> allocation_callback, |
748 | std::vector<Tensor>* out_tensors) { |
749 | const size_t num_tuple_components = batch_elements.at(0).size(); |
750 | out_tensors->reserve(num_tuple_components); |
751 | const int64_t num_batch_elements = batch_elements.size(); |
752 | for (size_t component_index = 0; component_index < num_tuple_components; |
753 | ++component_index) { |
754 | const Tensor& first_element = batch_elements.at(0)[component_index]; |
755 | TensorShape first_element_shape(first_element.shape()); |
756 | TensorShape batch_component_shape({num_batch_elements}); |
757 | batch_component_shape.AppendShape(first_element_shape); |
758 | out_tensors->emplace_back(params.allocator, first_element.dtype(), |
759 | batch_component_shape); |
760 | if (!out_tensors->back().IsInitialized()) { |
761 | return errors::ResourceExhausted( |
762 | "Failed to allocate memory for the batch of component " , |
763 | component_index); |
764 | } |
765 | } |
766 | if (allocation_callback) { |
767 | TF_RETURN_IF_ERROR(allocation_callback()); |
768 | } |
769 | for (size_t component_index = 0; component_index < num_tuple_components; |
770 | ++component_index) { |
771 | Tensor& batch_component = out_tensors->at(component_index); |
772 | const Tensor& first_element = batch_elements.at(0)[component_index]; |
773 | TensorShape first_element_shape(first_element.shape()); |
774 | // Build the output tuple component by copying one slice from each input |
775 | // element in the batch. |
776 | auto copy_element_fn = [component_index, &batch_elements, &batch_component, |
777 | &first_element_shape](int index) { |
778 | if (batch_elements.at(index)[component_index].shape() != |
779 | first_element_shape) { |
780 | return errors::InvalidArgument( |
781 | "Cannot batch tensors with different shapes in component " , |
782 | component_index, ". First element had shape " , |
783 | first_element_shape.DebugString(), " and element " , index, |
784 | " had shape " , |
785 | batch_elements.at(index)[component_index].shape().DebugString(), |
786 | "." ); |
787 | } |
788 | return batch_util::CopyElementToSlice( |
789 | std::move(batch_elements.at(index)[component_index]), |
790 | &batch_component, index); |
791 | }; |
792 | if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) { |
793 | Status status; |
794 | mutex status_mu; |
795 | BlockingCounter counter(num_batch_elements); |
796 | const auto num_threads = params.runner_threadpool_size; |
797 | const auto slice_size = num_batch_elements / num_threads; |
798 | int64_t offset = 0; |
799 | for (size_t i = 0; i < num_threads; ++i) { |
800 | int64_t length = slice_size; |
801 | // When the number of threads does not divide the number of elements |
802 | // evenly, the size of some slices is incremented to guarantee their |
803 | // sizes add up to the total number of elements. |
804 | if (i < num_batch_elements % num_threads) ++length; |
805 | (*params.runner)([offset, length, &status, &status_mu, &counter, |
806 | ©_element_fn]() { |
807 | for (size_t j = offset; j < offset + length; ++j) { |
808 | { |
809 | Status s = copy_element_fn(j); |
810 | mutex_lock l(status_mu); |
811 | status.Update(s); |
812 | } |
813 | counter.DecrementCount(); |
814 | } |
815 | }); |
816 | offset += length; |
817 | } |
818 | counter.Wait(); |
819 | TF_RETURN_IF_ERROR(status); |
820 | } else { |
821 | for (size_t i = 0; i < num_batch_elements; ++i) { |
822 | TF_RETURN_IF_ERROR(copy_element_fn(i)); |
823 | } |
824 | } |
825 | } |
826 | return OkStatus(); |
827 | } |
828 | |
829 | absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) { |
830 | absl::flat_hash_set<tstring> configs; |
831 | const auto& autotune_options = options.autotune_options(); |
832 | std::vector<tstring> autotune_only_optimizations = { |
833 | kAutotuneBufferSizesOpt, |
834 | kBatchParallelizationOpt, |
835 | kDisablePrefetchLegacyAutotuneOpt, |
836 | kEnableGradientDescentOpt, |
837 | kFilterParallelizationOpt, |
838 | kMapParallelizationOpt, |
839 | kInjectPrefetchOpt}; |
840 | |
841 | if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled && |
842 | !autotune_options.enabled()) { |
843 | for (const auto& optimization : autotune_only_optimizations) { |
844 | configs.insert( |
845 | absl::StrCat(optimization.data(), ":" , kAutotuneOpt, ":false" )); |
846 | } |
847 | } else { |
848 | for (const auto& optimization : autotune_only_optimizations) { |
849 | configs.insert( |
850 | absl::StrCat(optimization.data(), ":" , kAutotuneOpt, ":true" )); |
851 | } |
852 | } |
853 | if (options.slack()) { |
854 | int num_devices = 1; |
855 | if (options.distribute_options().optional_num_devices_case() == |
856 | DistributeOptions::kNumDevices) { |
857 | num_devices = options.distribute_options().num_devices(); |
858 | } |
859 | configs.insert( |
860 | absl::StrCat(kSlackOpt, ":" , kSlackPeriodOpt, ":" , num_devices)); |
861 | } |
862 | return configs; |
863 | } |
864 | |
865 | bool ShouldConfigureMaxIntraOpParallelism(const Options& options) { |
866 | return options.threading_options().optional_max_intra_op_parallelism_case() == |
867 | ThreadingOptions::kMaxIntraOpParallelism; |
868 | } |
869 | |
870 | bool ShouldUsePrivateThreadPool(const Options& options) { |
871 | return options.threading_options().optional_private_threadpool_size_case() == |
872 | ThreadingOptions::kPrivateThreadpoolSize; |
873 | } |
874 | |
875 | bool ShouldUseAutotuning(const Options& options) { |
876 | return options.autotune_options().optional_enabled_case() != |
877 | AutotuneOptions::kEnabled || |
878 | options.autotune_options().enabled(); |
879 | } |
880 | |
881 | bool ShouldApplyOptimizations( |
882 | const Options& options, |
883 | const absl::flat_hash_set<tstring>& optimizations_enabled, |
884 | const absl::flat_hash_set<tstring>& optimizations_default) { |
885 | return (options.optimization_options() |
886 | .optional_apply_default_optimizations_case() != |
887 | OptimizationOptions::kApplyDefaultOptimizations || |
888 | options.optimization_options().apply_default_optimizations() || |
889 | !optimizations_enabled.empty() || !optimizations_default.empty()); |
890 | } |
891 | |
892 | int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) { |
893 | return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size()); |
894 | } |
895 | |
896 | // static |
897 | void DatasetExperimentRegistry::Register(const string& experiment, |
898 | JobSelector job_selector, |
899 | TaskSelector task_selector) { |
900 | mutex_lock l(*get_dataset_experiment_registry_lock()); |
901 | get_dataset_experiments()->insert( |
902 | std::make_pair(experiment, DatasetExperimentRegistry::ExperimentSelector{ |
903 | job_selector, task_selector})); |
904 | } |
905 | |
906 | // static |
907 | absl::flat_hash_map<string, DatasetExperimentRegistry::ExperimentSelector> |
908 | DatasetExperimentRegistry::Experiments() { |
909 | mutex_lock l(*get_dataset_experiment_registry_lock()); |
910 | return *get_dataset_experiments(); |
911 | } |
912 | |
913 | namespace { |
914 | |
915 | // Select `rollout_pct` percent of jobs at random. `hash_func` takes a string |
916 | // and returns a uint64 between 0 and 1. |
917 | template <int64_t rollout_pct> |
918 | bool RandomJobSamplePercentage(std::function<uint64_t(const string&)> hash_func, |
919 | const std::string& experiment_name, |
920 | const std::string& job_name) { |
921 | return hash_func(strings::StrCat(job_name, experiment_name)) % 100 < |
922 | rollout_pct; |
923 | } |
924 | bool AllTasks(int64_t task_id) { return true; } |
925 | |
926 | REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations" , |
927 | RandomJobSamplePercentage<0>, AllTasks); |
928 | REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization" , |
929 | RandomJobSamplePercentage<0>, AllTasks); |
930 | REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt, |
931 | RandomJobSamplePercentage<0>, AllTasks); |
932 | REGISTER_DATASET_EXPERIMENT("inject_prefetch" , RandomJobSamplePercentage<100>, |
933 | AllTasks); |
934 | REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism" , |
935 | RandomJobSamplePercentage<0>, AllTasks); |
936 | REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch" , |
937 | RandomJobSamplePercentage<0>, AllTasks); |
938 | REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length" , |
939 | RandomJobSamplePercentage<0>, AllTasks); |
940 | REGISTER_DATASET_EXPERIMENT("stage_based_autotune" , |
941 | RandomJobSamplePercentage<0>, AllTasks); |
942 | } // namespace |
943 | } // namespace data |
944 | } // namespace tensorflow |
945 | |