1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ |
16 | #define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ |
17 | |
18 | #include <functional> |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "tensorflow/core/common_runtime/function.h" |
23 | #include "tensorflow/core/framework/dataset.h" |
24 | #include "tensorflow/core/framework/function.h" |
25 | #include "tensorflow/core/framework/resource_handle.h" |
26 | #include "tensorflow/core/framework/resource_mgr.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace data { |
31 | |
32 | // Constant used for indicating that the argument of tf.data.Dataset.shard |
33 | // should be supplied by the auto-sharding rewrite. |
34 | constexpr int kShardHint = -1; |
35 | |
36 | // The initial parallelism value before Autotune has a chance to optimize. |
37 | constexpr int kAutotuneDefaultParallelism = 16; |
38 | |
39 | // Creates a resource handle with a unique name for the given resource where |
40 | // the resource is managed by the Resource Manager. |
41 | template <typename T> |
42 | Status CreateWeakHandle(OpKernelContext* ctx, T* resource, |
43 | const string& container_name, ResourceHandle* handle) { |
44 | static std::atomic<int64_t> resource_id_counter(0); |
45 | string unique_name = |
46 | strings::StrCat(container_name, resource_id_counter.fetch_add(1)); |
47 | ResourceMgr* mgr = ctx->resource_manager(); |
48 | TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource)); |
49 | |
50 | *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(), |
51 | TypeIndex::Make<T>()); |
52 | return OkStatus(); |
53 | } |
54 | |
55 | // Creates a ref-counting resource handle for the given resource, where the |
56 | // resource is owned by the handle. |
57 | template <typename T> |
58 | Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) { |
59 | ResourceMgr* mgr = ctx->resource_manager(); |
60 | *handle = |
61 | ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name()); |
62 | TF_RETURN_IF_ERROR( |
63 | mgr->CreateUnowned<T>(handle->container(), handle->name(), resource)); |
64 | return OkStatus(); |
65 | } |
66 | |
67 | // TODO(b/198162355): Merge this class with ResourceOpKernel. |
68 | template <typename T> |
69 | class AnonymousResourceOp : public OpKernel { |
70 | public: |
71 | // Creates an AnonymousResourceOp. |
72 | // ref_counting: Determines if the Op returns a ref-counting ResourceHandle. |
73 | // ResourceHandle. See go/tf-resource-handle-ref-count. |
74 | // return_deleter: Determines if the Op outputs a deleter tensor in addition |
75 | // to the resource handle tensor. |
76 | // If the resource handle is ref-counting, a no-op deleter is returned. |
77 | explicit AnonymousResourceOp(OpKernelConstruction* context, bool ref_counting, |
78 | bool return_deleter) |
79 | : OpKernel(context), |
80 | ref_counting_(ref_counting), |
81 | return_deleter_(return_deleter) {} |
82 | |
83 | void Compute(OpKernelContext* ctx) override { |
84 | FunctionLibraryRuntime* lib; |
85 | std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); |
86 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); |
87 | OP_REQUIRES_OK( |
88 | ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true)); |
89 | T* resource; |
90 | OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def), |
91 | std::move(pflr), lib, &resource)); |
92 | |
93 | ResourceHandle handle; |
94 | if (ref_counting_) { |
95 | OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, &handle)); |
96 | } else { |
97 | OP_REQUIRES_OK(ctx, CreateWeakHandle(ctx, resource, name(), &handle)); |
98 | } |
99 | Tensor* handle_t; |
100 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t)); |
101 | handle_t->scalar<ResourceHandle>()() = handle; |
102 | |
103 | if (return_deleter_) { |
104 | Tensor* deleter_t; |
105 | AllocatorAttributes attr; |
106 | attr.set_on_host(true); |
107 | OP_REQUIRES_OK( |
108 | ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr)); |
109 | // TODO(feyu): Consider returning an OptionalVariant. |
110 | if (!ref_counting_) { |
111 | // A deleter output that deletes the resource when destroyed. |
112 | deleter_t->scalar<Variant>()() = |
113 | ResourceDeleter(handle, ctx->resource_manager()); |
114 | } |
115 | } |
116 | } |
117 | |
118 | protected: |
119 | virtual string name() = 0; |
120 | |
121 | virtual Status CreateResource( |
122 | OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def, |
123 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, |
124 | FunctionLibraryRuntime* lib, T** resource) = 0; |
125 | |
126 | private: |
127 | const bool ref_counting_; |
128 | const bool return_deleter_; |
129 | }; |
130 | |
131 | // Returns OkStatus() if `expected` and `received` types match, |
132 | // errors::InvalidArgument otherwise. |
133 | Status VerifyTypesMatch(const DataTypeVector& expected, |
134 | const DataTypeVector& received); |
135 | |
136 | Status VerifyTypesMatch(const DataTypeVector& expected, |
137 | const std::vector<Tensor>& received); |
138 | |
139 | // Returns OkStatus() if `expected` and `received` shapes are compatible, |
140 | // errors::InvalidArgument otherwise. |
141 | Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
142 | const std::vector<PartialTensorShape>& received); |
143 | |
144 | Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, |
145 | const std::vector<Tensor>& received); |
146 | |
147 | // Dataset op level determinism policy. |
148 | class DeterminismPolicy { |
149 | public: |
150 | enum class Type : int { |
151 | // The op must produce elements deterministically. |
152 | kDeterministic, |
153 | // The op may relax determinism to improve performance. |
154 | kNondeterministic, |
155 | // The determinism policy is not specified at the op level. In this case we |
156 | // use the experimental_deterministic dataset option to determine the |
157 | // determinism policy. |
158 | kDefault, |
159 | }; |
160 | static constexpr const char* const kDeterministic = "true" ; |
161 | static constexpr const char* const kNondeterministic = "false" ; |
162 | static constexpr const char* const kDefault = "default" ; |
163 | |
164 | DeterminismPolicy() : determinism_(Type::kDefault) {} |
165 | explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {} |
166 | // Creates a DeterminismPolicy with Type kDeterministic or |
167 | // kNondeterministic, depending on the values of `is_deterministic`. |
168 | explicit DeterminismPolicy(bool is_deterministic); |
169 | |
170 | static Status FromString(const std::string& s, DeterminismPolicy* out); |
171 | |
172 | // Returns the string representing the determinism policy. This will be one of |
173 | // the string constants defined above. |
174 | std::string String() const; |
175 | |
176 | /// Convenience methods for checking the DeterminismPolicy::Type. |
177 | bool IsDeterministic() const { return determinism_ == Type::kDeterministic; } |
178 | bool IsNondeterministic() const { |
179 | return determinism_ == Type::kNondeterministic; |
180 | } |
181 | bool IsDefault() const { return determinism_ == Type::kDefault; } |
182 | |
183 | private: |
184 | Type determinism_; |
185 | }; |
186 | |
187 | // Resolves non-deterministic seeds if necessary, returning either the original |
188 | // seeds or the resolved seeds. |
189 | // |
190 | // By TensorFlow convention, if both seeds are 0, they should be replaced with |
191 | // non-deterministically chosen seeds. |
192 | std::pair<int64_t, int64_t> MaybeOverrideSeeds( |
193 | std::pair<int64_t, int64_t> seeds); |
194 | |
195 | // Adds the functions in `to_add` to `base`. If a function with a matching |
196 | // signature already exists in `base`, replaces it with the function from |
197 | // `to_add`. |
198 | Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
199 | const FunctionLibraryDefinition& to_add); |
200 | Status AddToFunctionLibrary(FunctionLibraryDefinition* base, |
201 | const FunctionDefLibrary& to_add); |
202 | |
203 | // Determines whether the given function is stateful. |
204 | Status IsFunctionStateful(const FunctionLibraryDefinition& library, |
205 | const FunctionDef& function_def); |
206 | |
207 | // Determines whether the given node is stateful. |
208 | Status IsNodeStateful(const FunctionLibraryDefinition& library, |
209 | const NodeDef& node); |
210 | |
211 | // Creates a runner that runs functions with limited parallelism. |
212 | std::function<void(std::function<void()>)> RunnerWithMaxParallelism( |
213 | std::function<void(std::function<void()>)> runner, int max_parallelism); |
214 | |
215 | // Op for creating a typed dummy resource. |
216 | // |
217 | // This op is used to provide a resource "placeholder" for ops such as |
218 | // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input. |
219 | // Originally, the lifetime of the resources passed into these ops was managed |
220 | // externally. After the implementation changed to manage the lifetime of the |
221 | // resources (including creation) by the ops themselves, the resource input is |
222 | // only needed to pass a resource handle through graph rewrites. When they are |
223 | // invoked from user code, the implementation passes in a dummy resource. |
224 | template <typename ResourceType> |
225 | class DummyResourceOp : public OpKernel { |
226 | public: |
227 | explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
228 | |
229 | void Compute(OpKernelContext* ctx) override { |
230 | Tensor* tensor; |
231 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor)); |
232 | tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>( |
233 | ctx, /*container=*/"" , /*name=*/"dummy_resource" ); |
234 | } |
235 | }; |
236 | |
237 | // Given an op prefix and an op to match, returns whether the op to match |
238 | // is a match for any version of the op prefix. For example, |
239 | // MatchesAnyVersion("BatchDataset", "BatchDataset") == true |
240 | // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true |
241 | // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true |
242 | // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false |
243 | bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match); |
244 | |
245 | // Returns the index-th slice of a given tensor. If the index-th slice of |
246 | // the tensor is not aligned, returns a deep copy of the tensor. |
247 | Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index); |
248 | |
249 | // Removes device placements from the ops of all functions in `library`. |
250 | void StripDevicePlacement(FunctionDefLibrary* library); |
251 | |
252 | // Copies partial of the batch output. |
253 | Status CopyPartialBatch(int64_t num_elements, const Tensor& value, |
254 | Tensor* output); |
255 | |
256 | // Reads a batch when restoring the iterator. |
257 | Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, |
258 | int64_t batch_size, const string& iterator_prefix, |
259 | const string& batch_prefix, std::vector<Tensor>* batch); |
260 | |
261 | // Writes a batch when saving the iterator. |
262 | Status WriteBatch(int64_t batch_size, int64_t num_elements, |
263 | const string& iterator_prefix, const string& batch_prefix, |
264 | IteratorStateWriter* writer, std::vector<Tensor>* batch); |
265 | |
266 | // Reads a status when restoring the iterator. |
267 | Status ReadStatus(const string& iterator_prefix, const string& prefix, |
268 | IteratorStateReader* reader, Status* status); |
269 | |
270 | // Writes a status when saving the iterator. |
271 | Status WriteStatus(const string& iterator_prefix, const string& prefix, |
272 | const Status& status, IteratorStateWriter* writer); |
273 | |
274 | // Processes a batch to output. In the case a partial batch is encountered, copy |
275 | // only partial of the batch. |
276 | Status ProcessBatch(int64_t batch_size, int64_t num_elements, |
277 | bool drop_remainder, const Status& status, |
278 | IteratorContext* ctx, std::vector<Tensor>* output, |
279 | bool* end_of_sequence, std::vector<Tensor>* batch); |
280 | |
281 | // Constructs and stores the parameters for the CopyBatch function. |
282 | struct CopyBatchParams { |
283 | Allocator* allocator; |
284 | std::function<void(std::function<void()>)>* runner; |
285 | int64 runner_threadpool_size; |
286 | |
287 | explicit CopyBatchParams(IteratorContext* ctx) { |
288 | allocator = ctx->allocator({}); |
289 | runner = ctx->runner(); |
290 | runner_threadpool_size = ctx->runner_threadpool_size(); |
291 | } |
292 | |
293 | explicit CopyBatchParams(OpKernelContext* ctx) { |
294 | allocator = ctx->get_allocator({}); |
295 | runner = ctx->runner(); |
296 | runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx); |
297 | } |
298 | }; |
299 | |
300 | // Copies the input elements to a batch. |
301 | // |
302 | // The `batch_elements` argument contains the individual elements to copy into a |
303 | // batch. The `parallel_copy` argument indicates whether to parallelize the |
304 | // copy. The `allocation_callback` argument can be used to pass a callback to |
305 | // invoke upon successful allocation of the memory for the batch. The |
306 | // `out_tensors` argument will be used to store the resulting batch (one for |
307 | // each component of the input). |
308 | Status CopyBatch(CopyBatchParams params, |
309 | const std::vector<std::vector<Tensor>>& batch_elements, |
310 | bool parallel_copy, |
311 | std::function<Status()> allocation_callback, |
312 | std::vector<Tensor>* out_tensors); |
313 | |
314 | // Computes the set of experiments to apply based on the job name, task id, |
315 | // rollout percentage of registered experiments, and the |
316 | // TF_DATA_EXPERIMENT_OPT_IN and TF_DATA_EXPERIMENT_OPT_OUT environment |
317 | // variables. |
318 | absl::flat_hash_set<string> GetExperiments(); |
319 | absl::flat_hash_set<string> GetExperiments( |
320 | const std::string& job_name, int64_t task_id, |
321 | std::function<uint64_t(const string&)> hash_func); |
322 | |
323 | // Logs and records the experiments that will be applied. |
324 | void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments); |
325 | |
326 | // Computes the set of enabled, disabled, and default optimizations based on the |
327 | // given options. An optimization must be a graph optimizer name that has been |
328 | // registered with Grappler. |
329 | void GetOptimizations(const Options& options, |
330 | absl::flat_hash_set<tstring>* optimizations_enabled, |
331 | absl::flat_hash_set<tstring>* optimizations_disabled, |
332 | absl::flat_hash_set<tstring>* optimizations_default); |
333 | |
334 | // Creates graph rewrite configs based on the given options. The configs will |
335 | // only be used if their corresponding optimizers registered with Grappler are |
336 | // enabled. |
337 | // A config is a string with the following format: |
338 | // <optimizer name>:<attribute name>:<attribute value> |
339 | absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options); |
340 | |
341 | // Determines whether max intra-op parallelism should be configured. |
342 | bool ShouldConfigureMaxIntraOpParallelism(const Options& options); |
343 | |
344 | // Determines whether private threadpool should be used. |
345 | bool ShouldUsePrivateThreadPool(const Options& options); |
346 | |
347 | // Determines whether autotuning should be used. |
348 | bool ShouldUseAutotuning(const Options& options); |
349 | |
350 | // Determines whether optimizations should be applied. |
351 | bool ShouldApplyOptimizations( |
352 | const Options& options, |
353 | const absl::flat_hash_set<tstring>& optimizations_enabled, |
354 | const absl::flat_hash_set<tstring>& optimizations_default); |
355 | |
356 | // Returns the default CPU budget. |
357 | inline int GetCpuBudget() { |
358 | static bool in_experiment = GetExperiments().contains("tune_cpu_budget" ); |
359 | return (in_experiment ? 1.2 : 1.0) * port::NumSchedulableCPUs(); |
360 | } |
361 | |
362 | // Returns the initial value for parallelism parameter before the first Autotune |
363 | // optimization. |
364 | int64 GetAutotuneDefaultParallelism(IteratorContext* ctx); |
365 | |
366 | // Registry of tf.data experiments. |
367 | class DatasetExperimentRegistry { |
368 | public: |
369 | using JobSelector = std::function<bool( |
370 | std::function<uint64_t(const string&)> hash_func, |
371 | const std::string& experiment_name, const std::string& job_name)>; |
372 | using TaskSelector = std::function<bool(int64_t task_id)>; |
373 | |
374 | struct ExperimentSelector { |
375 | JobSelector job_selector; |
376 | TaskSelector task_selector; |
377 | }; |
378 | |
379 | // Registers the experiment. |
380 | static void Register(const string& experiment, JobSelector job_selector, |
381 | TaskSelector task_selector); |
382 | |
383 | // Returns all registered experiments. |
384 | static absl::flat_hash_map<string, ExperimentSelector> Experiments(); |
385 | }; |
386 | |
387 | // Helper class to register a dataset experiment. |
388 | class DatasetExperimentRegistrar { |
389 | public: |
390 | explicit DatasetExperimentRegistrar( |
391 | const string& experiment, |
392 | DatasetExperimentRegistry::JobSelector job_selector, |
393 | DatasetExperimentRegistry::TaskSelector task_selector) { |
394 | DatasetExperimentRegistry::Register(experiment, job_selector, |
395 | task_selector); |
396 | } |
397 | }; |
398 | |
399 | // Macro that can be used to register a dataset experiment. |
400 | #define REGISTER_DATASET_EXPERIMENT(experiment, job_selector, task_selector) \ |
401 | REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, job_selector, \ |
402 | task_selector) |
403 | |
404 | #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, job_selector, \ |
405 | task_selector) \ |
406 | REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, task_selector) |
407 | |
408 | #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, \ |
409 | task_selector) \ |
410 | static ::tensorflow::data::DatasetExperimentRegistrar \ |
411 | registrar__body__##ctr##__object(experiment, job_selector, \ |
412 | task_selector) |
413 | |
414 | } // namespace data |
415 | } // namespace tensorflow |
416 | |
417 | #endif // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ |
418 | |