1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/dataset.h" |
16 | |
17 | #include <unordered_map> |
18 | |
19 | #include "tensorflow/core/framework/device_base.h" |
20 | #include "tensorflow/core/framework/function.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/resource_mgr.h" |
23 | #include "tensorflow/core/framework/variant_encode_decode.h" |
24 | #include "tensorflow/core/framework/variant_op_registry.h" |
25 | #include "tensorflow/core/framework/versions.pb.h" |
26 | #include "tensorflow/core/graph/graph_def_builder.h" |
27 | #include "tensorflow/core/graph/node_builder.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/platform/mutex.h" |
31 | #include "tensorflow/core/platform/refcount.h" |
32 | #include "tensorflow/core/platform/resource.h" |
33 | #include "tensorflow/core/platform/status.h" |
34 | #include "tensorflow/core/platform/strcat.h" |
35 | #include "tensorflow/core/profiler/lib/traceme.h" |
36 | #include "tensorflow/core/public/version.h" |
37 | |
38 | // On Windows, disable some macros that would break compile |
39 | #if defined(PLATFORM_WINDOWS) |
40 | #undef GetMessage |
41 | #endif |
42 | |
43 | namespace tensorflow { |
44 | namespace data { |
45 | namespace { |
46 | |
47 | static mutex* get_dataset_op_registry_lock() { |
48 | static mutex dataset_op_registry_lock(LINKER_INITIALIZED); |
49 | return &dataset_op_registry_lock; |
50 | } |
51 | |
52 | static std::unordered_set<string>* get_dataset_op_registry() { |
53 | static std::unordered_set<string>* names = new std::unordered_set<string>; |
54 | return names; |
55 | } |
56 | |
57 | std::string UniqueNodeName(const std::string& base) { |
58 | static std::atomic<int64_t> counter(0); |
59 | return strings::StrCat(base, "/" , counter.fetch_add(1)); |
60 | } |
61 | |
62 | // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. |
63 | // Objects of the wrapper class own a reference on an instance of `DatasetBase`, |
64 | // and the wrapper's copy constructor and destructor take care of managing the |
65 | // reference count. |
66 | // |
67 | // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT |
68 | // specification. In particular, we cannot currently serialize an arbitrary |
69 | // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not |
70 | // implemented. |
71 | class DatasetVariantWrapper { |
72 | public: |
73 | DatasetVariantWrapper() : dataset_(nullptr) {} |
74 | |
75 | // Transfers ownership of `dataset` to `*this`. |
76 | explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {} |
77 | |
78 | DatasetVariantWrapper(const DatasetVariantWrapper& other) |
79 | : dataset_(other.dataset_) { |
80 | if (dataset_) dataset_->Ref(); |
81 | } |
82 | |
83 | DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) { |
84 | if (&other == this) return *this; |
85 | std::swap(dataset_, other.dataset_); |
86 | return *this; |
87 | } |
88 | |
89 | DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete; |
90 | |
91 | ~DatasetVariantWrapper() { |
92 | if (dataset_) dataset_->Unref(); |
93 | } |
94 | |
95 | DatasetBase* get() const { return dataset_; } |
96 | |
97 | string TypeName() const { return "tensorflow::DatasetVariantWrapper" ; } |
98 | string DebugString() const { |
99 | if (dataset_) { |
100 | return dataset_->DebugString(); |
101 | } else { |
102 | return "<Uninitialized DatasetVariantWrapper>" ; |
103 | } |
104 | } |
105 | void Encode(VariantTensorData* data) const { |
106 | LOG(ERROR) << "The Encode() method is not implemented for " |
107 | "DatasetVariantWrapper objects." ; |
108 | } |
109 | bool Decode(const VariantTensorData& data) { |
110 | LOG(ERROR) << "The Decode() method is not implemented for " |
111 | "DatasetVariantWrapper objects." ; |
112 | return false; |
113 | } |
114 | |
115 | private: |
116 | DatasetBase* dataset_; // Owns one reference. |
117 | }; |
118 | |
119 | const char kWrappedDatasetVariantTypeName[] = |
120 | "tensorflow::data::WrappedDatasetVariant" ; |
121 | |
122 | class WrappedDatasetVariantWrapper { |
123 | public: |
124 | WrappedDatasetVariantWrapper() {} |
125 | |
126 | explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor) |
127 | : ds_tensor_(ds_tensor) {} |
128 | |
129 | Tensor get() const { return ds_tensor_; } |
130 | |
131 | string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper" ; } |
132 | |
133 | string DebugString() const { |
134 | return "tensorflow::WrappedDatasetVariantWrapper::DebugString" ; |
135 | } |
136 | |
137 | void Encode(VariantTensorData* data) const { |
138 | *(data->add_tensors()) = ds_tensor_; |
139 | } |
140 | |
141 | bool Decode(const VariantTensorData& data) { |
142 | ds_tensor_ = data.tensors(0); |
143 | return true; |
144 | } |
145 | |
146 | private: |
147 | Tensor ds_tensor_; |
148 | }; |
149 | |
150 | class WrapDatasetVariantOp : public OpKernel { |
151 | public: |
152 | explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
153 | |
154 | void Compute(OpKernelContext* ctx) override { |
155 | const Tensor& tensor = ctx->input(0); |
156 | OP_REQUIRES(ctx, |
157 | tensor.dtype() == DT_VARIANT && |
158 | TensorShapeUtils::IsScalar(tensor.shape()), |
159 | errors::InvalidArgument( |
160 | "Dataset tensor must be a scalar of dtype DT_VARIANT." )); |
161 | DatasetBase* unused; |
162 | OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused)); |
163 | Tensor* output = nullptr; |
164 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
165 | output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor); |
166 | } |
167 | }; |
168 | |
169 | REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant" ).Device(DEVICE_CPU), |
170 | WrapDatasetVariantOp); |
171 | REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant" ) |
172 | .HostMemory("input_handle" ) |
173 | .HostMemory("output_handle" ) |
174 | .Device(DEVICE_GPU), |
175 | WrapDatasetVariantOp); |
176 | |
177 | class UnwrapDatasetVariantOp : public OpKernel { |
178 | public: |
179 | explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
180 | |
181 | void Compute(OpKernelContext* ctx) override { |
182 | const Tensor& tensor = ctx->input(0); |
183 | OP_REQUIRES(ctx, |
184 | tensor.dtype() == DT_VARIANT && |
185 | TensorShapeUtils::IsScalar(tensor.shape()), |
186 | errors::InvalidArgument( |
187 | "Dataset tensor must be a scalar of dtype DT_VARIANT." )); |
188 | Variant variant = tensor.scalar<Variant>()(); |
189 | const WrappedDatasetVariantWrapper* wrapper = |
190 | variant.get<WrappedDatasetVariantWrapper>(); |
191 | OP_REQUIRES(ctx, wrapper != nullptr, |
192 | errors::InvalidArgument( |
193 | "Tensor must be a WrappedDataset variant object." )); |
194 | Tensor ds_tensor = wrapper->get(); |
195 | OP_REQUIRES_OK(ctx, ctx->set_output("output_handle" , ds_tensor)); |
196 | } |
197 | }; |
198 | |
199 | REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant" ).Device(DEVICE_CPU), |
200 | UnwrapDatasetVariantOp); |
201 | REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant" ) |
202 | .HostMemory("input_handle" ) |
203 | .HostMemory("output_handle" ) |
204 | .Device(DEVICE_GPU), |
205 | UnwrapDatasetVariantOp); |
206 | |
207 | static Status WrappedDatasetVariantDeviceCopy( |
208 | const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to, |
209 | const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { |
210 | *to = WrappedDatasetVariantWrapper(from); |
211 | return OkStatus(); |
212 | } |
213 | |
214 | #define REGISTER_OPTIONAL_COPY(DIRECTION) \ |
215 | INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ |
216 | WrappedDatasetVariantWrapper, DIRECTION, \ |
217 | WrappedDatasetVariantDeviceCopy) |
218 | |
219 | REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); |
220 | REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); |
221 | REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); |
222 | |
223 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper, |
224 | kWrappedDatasetVariantTypeName); |
225 | |
226 | } // namespace |
227 | |
228 | Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset, |
229 | const std::vector<Node*>& inputs, |
230 | Node** output) { |
231 | return AddDataset(dataset, inputs, {}, output); |
232 | } |
233 | |
234 | Status GraphDefBuilderWrapper::AddDataset( |
235 | const DatasetBase* dataset, const std::vector<Node*>& inputs, |
236 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
237 | Node** output) { |
238 | std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); |
239 | for (size_t i = 0; i < inputs.size(); i++) { |
240 | enumerated_inputs[i] = std::make_pair(i, inputs[i]); |
241 | } |
242 | return AddDataset(dataset, enumerated_inputs, {}, attrs, output); |
243 | } |
244 | |
245 | Status GraphDefBuilderWrapper::AddDataset( |
246 | const DatasetBase* dataset, |
247 | const std::vector<std::pair<size_t, Node*>>& inputs, |
248 | const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, |
249 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
250 | Node** output) { |
251 | return AddDataset(dataset, inputs, list_inputs, attrs, |
252 | /*use_dataset_name=*/false, output); |
253 | } |
254 | |
255 | Status GraphDefBuilderWrapper::AddDataset( |
256 | const DatasetBase* dataset, |
257 | const std::vector<std::pair<size_t, Node*>>& inputs, |
258 | const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, |
259 | const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
260 | bool use_dataset_name, Node** output) { |
261 | auto& type_string = dataset->type_string(); |
262 | auto opts = absl::make_unique<GraphDefBuilder::Options>(b_->opts()); |
263 | // TODO(srbs|mrry): Not all datasets have output_types and output_shapes |
264 | // attributes defined. It will be nice to have a consistent pattern. |
265 | bool has_output_types_attr = HasAttr(type_string, "output_types" ); |
266 | bool has_output_shapes_attr = HasAttr(type_string, "output_shapes" ); |
267 | if (has_output_shapes_attr) { |
268 | opts = absl::make_unique<GraphDefBuilder::Options>( |
269 | opts->WithAttr("output_shapes" , dataset->output_shapes())); |
270 | } |
271 | if (has_output_types_attr) { |
272 | opts = absl::make_unique<GraphDefBuilder::Options>( |
273 | opts->WithAttr("output_types" , dataset->output_dtypes())); |
274 | } |
275 | bool has_metadata_attr = HasAttr(type_string, "metadata" ); |
276 | if (has_metadata_attr) { |
277 | std::string serialized_metadata; |
278 | dataset->metadata().SerializeToString(&serialized_metadata); |
279 | opts = absl::make_unique<GraphDefBuilder::Options>( |
280 | opts->WithAttr("metadata" , serialized_metadata)); |
281 | } |
282 | for (const auto& attr : attrs) { |
283 | opts = absl::make_unique<GraphDefBuilder::Options>( |
284 | opts->WithAttr(attr.first, attr.second)); |
285 | } |
286 | if (opts->HaveError()) { |
287 | return errors::Internal("AddDataset: Failed to build Options with error " , |
288 | opts->StatusToString()); |
289 | } |
290 | NodeBuilder node_builder( |
291 | use_dataset_name ? dataset->node_name() : opts->GetNameForOp(type_string), |
292 | type_string, opts->op_registry()); |
293 | { |
294 | size_t total_size = inputs.size() + list_inputs.size(); |
295 | auto inputs_iter = inputs.begin(); |
296 | auto list_inputs_iter = list_inputs.begin(); |
297 | for (int i = 0; i < total_size; i++) { |
298 | if (inputs_iter != inputs.end() && inputs_iter->first == i) { |
299 | node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second)); |
300 | inputs_iter++; |
301 | } else if (list_inputs_iter != list_inputs.end() && |
302 | list_inputs_iter->first == i) { |
303 | std::vector<NodeBuilder::NodeOut> nodeout_inputs; |
304 | nodeout_inputs.reserve(list_inputs_iter->second.size()); |
305 | for (Node* n : list_inputs_iter->second) { |
306 | nodeout_inputs.emplace_back(n); |
307 | } |
308 | node_builder.Input(nodeout_inputs); |
309 | list_inputs_iter++; |
310 | } else { |
311 | return errors::InvalidArgument("No input found for index " , i); |
312 | } |
313 | } |
314 | } |
315 | *output = opts->FinalizeBuilder(&node_builder); |
316 | if (*output == nullptr) { |
317 | return errors::Internal("AddDataset: Failed to build " , type_string, |
318 | " op with error " , opts->StatusToString()); |
319 | } |
320 | return OkStatus(); |
321 | } |
322 | |
323 | Status GraphDefBuilderWrapper::AddFunction( |
324 | SerializationContext* ctx, const string& function_name, |
325 | const FunctionLibraryDefinition& lib_def) { |
326 | if (b_->HasFunction(function_name)) { |
327 | VLOG(1) << "Function with name " << function_name << "already exists in" |
328 | << " the graph. It will not be added again." ; |
329 | return OkStatus(); |
330 | } |
331 | const FunctionDef* f_def = lib_def.Find(function_name); |
332 | if (f_def == nullptr) { |
333 | return errors::InvalidArgument("Unable to find FunctionDef for " , |
334 | function_name, " in the registry." ); |
335 | } |
336 | FunctionDefLibrary def; |
337 | *def.add_function() = *f_def; |
338 | const string gradient_func = lib_def.FindGradient(function_name); |
339 | if (!gradient_func.empty()) { |
340 | GradientDef* g_def = def.add_gradient(); |
341 | g_def->set_function_name(function_name); |
342 | g_def->set_gradient_func(gradient_func); |
343 | } |
344 | TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); |
345 | |
346 | // Recursively add functions in inputs of function_name. |
347 | for (const NodeDef& node_def : f_def->node_def()) { |
348 | const OpRegistrationData* op_reg_data = nullptr; |
349 | TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data)); |
350 | if (op_reg_data->is_function_op) { |
351 | TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def)); |
352 | } |
353 | // Recursively add functions in attrs of this NodeDef. |
354 | for (const auto& pair : node_def.attr()) { |
355 | TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def)); |
356 | } |
357 | } |
358 | |
359 | // Recursively add functions in attrs of function_name. |
360 | for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { |
361 | TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def)); |
362 | } |
363 | return OkStatus(); |
364 | } |
365 | |
366 | void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val, |
367 | Node** output) { |
368 | *output = ops::SourceOp( |
369 | "Placeholder" , |
370 | b_->opts().WithAttr("dtype" , val.dtype()).WithAttr("shape" , val.shape())); |
371 | } |
372 | |
373 | void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, |
374 | Node** output) { |
375 | *output = ops::SourceOp( |
376 | "Const" , |
377 | b_->opts().WithAttr("dtype" , val.dtype()).WithAttr("value" , val)); |
378 | } |
379 | |
380 | bool GraphDefBuilderWrapper::HasAttr(const string& name, |
381 | const string& attr_name) const { |
382 | const OpDef* op_def = nullptr; |
383 | Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def); |
384 | if (!s.ok() || op_def == nullptr) { |
385 | return false; |
386 | } |
387 | return HasAttr(op_def, attr_name); |
388 | } |
389 | |
390 | int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx) { |
391 | thread::ThreadPool* thread_pool = |
392 | ctx->device()->tensorflow_device_thread_pool(); |
393 | if (thread_pool) { |
394 | return thread_pool->NumThreads(); |
395 | } else { |
396 | static const int32_t kDefaultRunnerThreadpoolSize = port::MaxParallelism(); |
397 | return kDefaultRunnerThreadpoolSize; |
398 | } |
399 | } |
400 | |
401 | Status IteratorBase::InitializeBase(IteratorContext* ctx, |
402 | const IteratorBase* parent) { |
403 | parent_ = parent; |
404 | id_ = |
405 | Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this)); |
406 | if (parent_) { |
407 | parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()), |
408 | reinterpret_cast<uint64>(parent_)); |
409 | } |
410 | if (const auto& model = ctx->model()) { |
411 | auto factory = [ctx, this](model::Node::Args args) { |
412 | return CreateNode(ctx, std::move(args)); |
413 | }; |
414 | model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_); |
415 | cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); |
416 | } |
417 | return OkStatus(); |
418 | } |
419 | |
420 | int64_t GetAllocatedBytes(const std::vector<Tensor>& element) { |
421 | int64_t allocated_bytes = 0; |
422 | DatasetBase* dataset; |
423 | for (auto& tensor : element) { |
424 | if (tensor.dtype() == DT_VARIANT && |
425 | GetDatasetFromVariantTensor(tensor, &dataset).ok()) { |
426 | allocated_bytes += dataset->AllocatedBytes(); |
427 | } else { |
428 | allocated_bytes += tensor.AllocatedBytes(); |
429 | } |
430 | } |
431 | return allocated_bytes; |
432 | } |
433 | |
434 | int64_t GetTotalBytes(const std::vector<Tensor>& element) { |
435 | int64_t total_bytes = 0; |
436 | DatasetBase* dataset; |
437 | for (auto& tensor : element) { |
438 | if (tensor.dtype() == DT_VARIANT && |
439 | GetDatasetFromVariantTensor(tensor, &dataset).ok()) { |
440 | total_bytes += dataset->TotalBytes(); |
441 | } else { |
442 | total_bytes += tensor.TotalBytes(); |
443 | } |
444 | } |
445 | return total_bytes; |
446 | } |
447 | |
448 | std::string FullName(const std::string& prefix, const std::string& name) { |
449 | if (str_util::StrContains(name, kColon)) { |
450 | LOG(ERROR) << name << " should not contain " << kColon; |
451 | } |
452 | |
453 | return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name); |
454 | } |
455 | |
456 | Status GetDatasetFromVariantTensor(const Tensor& tensor, |
457 | DatasetBase** out_dataset) { |
458 | if (!(tensor.dtype() == DT_VARIANT && |
459 | TensorShapeUtils::IsScalar(tensor.shape()))) { |
460 | return errors::InvalidArgument( |
461 | "Dataset tensor must be a scalar of dtype DT_VARIANT." ); |
462 | } |
463 | const Variant& variant = tensor.scalar<Variant>()(); |
464 | const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>(); |
465 | if (wrapper == nullptr) { |
466 | return errors::InvalidArgument("Tensor must be a Dataset object." ); |
467 | } |
468 | *out_dataset = wrapper->get(); |
469 | if (*out_dataset == nullptr) { |
470 | return errors::Internal("Read uninitialized Dataset variant." ); |
471 | } |
472 | return OkStatus(); |
473 | } |
474 | |
475 | Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { |
476 | if (!(tensor->dtype() == DT_VARIANT && |
477 | TensorShapeUtils::IsScalar(tensor->shape()))) { |
478 | return errors::InvalidArgument( |
479 | "Dataset tensor must be a scalar of dtype DT_VARIANT." ); |
480 | } |
481 | tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset); |
482 | return OkStatus(); |
483 | } |
484 | |
485 | namespace internal { |
486 | |
487 | #define WARN_PROTO_FIELD_CONFLICT(reflection, field, field_type, src, dst) \ |
488 | { \ |
489 | auto source_value = reflection->Get##field_type(src, field); \ |
490 | auto destination_value = reflection->Get##field_type(*dst, field); \ |
491 | if (source_value != destination_value) { \ |
492 | LOG(WARNING) << "Changing the value of option field " << field->name() \ |
493 | << " from " << destination_value << " to " << source_value; \ |
494 | } \ |
495 | } |
496 | |
497 | #define WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst) \ |
498 | { \ |
499 | auto source_value = reflection->GetEnum(src, field); \ |
500 | auto destination_value = reflection->GetEnum(*dst, field); \ |
501 | if (source_value != destination_value) { \ |
502 | LOG(WARNING) << "Changing the value of option enum field " \ |
503 | << field->name() << " from " \ |
504 | << destination_value->full_name() << " to " \ |
505 | << source_value->full_name(); \ |
506 | } \ |
507 | } |
508 | |
509 | void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) { |
510 | std::vector<const protobuf::FieldDescriptor*> set_src; |
511 | std::vector<const protobuf::FieldDescriptor*> set_dst; |
512 | const protobuf::Reflection* reflection = src.GetReflection(); |
513 | reflection->ListFields(src, &set_src); |
514 | reflection->ListFields(*dst, &set_dst); |
515 | std::sort(set_src.begin(), set_src.end()); |
516 | std::sort(set_dst.begin(), set_dst.end()); |
517 | |
518 | std::vector<const protobuf::FieldDescriptor*> in_both; |
519 | std::set_intersection(set_src.begin(), set_src.end(), set_dst.begin(), |
520 | set_dst.end(), std::back_inserter(in_both)); |
521 | |
522 | for (auto field : in_both) { |
523 | if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) { |
524 | WarnProtoConflicts(reflection->GetMessage(src, field), |
525 | reflection->MutableMessage(dst, field)); |
526 | } else { |
527 | switch (field->cpp_type()) { |
528 | case protobuf::FieldDescriptor::CPPTYPE_INT32: |
529 | WARN_PROTO_FIELD_CONFLICT(reflection, field, Int32, src, dst); |
530 | break; |
531 | case protobuf::FieldDescriptor::CPPTYPE_INT64: |
532 | WARN_PROTO_FIELD_CONFLICT(reflection, field, Int64, src, dst); |
533 | break; |
534 | case protobuf::FieldDescriptor::CPPTYPE_UINT32: |
535 | WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt32, src, dst); |
536 | break; |
537 | case protobuf::FieldDescriptor::CPPTYPE_UINT64: |
538 | WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt64, src, dst); |
539 | break; |
540 | case protobuf::FieldDescriptor::CPPTYPE_DOUBLE: |
541 | WARN_PROTO_FIELD_CONFLICT(reflection, field, Double, src, dst); |
542 | break; |
543 | case protobuf::FieldDescriptor::CPPTYPE_FLOAT: |
544 | WARN_PROTO_FIELD_CONFLICT(reflection, field, Float, src, dst); |
545 | break; |
546 | case protobuf::FieldDescriptor::CPPTYPE_BOOL: |
547 | WARN_PROTO_FIELD_CONFLICT(reflection, field, Bool, src, dst); |
548 | break; |
549 | case protobuf::FieldDescriptor::CPPTYPE_ENUM: |
550 | WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst); |
551 | break; |
552 | default: { |
553 | LOG(ERROR) << "Unrecognized proto type for field " |
554 | << field->full_name(); |
555 | } |
556 | } |
557 | } |
558 | } |
559 | } |
560 | |
561 | #undef WARN_PROTO_ENUM_FIELD_CONFLICT |
562 | #undef WARN_PROTO_FIELD_CONFLICT |
563 | |
564 | void MergeOptions(const protobuf::Message& source, |
565 | protobuf::Message* destination) { |
566 | WarnProtoConflicts(source, destination); |
567 | destination->MergeFrom(source); |
568 | } |
569 | |
570 | void MergeOptions(const protobuf::MessageLite& source, |
571 | protobuf::MessageLite* destination) { |
572 | destination->CheckTypeAndMergeFrom(source); |
573 | } |
574 | |
575 | } // namespace internal |
576 | |
577 | void DatasetBase::Initialize(const Metadata& metadata) { |
578 | Status s = ComputeNumSources(); |
579 | if (!s.ok()) { |
580 | LOG(ERROR) << s; |
581 | } |
582 | s = MergeOptionsFromInputs(); |
583 | if (!s.ok()) { |
584 | LOG(ERROR) << s; |
585 | } |
586 | metadata_ = metadata; |
587 | if (metadata_.name() == "" ) { |
588 | static std::atomic<int64_t> id_counter(0); |
589 | *metadata_.mutable_name() = |
590 | strings::StrCat(type_string(), ":" , id_counter.fetch_add(1)); |
591 | } |
592 | } |
593 | |
594 | Status DatasetBase::ComputeNumSources() { |
595 | std::vector<const DatasetBase*> inputs; |
596 | Status s = InputDatasets(&inputs); |
597 | if (errors::IsUnimplemented(s)) { |
598 | return errors::Unimplemented( |
599 | "Cannot compute input sources for dataset of type " , type_string(), |
600 | ", because the dataset does not implement `InputDatasets`." ); |
601 | } |
602 | if (num_sources_ >= 0) { |
603 | // Already computed. |
604 | return OkStatus(); |
605 | } |
606 | num_sources_ = 0; |
607 | if (inputs.empty()) { |
608 | num_sources_ = 1; |
609 | return OkStatus(); |
610 | } |
611 | for (const auto& input : inputs) { |
612 | if (input->num_sources() < 0) { |
613 | return errors::FailedPrecondition( |
614 | "Cannot compute input sources for dataset of type " , type_string(), |
615 | ", because sources could not be computed for input dataset of type " , |
616 | input->type_string()); |
617 | } |
618 | num_sources_ += input->num_sources(); |
619 | } |
620 | return OkStatus(); |
621 | } |
622 | |
623 | Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { |
624 | CardinalityOptions options; |
625 | options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_MODERATE); |
626 | int64 cardinality = Cardinality(options); |
627 | if (cardinality == kInfiniteCardinality || |
628 | cardinality == kUnknownCardinality) { |
629 | return tensorflow::errors::FailedPrecondition( |
630 | "Dataset of type " , this->DebugString(), " has " , |
631 | cardinality == kInfiniteCardinality ? "infinite" : "unknown" , |
632 | " cardinality, which does not support random access." ); |
633 | } |
634 | if (index < 0 || index >= cardinality) { |
635 | return errors::OutOfRange("Index out of range [0, " , cardinality, |
636 | "):" , index); |
637 | } |
638 | return OkStatus(); |
639 | } |
640 | |
641 | Status DatasetBase::Get(OpKernelContext* ctx, int64 index, |
642 | std::vector<Tensor>* out_tensors) const { |
643 | return errors::Unimplemented( |
644 | "Random access is not implemented for this dataset." ); |
645 | } |
646 | |
647 | StatusOr<DatasetBase*> DatasetBase::Finalize( |
648 | OpKernelContext* ctx, |
649 | std::function<StatusOr<core::RefCountPtr<DatasetBase>>()> |
650 | make_finalized_dataset) const { |
651 | mutex_lock l(mu_); |
652 | if (!finalized_dataset_) { |
653 | TF_ASSIGN_OR_RETURN(finalized_dataset_, make_finalized_dataset()); |
654 | } |
655 | return finalized_dataset_.get(); |
656 | } |
657 | |
658 | Status DatasetBase::MergeOptionsFromInputs() { |
659 | std::vector<const DatasetBase*> inputs; |
660 | Status s = InputDatasets(&inputs); |
661 | if (errors::IsUnimplemented(s)) { |
662 | return errors::Unimplemented( |
663 | "Cannot merge options for dataset of type " , type_string(), |
664 | ", because the dataset does not implement `InputDatasets`." ); |
665 | } |
666 | if (inputs.empty()) { |
667 | return OkStatus(); |
668 | } |
669 | // Merge options from inputs sequentially before merging options from dataset. |
670 | // Since the last options merged takes precedence, the options that may be set |
671 | // for the current dataset through OptionsDataset takes precedence over those |
672 | // set on the input datasets. |
673 | Options merged_options = inputs[0]->options_; |
674 | for (int i = 1; i < inputs.size(); ++i) { |
675 | internal::MergeOptions(inputs[i]->options_, &merged_options); |
676 | } |
677 | internal::MergeOptions(options_, &merged_options); |
678 | options_ = merged_options; |
679 | return OkStatus(); |
680 | } |
681 | |
682 | Status DatasetBase::MakeIterator( |
683 | IteratorContext* ctx, const IteratorBase* parent, |
684 | const string& output_prefix, |
685 | std::unique_ptr<IteratorBase>* iterator) const { |
686 | if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset" ) { |
687 | std::vector<const DatasetBase*> inputs; |
688 | Status s = InputDatasets(&inputs); |
689 | return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator); |
690 | } |
691 | profiler::TraceMe traceme( |
692 | [&] { |
693 | return profiler::TraceMeEncode( |
694 | strings::StrCat("MakeIterator::" , type_string()), {}); |
695 | }, |
696 | profiler::TraceMeLevel::kInfo); |
697 | *iterator = MakeIteratorInternal(output_prefix); |
698 | Status s = (*iterator)->InitializeBase(ctx, parent); |
699 | if (s.ok()) { |
700 | s.Update((*iterator)->Initialize(ctx)); |
701 | } |
702 | if (!s.ok()) { |
703 | // Reset the iterator to avoid returning an uninitialized iterator. |
704 | iterator->reset(); |
705 | } |
706 | return s; |
707 | } |
708 | |
709 | Status DatasetBase::MakeSplitProviders( |
710 | std::vector<std::unique_ptr<SplitProvider>>* split_providers) const { |
711 | std::vector<const DatasetBase*> inputs; |
712 | Status s = InputDatasets(&inputs); |
713 | if (errors::IsUnimplemented(s)) { |
714 | return errors::Unimplemented( |
715 | "Cannot create split providers for dataset of type " , type_string(), |
716 | ", because the dataset implements neither `InputDatasets` nor " |
717 | "`MakeSplitProvider`." ); |
718 | } |
719 | if (inputs.size() != 1) { |
720 | return errors::Unimplemented( |
721 | "Cannot create split providers for dataset of type " , type_string(), |
722 | ", because the dataset is not unary (instead having arity " , |
723 | inputs.size(), |
724 | "), and no custom implementation of `MakeSplitProvider` is defined." ); |
725 | } |
726 | return inputs[0]->MakeSplitProviders(split_providers); |
727 | } |
728 | |
729 | int64_t DatasetBase::Cardinality() const { |
730 | mutex_lock l(cardinality_mu_); |
731 | if (cardinality_ == kUnknownCardinality) { |
732 | cardinality_ = CardinalityInternal(); |
733 | } |
734 | return cardinality_; |
735 | } |
736 | |
737 | int64_t DatasetBase::Cardinality(CardinalityOptions options) const { |
738 | mutex_lock l(cardinality_mu_); |
739 | if (cardinality_ == kUnknownCardinality) { |
740 | cardinality_ = CardinalityInternal(options); |
741 | } |
742 | return cardinality_; |
743 | } |
744 | |
745 | Status DatasetBase::InputDatasets( |
746 | std::vector<const DatasetBase*>* inputs) const { |
747 | return errors::Unimplemented("InputDatasets not implemented for " , |
748 | type_string()); |
749 | } |
750 | |
751 | Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( |
752 | SerializationContext* ctx, const DatasetBase* dataset, Node** output) { |
753 | Status status = dataset->AsGraphDefInternal(ctx, this, output); |
754 | if (ctx->is_graph_rewrite()) { |
755 | if (status.ok()) { |
756 | // Record cardinality in an unregistered attributes so that rewrites have |
757 | // this information. |
758 | (*output)->AddAttr(kCardinalityAttrForRewrite, dataset->Cardinality()); |
759 | } else if (errors::IsUnimplemented(status)) { |
760 | Tensor t(DT_VARIANT, TensorShape({})); |
761 | // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We |
762 | // increment the refcount of `dataset` here to retain ownership. |
763 | dataset->Ref(); |
764 | TF_RETURN_IF_ERROR( |
765 | StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t)); |
766 | TF_RETURN_IF_ERROR(AddPlaceholder(t, output)); |
767 | DCHECK_NE(ctx->input_list(), nullptr); |
768 | ctx->input_list()->emplace_back((*output)->name(), std::move(t)); |
769 | LOG_EVERY_N_SEC(WARNING, 30) |
770 | << "Input of " << dataset->DebugString() |
771 | << " will not be optimized because the dataset does not implement " |
772 | "the " |
773 | "AsGraphDefInternal() method needed to apply optimizations." ; |
774 | return OkStatus(); |
775 | } |
776 | } |
777 | return status; |
778 | } |
779 | |
780 | Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( |
781 | SerializationContext* ctx, const Tensor& t, Node** output) { |
782 | if (t.dtype() == DT_VARIANT) { |
783 | // If the input tensor is a variant, it may represent a multi-dimensional |
784 | // array of datasets. We attempt to decode each dataset so that we can use |
785 | // their custom serialization logic and combine the result of their |
786 | // individual serializations using the `Pack` operation. |
787 | // |
788 | // If this fails, we fallback to using its Variant::Encode() based |
789 | // serialization. |
790 | Status s = AddDatasetOrTensorHelper(ctx, t, output); |
791 | if (s.ok()) { |
792 | return s; |
793 | } |
794 | } |
795 | if (t.dtype() == DT_RESOURCE && !ctx->is_graph_rewrite()) { |
796 | Status s = AddResourceHelper(ctx, t, output); |
797 | if (!errors::IsUnimplemented(s)) { |
798 | // Fall through to AddTensor if AsGraphDef is not implemented for this |
799 | // resource. |
800 | return s; |
801 | } |
802 | } |
803 | return AddTensor(t, output); |
804 | } |
805 | |
806 | Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( |
807 | SerializationContext* ctx, const std::string& name_prefix, Node** input, |
808 | Node** output) { |
809 | *output = |
810 | ops::UnaryOp("Identity" , *input, |
811 | builder()->opts().WithName(UniqueNodeName(name_prefix))); |
812 | return OkStatus(); |
813 | } |
814 | |
815 | Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( |
816 | SerializationContext* ctx, const Tensor& t, Node** output) { |
817 | if (t.dims() == 0) { |
818 | DatasetBase* dataset; |
819 | TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset)); |
820 | return AddInputDataset(ctx, dataset, output); |
821 | } |
822 | std::vector<NodeBuilder::NodeOut> nodes; |
823 | for (int i = 0; i < t.dim_size(0); ++i) { |
824 | Node* node; |
825 | TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node)); |
826 | nodes.emplace_back(node); |
827 | } |
828 | auto op_name = "Pack" ; |
829 | auto opts = builder()->opts(); |
830 | NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, |
831 | opts.op_registry()); |
832 | node_builder.Input(std::move(nodes)); |
833 | *output = opts.FinalizeBuilder(&node_builder); |
834 | return OkStatus(); |
835 | } |
836 | |
837 | Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( |
838 | SerializationContext* ctx, const Tensor& t, Node** output) { |
839 | const ResourceHandle& handle = t.flat<ResourceHandle>()(0); |
840 | if (ctx->device_name() != handle.device()) { |
841 | return errors::InvalidArgument("Trying to access resource " , handle.name(), |
842 | " located in device " , handle.device(), |
843 | " from device " , ctx->device_name()); |
844 | } |
845 | ResourceBase* resource; |
846 | TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource)); |
847 | core::ScopedUnref unref(resource); |
848 | return resource->AsGraphDef(builder(), output); |
849 | } |
850 | |
851 | DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params) |
852 | : params_(params) { |
853 | params_.dataset->Ref(); |
854 | VLOG(2) << prefix() << " constructor" ; |
855 | strings::StrAppend(&traceme_metadata_, "name=" , dataset()->metadata().name()); |
856 | strings::StrAppend(&traceme_metadata_, ",shapes=" ); |
857 | auto& shapes = output_shapes(); |
858 | for (int i = 0; i < shapes.size(); ++i) { |
859 | if (i > 0) { |
860 | strings::StrAppend(&traceme_metadata_, " " ); |
861 | } |
862 | strings::StrAppend(&traceme_metadata_, shapes.at(i).DebugString()); |
863 | } |
864 | strings::StrAppend(&traceme_metadata_, ",types=" ); |
865 | auto& types = output_dtypes(); |
866 | for (int i = 0; i < types.size(); ++i) { |
867 | if (i > 0) { |
868 | strings::StrAppend(&traceme_metadata_, " " ); |
869 | } |
870 | strings::StrAppend(&traceme_metadata_, DataTypeString(types.at(i))); |
871 | } |
872 | } |
873 | |
874 | DatasetBaseIterator::~DatasetBaseIterator() { |
875 | VLOG(2) << prefix() << " destructor" ; |
876 | params_.dataset->Unref(); |
877 | } |
878 | |
879 | string DatasetBaseIterator::BuildTraceMeName() { |
880 | string result = |
881 | strings::StrCat(params_.prefix, "#" , traceme_metadata_, ",id=" , id_); |
882 | if (parent_) { |
883 | strings::StrAppend(&result, ",parent_id=" , parent_id_); |
884 | } |
885 | TraceMeMetadata metadata = GetTraceMeMetadata(); |
886 | for (const auto& pair : metadata) { |
887 | strings::StrAppend(&result, "," , pair.first, "=" , pair.second); |
888 | } |
889 | strings::StrAppend(&result, "#" ); |
890 | return result; |
891 | } |
892 | |
893 | Status DatasetBaseIterator::GetNext(IteratorContext* ctx, |
894 | std::vector<Tensor>* out_tensors, |
895 | bool* end_of_sequence) { |
896 | profiler::TraceMe activity([&] { return BuildTraceMeName(); }, |
897 | profiler::TraceMeLevel::kInfo); |
898 | DVLOG(3) << prefix() << " GetNext enter" ; |
899 | auto model = ctx->model(); |
900 | if (collect_resource_usage(ctx)) { |
901 | int64_t now_nanos = EnvTime::NowNanos(); |
902 | auto output = node_->output(); |
903 | if (output) { |
904 | output->record_stop(now_nanos); |
905 | } |
906 | node_->record_start(now_nanos); |
907 | } |
908 | out_tensors->clear(); |
909 | Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); |
910 | if (TF_PREDICT_TRUE(s.ok())) { |
911 | if (TF_PREDICT_TRUE(!*end_of_sequence)) { |
912 | DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size()); |
913 | RecordElement(ctx, out_tensors); |
914 | } else { |
915 | out_tensors->clear(); |
916 | } |
917 | } |
918 | if (collect_resource_usage(ctx)) { |
919 | int64_t now_nanos = EnvTime::NowNanos(); |
920 | node_->record_stop(now_nanos); |
921 | auto output = node_->output(); |
922 | if (output) { |
923 | output->record_start(now_nanos); |
924 | } |
925 | } |
926 | if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) { |
927 | s = errors::Internal("Iterator \"" , params_.prefix, |
928 | "\" returned `OutOfRange`. This indicates an " |
929 | "implementation error as `OutOfRange` errors are not " |
930 | "expected to be returned here. Original message: " , |
931 | s.error_message()); |
932 | LOG(ERROR) << s; |
933 | } |
934 | DVLOG(3) << prefix() << " GetNext exit" ; |
935 | return s; |
936 | } |
937 | |
938 | Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, |
939 | bool* end_of_sequence, int* num_skipped) { |
940 | profiler::TraceMe activity([&] { return BuildTraceMeName(); }, |
941 | profiler::TraceMeLevel::kInfo); |
942 | DVLOG(3) << prefix() << " Skip enter" ; |
943 | auto model = ctx->model(); |
944 | if (collect_resource_usage(ctx)) { |
945 | int64_t now_nanos = EnvTime::NowNanos(); |
946 | auto output = node_->output(); |
947 | if (output) { |
948 | output->record_stop(now_nanos); |
949 | } |
950 | node_->record_start(now_nanos); |
951 | } |
952 | Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped); |
953 | if (collect_resource_usage(ctx)) { |
954 | int64_t now_nanos = EnvTime::NowNanos(); |
955 | node_->record_stop(now_nanos); |
956 | auto output = node_->output(); |
957 | if (output) { |
958 | output->record_start(now_nanos); |
959 | } |
960 | } |
961 | if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) { |
962 | s = errors::Internal("Iterator \"" , params_.prefix, |
963 | "\" returned `OutOfRange`. This indicates an " |
964 | "implementation error as `OutOfRange` errors are not " |
965 | "expected to be returned here. Original message: " , |
966 | s.error_message()); |
967 | LOG(ERROR) << s; |
968 | } |
969 | DVLOG(3) << prefix() << " Skip exit" ; |
970 | return s; |
971 | } |
972 | |
973 | Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, |
974 | bool* end_of_sequence, |
975 | int* num_skipped) { |
976 | *num_skipped = 0; |
977 | for (int i = 0; i < num_to_skip; ++i) { |
978 | std::vector<Tensor> out_tensors; |
979 | TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence)); |
980 | if (*end_of_sequence) { |
981 | return OkStatus(); |
982 | } |
983 | // RecordElement is used to count the number of element computed and |
984 | // help calculate the CPU time spent on a given iterator to do the |
985 | // autotuning. |
986 | // Here we only call RecordElement in the default implementation of |
987 | // SkipInternal (which trivially calls GetNextInternal) and assume |
988 | // that the overridden SkipInternal in the derived class will have |
989 | // negligible cost compare to its GetNextInternal. |
990 | RecordElement(ctx, &out_tensors); |
991 | (*num_skipped)++; |
992 | } |
993 | return OkStatus(); |
994 | } |
995 | |
996 | void DatasetOpKernel::Compute(OpKernelContext* ctx) { |
997 | DatasetBase* dataset = nullptr; |
998 | MakeDataset(ctx, &dataset); |
999 | if (ctx->status().ok()) { |
1000 | Tensor* output = nullptr; |
1001 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
1002 | OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output)); |
1003 | dataset->Initialize(metadata_); |
1004 | } |
1005 | } |
1006 | |
1007 | string DatasetOpKernel::TraceString(const OpKernelContext& ctx, |
1008 | bool verbose) const { |
1009 | return profiler::TraceMeOp(name_view(), type_string_view()); |
1010 | } |
1011 | |
1012 | // static |
1013 | bool DatasetOpKernel::IsDatasetOp(const OpDef& op_def) { |
1014 | if (op_def.output_arg_size() != 1) return false; |
1015 | if (op_def.output_arg(0).type() != DT_VARIANT) return false; |
1016 | absl::string_view op_name = op_def.name(); |
1017 | if (op_name == "DatasetFromGraph" ) return true; |
1018 | if (absl::EndsWith(op_name, "Dataset" )) return true; |
1019 | // Check if the suffix matches "DatasetV[0-9]+". |
1020 | size_t index = op_name.length() - 1; |
1021 | while (index >= 0 && isdigit(op_name[index])) { |
1022 | index--; |
1023 | } |
1024 | constexpr absl::string_view kDatasetPrefix = "DatasetV" ; |
1025 | constexpr absl::string_view::size_type kPrefixLength = kDatasetPrefix.size(); |
1026 | if (index < kPrefixLength - 1 || index == op_name.length() - 1) return false; |
1027 | return op_name.substr(index - kPrefixLength + 1, kPrefixLength) == |
1028 | kDatasetPrefix; |
1029 | } |
1030 | |
1031 | void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, |
1032 | DatasetBase** output) { |
1033 | DatasetBase* input; |
1034 | OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input)); |
1035 | MakeDataset(ctx, input, output); |
1036 | } |
1037 | |
1038 | void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, |
1039 | DatasetBase** output) { |
1040 | DatasetBase* input; |
1041 | OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input)); |
1042 | DatasetBase* another_input; |
1043 | OP_REQUIRES_OK(ctx, |
1044 | GetDatasetFromVariantTensor(ctx->input(1), &another_input)); |
1045 | MakeDataset(ctx, input, another_input, output); |
1046 | } |
1047 | |
1048 | const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH" ; |
1049 | const char DatasetBase::kDatasetGraphOutputNodeKey[] = |
1050 | "_DATASET_GRAPH_OUTPUT_NODE" ; |
1051 | |
1052 | BackgroundWorker::BackgroundWorker(Env* env, const char* name) |
1053 | : env_(env), name_(name) {} |
1054 | |
1055 | BackgroundWorker::~BackgroundWorker() { |
1056 | { |
1057 | mutex_lock l(mu_); |
1058 | cancelled_ = true; |
1059 | } |
1060 | cond_var_.notify_one(); |
1061 | // Block until the background thread has terminated. |
1062 | // |
1063 | // NOTE(mrry): We explicitly free and join the thread here because |
1064 | // `WorkerLoop()` uses other members of this object, and so we must join |
1065 | // the thread before destroying them. |
1066 | thread_.reset(); |
1067 | } |
1068 | |
1069 | void BackgroundWorker::Schedule(std::function<void()> work_item) { |
1070 | { |
1071 | mutex_lock l(mu_); |
1072 | if (!thread_) { |
1073 | thread_ = absl::WrapUnique(env_->StartThread( |
1074 | {} /* thread_options */, name_, [this]() { WorkerLoop(); })); |
1075 | } |
1076 | work_queue_.push_back(std::move(work_item)); |
1077 | } |
1078 | cond_var_.notify_one(); |
1079 | } |
1080 | |
1081 | void BackgroundWorker::WorkerLoop() { |
1082 | tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background" ); |
1083 | while (true) { |
1084 | std::function<void()> work_item = nullptr; |
1085 | { |
1086 | mutex_lock l(mu_); |
1087 | while (!cancelled_ && work_queue_.empty()) { |
1088 | cond_var_.wait(l); |
1089 | } |
1090 | if (cancelled_) { |
1091 | return; |
1092 | } |
1093 | DCHECK(!work_queue_.empty()); |
1094 | work_item = std::move(work_queue_.front()); |
1095 | work_queue_.pop_front(); |
1096 | } |
1097 | DCHECK(work_item != nullptr); |
1098 | work_item(); |
1099 | } |
1100 | } |
1101 | |
1102 | namespace { |
1103 | class RunnerImpl : public Runner { |
1104 | public: |
1105 | void Run(const std::function<void()>& f) override { |
1106 | tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner" ); |
1107 | f(); |
1108 | |
1109 | // NOTE: We invoke a virtual function to prevent `f` being tail-called, and |
1110 | // thus ensure that this function remains on the stack until after `f` |
1111 | // returns. |
1112 | PreventTailCall(); |
1113 | } |
1114 | |
1115 | private: |
1116 | virtual void PreventTailCall() {} |
1117 | }; |
1118 | } // namespace |
1119 | |
1120 | /* static */ |
1121 | Runner* Runner::get() { |
1122 | static Runner* singleton = new RunnerImpl; |
1123 | return singleton; |
1124 | } |
1125 | |
1126 | } // namespace data |
1127 | } // namespace tensorflow |
1128 | |