1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
16#define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
17
18#include <functional>
19#include <string>
20
21#include "absl/container/flat_hash_set.h"
22#include "tensorflow/core/common_runtime/function.h"
23#include "tensorflow/core/framework/dataset.h"
24#include "tensorflow/core/framework/function.h"
25#include "tensorflow/core/framework/resource_handle.h"
26#include "tensorflow/core/framework/resource_mgr.h"
27#include "tensorflow/core/framework/tensor.h"
28
29namespace tensorflow {
30namespace data {
31
32// Constant used for indicating that the argument of tf.data.Dataset.shard
33// should be supplied by the auto-sharding rewrite.
34constexpr int kShardHint = -1;
35
36// The initial parallelism value before Autotune has a chance to optimize.
37constexpr int kAutotuneDefaultParallelism = 16;
38
39// Creates a resource handle with a unique name for the given resource where
40// the resource is managed by the Resource Manager.
41template <typename T>
42Status CreateWeakHandle(OpKernelContext* ctx, T* resource,
43 const string& container_name, ResourceHandle* handle) {
44 static std::atomic<int64_t> resource_id_counter(0);
45 string unique_name =
46 strings::StrCat(container_name, resource_id_counter.fetch_add(1));
47 ResourceMgr* mgr = ctx->resource_manager();
48 TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
49
50 *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
51 TypeIndex::Make<T>());
52 return OkStatus();
53}
54
55// Creates a ref-counting resource handle for the given resource, where the
56// resource is owned by the handle.
57template <typename T>
58Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) {
59 ResourceMgr* mgr = ctx->resource_manager();
60 *handle =
61 ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name());
62 TF_RETURN_IF_ERROR(
63 mgr->CreateUnowned<T>(handle->container(), handle->name(), resource));
64 return OkStatus();
65}
66
67// TODO(b/198162355): Merge this class with ResourceOpKernel.
68template <typename T>
69class AnonymousResourceOp : public OpKernel {
70 public:
71 // Creates an AnonymousResourceOp.
72 // ref_counting: Determines if the Op returns a ref-counting ResourceHandle.
73 // ResourceHandle. See go/tf-resource-handle-ref-count.
74 // return_deleter: Determines if the Op outputs a deleter tensor in addition
75 // to the resource handle tensor.
76 // If the resource handle is ref-counting, a no-op deleter is returned.
77 explicit AnonymousResourceOp(OpKernelConstruction* context, bool ref_counting,
78 bool return_deleter)
79 : OpKernel(context),
80 ref_counting_(ref_counting),
81 return_deleter_(return_deleter) {}
82
83 void Compute(OpKernelContext* ctx) override {
84 FunctionLibraryRuntime* lib;
85 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
86 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
87 OP_REQUIRES_OK(
88 ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
89 T* resource;
90 OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
91 std::move(pflr), lib, &resource));
92
93 ResourceHandle handle;
94 if (ref_counting_) {
95 OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, &handle));
96 } else {
97 OP_REQUIRES_OK(ctx, CreateWeakHandle(ctx, resource, name(), &handle));
98 }
99 Tensor* handle_t;
100 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
101 handle_t->scalar<ResourceHandle>()() = handle;
102
103 if (return_deleter_) {
104 Tensor* deleter_t;
105 AllocatorAttributes attr;
106 attr.set_on_host(true);
107 OP_REQUIRES_OK(
108 ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
109 // TODO(feyu): Consider returning an OptionalVariant.
110 if (!ref_counting_) {
111 // A deleter output that deletes the resource when destroyed.
112 deleter_t->scalar<Variant>()() =
113 ResourceDeleter(handle, ctx->resource_manager());
114 }
115 }
116 }
117
118 protected:
119 virtual string name() = 0;
120
121 virtual Status CreateResource(
122 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
123 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
124 FunctionLibraryRuntime* lib, T** resource) = 0;
125
126 private:
127 const bool ref_counting_;
128 const bool return_deleter_;
129};
130
131// Returns OkStatus() if `expected` and `received` types match,
132// errors::InvalidArgument otherwise.
133Status VerifyTypesMatch(const DataTypeVector& expected,
134 const DataTypeVector& received);
135
136Status VerifyTypesMatch(const DataTypeVector& expected,
137 const std::vector<Tensor>& received);
138
139// Returns OkStatus() if `expected` and `received` shapes are compatible,
140// errors::InvalidArgument otherwise.
141Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
142 const std::vector<PartialTensorShape>& received);
143
144Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
145 const std::vector<Tensor>& received);
146
147// Dataset op level determinism policy.
148class DeterminismPolicy {
149 public:
150 enum class Type : int {
151 // The op must produce elements deterministically.
152 kDeterministic,
153 // The op may relax determinism to improve performance.
154 kNondeterministic,
155 // The determinism policy is not specified at the op level. In this case we
156 // use the experimental_deterministic dataset option to determine the
157 // determinism policy.
158 kDefault,
159 };
160 static constexpr const char* const kDeterministic = "true";
161 static constexpr const char* const kNondeterministic = "false";
162 static constexpr const char* const kDefault = "default";
163
164 DeterminismPolicy() : determinism_(Type::kDefault) {}
165 explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
166 // Creates a DeterminismPolicy with Type kDeterministic or
167 // kNondeterministic, depending on the values of `is_deterministic`.
168 explicit DeterminismPolicy(bool is_deterministic);
169
170 static Status FromString(const std::string& s, DeterminismPolicy* out);
171
172 // Returns the string representing the determinism policy. This will be one of
173 // the string constants defined above.
174 std::string String() const;
175
176 /// Convenience methods for checking the DeterminismPolicy::Type.
177 bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
178 bool IsNondeterministic() const {
179 return determinism_ == Type::kNondeterministic;
180 }
181 bool IsDefault() const { return determinism_ == Type::kDefault; }
182
183 private:
184 Type determinism_;
185};
186
187// Resolves non-deterministic seeds if necessary, returning either the original
188// seeds or the resolved seeds.
189//
190// By TensorFlow convention, if both seeds are 0, they should be replaced with
191// non-deterministically chosen seeds.
192std::pair<int64_t, int64_t> MaybeOverrideSeeds(
193 std::pair<int64_t, int64_t> seeds);
194
195// Adds the functions in `to_add` to `base`. If a function with a matching
196// signature already exists in `base`, replaces it with the function from
197// `to_add`.
198Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
199 const FunctionLibraryDefinition& to_add);
200Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
201 const FunctionDefLibrary& to_add);
202
203// Determines whether the given function is stateful.
204Status IsFunctionStateful(const FunctionLibraryDefinition& library,
205 const FunctionDef& function_def);
206
207// Determines whether the given node is stateful.
208Status IsNodeStateful(const FunctionLibraryDefinition& library,
209 const NodeDef& node);
210
211// Creates a runner that runs functions with limited parallelism.
212std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
213 std::function<void(std::function<void()>)> runner, int max_parallelism);
214
215// Op for creating a typed dummy resource.
216//
217// This op is used to provide a resource "placeholder" for ops such as
218// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
219// Originally, the lifetime of the resources passed into these ops was managed
220// externally. After the implementation changed to manage the lifetime of the
221// resources (including creation) by the ops themselves, the resource input is
222// only needed to pass a resource handle through graph rewrites. When they are
223// invoked from user code, the implementation passes in a dummy resource.
224template <typename ResourceType>
225class DummyResourceOp : public OpKernel {
226 public:
227 explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
228
229 void Compute(OpKernelContext* ctx) override {
230 Tensor* tensor;
231 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
232 tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
233 ctx, /*container=*/"", /*name=*/"dummy_resource");
234 }
235};
236
237// Given an op prefix and an op to match, returns whether the op to match
238// is a match for any version of the op prefix. For example,
239// MatchesAnyVersion("BatchDataset", "BatchDataset") == true
240// MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
241// MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
242// MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
243bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
244
245// Returns the index-th slice of a given tensor. If the index-th slice of
246// the tensor is not aligned, returns a deep copy of the tensor.
247Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index);
248
249// Removes device placements from the ops of all functions in `library`.
250void StripDevicePlacement(FunctionDefLibrary* library);
251
252// Copies partial of the batch output.
253Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
254 Tensor* output);
255
256// Reads a batch when restoring the iterator.
257Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
258 int64_t batch_size, const string& iterator_prefix,
259 const string& batch_prefix, std::vector<Tensor>* batch);
260
261// Writes a batch when saving the iterator.
262Status WriteBatch(int64_t batch_size, int64_t num_elements,
263 const string& iterator_prefix, const string& batch_prefix,
264 IteratorStateWriter* writer, std::vector<Tensor>* batch);
265
266// Reads a status when restoring the iterator.
267Status ReadStatus(const string& iterator_prefix, const string& prefix,
268 IteratorStateReader* reader, Status* status);
269
270// Writes a status when saving the iterator.
271Status WriteStatus(const string& iterator_prefix, const string& prefix,
272 const Status& status, IteratorStateWriter* writer);
273
274// Processes a batch to output. In the case a partial batch is encountered, copy
275// only partial of the batch.
276Status ProcessBatch(int64_t batch_size, int64_t num_elements,
277 bool drop_remainder, const Status& status,
278 IteratorContext* ctx, std::vector<Tensor>* output,
279 bool* end_of_sequence, std::vector<Tensor>* batch);
280
281// Constructs and stores the parameters for the CopyBatch function.
282struct CopyBatchParams {
283 Allocator* allocator;
284 std::function<void(std::function<void()>)>* runner;
285 int64 runner_threadpool_size;
286
287 explicit CopyBatchParams(IteratorContext* ctx) {
288 allocator = ctx->allocator({});
289 runner = ctx->runner();
290 runner_threadpool_size = ctx->runner_threadpool_size();
291 }
292
293 explicit CopyBatchParams(OpKernelContext* ctx) {
294 allocator = ctx->get_allocator({});
295 runner = ctx->runner();
296 runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
297 }
298};
299
300// Copies the input elements to a batch.
301//
302// The `batch_elements` argument contains the individual elements to copy into a
303// batch. The `parallel_copy` argument indicates whether to parallelize the
304// copy. The `allocation_callback` argument can be used to pass a callback to
305// invoke upon successful allocation of the memory for the batch. The
306// `out_tensors` argument will be used to store the resulting batch (one for
307// each component of the input).
308Status CopyBatch(CopyBatchParams params,
309 const std::vector<std::vector<Tensor>>& batch_elements,
310 bool parallel_copy,
311 std::function<Status()> allocation_callback,
312 std::vector<Tensor>* out_tensors);
313
314// Computes the set of experiments to apply based on the job name, task id,
315// rollout percentage of registered experiments, and the
316// TF_DATA_EXPERIMENT_OPT_IN and TF_DATA_EXPERIMENT_OPT_OUT environment
317// variables.
318absl::flat_hash_set<string> GetExperiments();
319absl::flat_hash_set<string> GetExperiments(
320 const std::string& job_name, int64_t task_id,
321 std::function<uint64_t(const string&)> hash_func);
322
323// Logs and records the experiments that will be applied.
324void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments);
325
326// Computes the set of enabled, disabled, and default optimizations based on the
327// given options. An optimization must be a graph optimizer name that has been
328// registered with Grappler.
329void GetOptimizations(const Options& options,
330 absl::flat_hash_set<tstring>* optimizations_enabled,
331 absl::flat_hash_set<tstring>* optimizations_disabled,
332 absl::flat_hash_set<tstring>* optimizations_default);
333
334// Creates graph rewrite configs based on the given options. The configs will
335// only be used if their corresponding optimizers registered with Grappler are
336// enabled.
337// A config is a string with the following format:
338// <optimizer name>:<attribute name>:<attribute value>
339absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options);
340
341// Determines whether max intra-op parallelism should be configured.
342bool ShouldConfigureMaxIntraOpParallelism(const Options& options);
343
344// Determines whether private threadpool should be used.
345bool ShouldUsePrivateThreadPool(const Options& options);
346
347// Determines whether autotuning should be used.
348bool ShouldUseAutotuning(const Options& options);
349
350// Determines whether optimizations should be applied.
351bool ShouldApplyOptimizations(
352 const Options& options,
353 const absl::flat_hash_set<tstring>& optimizations_enabled,
354 const absl::flat_hash_set<tstring>& optimizations_default);
355
356// Returns the default CPU budget.
357inline int GetCpuBudget() {
358 static bool in_experiment = GetExperiments().contains("tune_cpu_budget");
359 return (in_experiment ? 1.2 : 1.0) * port::NumSchedulableCPUs();
360}
361
362// Returns the initial value for parallelism parameter before the first Autotune
363// optimization.
364int64 GetAutotuneDefaultParallelism(IteratorContext* ctx);
365
366// Registry of tf.data experiments.
367class DatasetExperimentRegistry {
368 public:
369 using JobSelector = std::function<bool(
370 std::function<uint64_t(const string&)> hash_func,
371 const std::string& experiment_name, const std::string& job_name)>;
372 using TaskSelector = std::function<bool(int64_t task_id)>;
373
374 struct ExperimentSelector {
375 JobSelector job_selector;
376 TaskSelector task_selector;
377 };
378
379 // Registers the experiment.
380 static void Register(const string& experiment, JobSelector job_selector,
381 TaskSelector task_selector);
382
383 // Returns all registered experiments.
384 static absl::flat_hash_map<string, ExperimentSelector> Experiments();
385};
386
387// Helper class to register a dataset experiment.
388class DatasetExperimentRegistrar {
389 public:
390 explicit DatasetExperimentRegistrar(
391 const string& experiment,
392 DatasetExperimentRegistry::JobSelector job_selector,
393 DatasetExperimentRegistry::TaskSelector task_selector) {
394 DatasetExperimentRegistry::Register(experiment, job_selector,
395 task_selector);
396 }
397};
398
399// Macro that can be used to register a dataset experiment.
400#define REGISTER_DATASET_EXPERIMENT(experiment, job_selector, task_selector) \
401 REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, job_selector, \
402 task_selector)
403
404#define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, job_selector, \
405 task_selector) \
406 REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, task_selector)
407
408#define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, \
409 task_selector) \
410 static ::tensorflow::data::DatasetExperimentRegistrar \
411 registrar__body__##ctr##__object(experiment, job_selector, \
412 task_selector)
413
414} // namespace data
415} // namespace tensorflow
416
417#endif // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
418