1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/data/captured_function.h" |
16 | |
17 | #include <utility> |
18 | |
19 | #include "absl/time/clock.h" |
20 | #include "tensorflow/core/common_runtime/function.h" |
21 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
22 | #include "tensorflow/core/data/dataset_utils.h" |
23 | #include "tensorflow/core/data/stats_utils.h" |
24 | #include "tensorflow/core/framework/attr_value.pb.h" |
25 | #include "tensorflow/core/framework/cancellation.h" |
26 | #include "tensorflow/core/framework/function.h" |
27 | #include "tensorflow/core/framework/function_handle_cache.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/stats_aggregator.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/lib/gtl/optional.h" |
32 | #include "tensorflow/core/lib/random/random.h" |
33 | #include "tensorflow/core/lib/strings/strcat.h" |
34 | #include "tensorflow/core/platform/errors.h" |
35 | #include "tensorflow/core/platform/notification.h" |
36 | #include "tensorflow/core/profiler/lib/traceme.h" |
37 | |
38 | #if !defined(IS_MOBILE_PLATFORM) |
39 | #include "tensorflow/core/grappler/grappler_item.h" |
40 | #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" |
41 | #endif // !IS_MOBILE_PLATFORM |
42 | |
43 | namespace tensorflow { |
44 | namespace data { |
45 | namespace { |
46 | |
47 | constexpr char kAllowSmallFunctionOptimizations[] = |
48 | "allow_small_function_optimizations" ; |
49 | |
50 | // Simplistic implementation of the `StepStatsCollectorInterface` that only |
51 | // cares about collecting the CPU time needed to execute a captured function. |
52 | class SimpleStepStatsCollector : public StepStatsCollectorInterface { |
53 | public: |
54 | void IncrementProcessingTime(int64_t delta) { |
55 | mutex_lock l(mu_); |
56 | processing_time_ += delta; |
57 | } |
58 | |
59 | NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override { |
60 | return new SimpleNodeExecStats(this); |
61 | } |
62 | |
63 | string ReportAllocsOnResourceExhausted(const string& err) override { |
64 | return "" ; |
65 | } |
66 | |
67 | int64_t processing_time() { |
68 | tf_shared_lock l(mu_); |
69 | return processing_time_; |
70 | } |
71 | |
72 | private: |
73 | class SimpleNodeExecStats : public NodeExecStatsInterface { |
74 | public: |
75 | explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector) |
76 | : step_stats_collector_(step_stats_collector) {} |
77 | |
78 | void Done(const string& device) override { |
79 | step_stats_collector_->IncrementProcessingTime(end_time_ns_ - |
80 | start_time_ns_); |
81 | delete this; |
82 | } |
83 | |
84 | void RecordExecutorStarted() override { |
85 | start_time_ns_ = absl::GetCurrentTimeNanos(); |
86 | } |
87 | |
88 | void RecordComputeStarted() override {} |
89 | |
90 | void RecordComputeEnded() override {} |
91 | |
92 | void RecordExecutorEnded() override { |
93 | end_time_ns_ = absl::GetCurrentTimeNanos(); |
94 | } |
95 | |
96 | bool TrackAllocations() const override { return false; } |
97 | |
98 | void SetMemory(OpKernelContext* ctx) override {} |
99 | |
100 | void SetOutput(int slot, const Tensor* tensor) override {} |
101 | |
102 | void SetScheduled(int64_t nanos) override {} |
103 | |
104 | private: |
105 | int64_t start_time_ns_ = 0; |
106 | int64_t end_time_ns_ = 0; |
107 | SimpleStepStatsCollector* step_stats_collector_; // Not owned. |
108 | }; |
109 | |
110 | mutex mu_; |
111 | int64_t processing_time_ TF_GUARDED_BY(mu_) = 0; |
112 | }; |
113 | |
114 | Status GetCapturedInput(const CapturedFunction* const func, int index, |
115 | const Tensor** out) { |
116 | if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) { |
117 | return errors::OutOfRange( |
118 | "Out of range access to captured inputs for function " , |
119 | func->func().name(), ". Index: " , index, |
120 | ". Num captured inputs: " , func->captured_inputs().size()); |
121 | } |
122 | *out = &func->captured_inputs()[index]; |
123 | return OkStatus(); |
124 | } |
125 | |
126 | Status RunShortCircuit(const ShortCircuitInfo& info, |
127 | const std::vector<Tensor>& args, |
128 | const CapturedFunction* const func, |
129 | std::vector<Tensor>* rets) { |
130 | VLOG(3) << "Running function " << func->func().name() << " short circuit" ; |
131 | const int num_args = args.size(); |
132 | rets->reserve(info.indices.size()); |
133 | for (size_t i = 0; i < info.indices.size(); ++i) { |
134 | if (info.indices[i] < num_args) { |
135 | rets->push_back(args[info.indices[i]]); |
136 | } else { |
137 | const Tensor* captured_input; |
138 | TF_RETURN_IF_ERROR( |
139 | GetCapturedInput(func, info.indices[i] - num_args, &captured_input)); |
140 | rets->push_back(*captured_input); |
141 | } |
142 | } |
143 | return OkStatus(); |
144 | } |
145 | |
146 | Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args, |
147 | const CapturedFunction* const func, |
148 | std::vector<Tensor>* rets) { |
149 | VLOG(3) << "Running function " << func->func().name() << " short circuit" ; |
150 | const int num_args = args.size(); |
151 | rets->reserve(info.indices.size()); |
152 | for (size_t i = 0; i < info.indices.size(); ++i) { |
153 | if (info.indices[i] < num_args) { |
154 | if (info.can_move[i]) { |
155 | rets->push_back(std::move(args[info.indices[i]])); |
156 | } else { |
157 | rets->push_back(args[info.indices[i]]); |
158 | } |
159 | } else { |
160 | const Tensor* captured_input; |
161 | TF_RETURN_IF_ERROR( |
162 | GetCapturedInput(func, info.indices[i] - num_args, &captured_input)); |
163 | rets->push_back(*captured_input); |
164 | } |
165 | } |
166 | return OkStatus(); |
167 | } |
168 | |
169 | Status CreateShortCircuitInfo(OpKernelConstruction* ctx, |
170 | const NameAttrList& func, |
171 | ShortCircuitInfo* info) { |
172 | auto& indices = info->indices; |
173 | |
174 | FunctionLibraryRuntime::Handle fn_handle; |
175 | TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( |
176 | func.name(), AttrSlice(&func.attr()), &fn_handle)); |
177 | auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { |
178 | Status s = ctx->function_library()->ReleaseHandle(fn_handle); |
179 | if (!s.ok()) { |
180 | LOG(WARNING) << "Failed to release handle: " << s.error_message(); |
181 | } |
182 | }); |
183 | |
184 | // If the function contains any stateful operations, we conservatively execute |
185 | // the entire function. |
186 | if (ctx->function_library()->IsStateful(func.name())) { |
187 | return OkStatus(); |
188 | } |
189 | |
190 | const FunctionBody* fn_body = |
191 | ctx->function_library()->GetFunctionBody(fn_handle); |
192 | indices.resize(fn_body->ret_nodes.size()); |
193 | |
194 | for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { |
195 | Node* ret_node = fn_body->ret_nodes[i]; |
196 | Node* ret_input_node; |
197 | TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); |
198 | |
199 | while (ret_input_node->def().op() == "Identity" ) { |
200 | TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node)); |
201 | } |
202 | |
203 | if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { |
204 | TF_RETURN_IF_ERROR( |
205 | GetNodeAttr(ret_input_node->def(), "index" , &(indices[i]))); |
206 | } else { |
207 | indices.clear(); |
208 | break; |
209 | } |
210 | } |
211 | |
212 | // Compute the `can_move` vector. |
213 | if (!indices.empty()) { |
214 | auto& can_move = info->can_move; |
215 | std::map<int, int> last_use; |
216 | for (size_t i = 0; i < indices.size(); ++i) { |
217 | last_use[indices[i]] = i; |
218 | } |
219 | can_move.resize(indices.size()); |
220 | for (int i = 0, end = indices.size(); i < end; ++i) { |
221 | can_move[i] = last_use[indices[i]] == i; |
222 | } |
223 | } |
224 | |
225 | return OkStatus(); |
226 | } |
227 | |
228 | Status CreateFunctionLibraryDefinition( |
229 | const FunctionLibraryDefinition* lib_def, const string& func_name, |
230 | std::unique_ptr<FunctionLibraryDefinition>* result) { |
231 | DCHECK(lib_def != nullptr); |
232 | const FunctionDef* fdef = lib_def->Find(func_name); |
233 | if (TF_PREDICT_FALSE(fdef == nullptr)) { |
234 | return errors::FailedPrecondition(strings::StrCat( |
235 | "Could not find required function definition " , func_name)); |
236 | } |
237 | *result = std::make_unique<FunctionLibraryDefinition>( |
238 | lib_def->ReachableDefinitions(*fdef)); |
239 | return (*result)->CopyFunctionDefFrom(func_name, *lib_def); |
240 | } |
241 | |
242 | Status LookupFunction(const FunctionLibraryDefinition& lib_def, |
243 | const string& name, const FunctionDef** fdef) { |
244 | *fdef = lib_def.Find(name); |
245 | if (*fdef == nullptr) { |
246 | return errors::InvalidArgument( |
247 | "Failed to find function " , name, |
248 | " in function library: " , lib_def.ToProto().DebugString()); |
249 | } |
250 | return OkStatus(); |
251 | } |
252 | |
253 | class CallFrameBase : public CallFrameInterface { |
254 | public: |
255 | explicit CallFrameBase(DataTypeSlice ret_types) |
256 | : ret_types_(ret_types), retvals_(ret_types.size()) {} |
257 | |
258 | // Caller methods. |
259 | Status ConsumeRetvals(std::vector<Tensor>* retvals) { |
260 | retvals->reserve(retvals_.size()); |
261 | int i = 0; |
262 | for (auto&& val : retvals_) { |
263 | if (!val) { |
264 | return errors::Internal("No return value for index " , i, "." ); |
265 | } |
266 | retvals->emplace_back(std::move(val.value())); |
267 | ++i; |
268 | } |
269 | return OkStatus(); |
270 | } |
271 | |
272 | size_t num_retvals() const override { return retvals_.size(); } |
273 | |
274 | // Callee methods. |
275 | Status SetRetval(int index, const Tensor& val) override { |
276 | const int retvals_size = retvals_.size(); |
277 | if (index < retvals_size && val.dtype() == ret_types_[index] && |
278 | !retvals_[index]) { |
279 | retvals_[index] = val; |
280 | return OkStatus(); |
281 | } else if (index >= retvals_size) { |
282 | return errors::InvalidArgument("Return value " , index, |
283 | " is out of range." ); |
284 | } else if (val.dtype() != ret_types_[index]) { |
285 | return errors::InvalidArgument("Expected type " , |
286 | DataTypeString(ret_types_[index]), |
287 | " for return value " , index, " but got " , |
288 | DataTypeString(val.dtype()), "." ); |
289 | } else { |
290 | return errors::Internal("Attempted to set return value " , index, |
291 | " more than once." ); |
292 | } |
293 | } |
294 | |
295 | private: |
296 | DataTypeSlice ret_types_; |
297 | std::vector<gtl::optional<Tensor>> retvals_; |
298 | TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase); |
299 | }; |
300 | |
301 | class OwnedArgsCallFrame : public CallFrameBase { |
302 | public: |
303 | OwnedArgsCallFrame(std::vector<Tensor>&& args, |
304 | const std::vector<Tensor>* captured_inputs, |
305 | DataTypeSlice ret_types) |
306 | : CallFrameBase(ret_types), |
307 | args_(std::move(args)), |
308 | captured_inputs_(captured_inputs) {} |
309 | |
310 | size_t num_args() const override { |
311 | return args_.size() + captured_inputs_->size(); |
312 | } |
313 | |
314 | // Callee methods. |
315 | Status GetArg(int index, const Tensor** val) override { |
316 | const int args_size = args_.size(); |
317 | const int captured_inputs_size = captured_inputs_->size(); |
318 | if (index < args_size) { |
319 | *val = &args_[index]; |
320 | return OkStatus(); |
321 | } else if (index < args_size + captured_inputs_size) { |
322 | *val = &(*captured_inputs_)[index - args_.size()]; |
323 | return OkStatus(); |
324 | } else { |
325 | return errors::InvalidArgument("Argument " , index, " is out of range." ); |
326 | } |
327 | } |
328 | |
329 | // Since we own the argument tensors in `args_`, we can implement |
330 | // `ConsumeArg()` for those arguments. |
331 | void ConsumeArg(int index, Tensor* val) override { |
332 | DCHECK_GE(index, 0); |
333 | DCHECK_LT(index, args_.size()); |
334 | *val = std::move(args_[index]); |
335 | } |
336 | bool CanConsumeArg(int index) const override { |
337 | return index >= 0 && index < static_cast<int>(args_.size()); |
338 | } |
339 | |
340 | private: |
341 | std::vector<Tensor> args_; |
342 | const std::vector<Tensor>* const captured_inputs_; // Not owned. |
343 | }; |
344 | |
345 | class BorrowedArgsCallFrame : public CallFrameBase { |
346 | public: |
347 | BorrowedArgsCallFrame(const std::vector<Tensor>& args, |
348 | const std::vector<Tensor>* captured_inputs, |
349 | DataTypeSlice ret_types) |
350 | : CallFrameBase(ret_types), |
351 | args_(args), |
352 | captured_inputs_(captured_inputs) {} |
353 | |
354 | size_t num_args() const override { |
355 | return args_.size() + captured_inputs_->size(); |
356 | } |
357 | |
358 | // Callee methods. |
359 | Status GetArg(int index, const Tensor** val) override { |
360 | const int args_size = args_.size(); |
361 | const int captured_inputs_size = captured_inputs_->size(); |
362 | if (index < args_size) { |
363 | *val = &args_[index]; |
364 | return OkStatus(); |
365 | } else if (index < args_size + captured_inputs_size) { |
366 | *val = &(*captured_inputs_)[index - args_size]; |
367 | return OkStatus(); |
368 | } else { |
369 | return errors::InvalidArgument("Argument " , index, " is out of range." ); |
370 | } |
371 | } |
372 | |
373 | private: |
374 | const std::vector<Tensor>& args_; // Not owned. |
375 | const std::vector<Tensor>* const captured_inputs_; // Not owned. |
376 | }; |
377 | |
378 | } // namespace |
379 | |
380 | Status MakeIteratorFromInputElement( |
381 | IteratorContext* ctx, const IteratorBase* parent, |
382 | const std::vector<Tensor>& input_element, int64_t thread_index, |
383 | const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, |
384 | std::unique_ptr<IteratorBase>* out_iterator) { |
385 | return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index, |
386 | inst_captured_func, prefix, out_iterator, |
387 | /*node=*/nullptr); |
388 | } |
389 | |
390 | Status MakeIteratorFromInputElement( |
391 | IteratorContext* ctx, const IteratorBase* parent, |
392 | const std::vector<Tensor>& input_element, int64_t thread_index, |
393 | const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, |
394 | std::unique_ptr<IteratorBase>* out_iterator, |
395 | const std::shared_ptr<model::Node>& node) { |
396 | std::vector<Tensor> return_values; |
397 | |
398 | TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs( |
399 | ctx, input_element, &return_values, node)); |
400 | |
401 | if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && |
402 | TensorShapeUtils::IsScalar(return_values[0].shape()))) { |
403 | return errors::InvalidArgument( |
404 | "Function must return a single scalar of dtype DT_VARIANT." ); |
405 | } |
406 | |
407 | // Retrieve the dataset that was created in `f`. |
408 | DatasetBase* returned_dataset; |
409 | TF_RETURN_IF_ERROR( |
410 | GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); |
411 | |
412 | // Create an iterator for the dataset that was returned by `f`. |
413 | std::string iterator_prefix = strings::StrCat(prefix, "[" , thread_index, "]" ); |
414 | |
415 | return returned_dataset->MakeIterator(MakeNestedIteratorContext(ctx), parent, |
416 | iterator_prefix, out_iterator); |
417 | } |
418 | |
419 | IteratorContext MakeNestedIteratorContext(IteratorContext* ctx) { |
420 | // Strip out any split providers so that they don't apply to sub-iterators. |
421 | if (ctx->split_providers().empty()) { |
422 | return *ctx; |
423 | } |
424 | IteratorContext::Params params(ctx); |
425 | params.split_providers.clear(); |
426 | return IteratorContext(std::move(params)); |
427 | } |
428 | |
429 | /* static */ |
430 | Status FunctionMetadata::Create( |
431 | OpKernelConstruction* ctx, const string& func_name, Params params, |
432 | std::shared_ptr<FunctionMetadata>* out_metadata) { |
433 | NameAttrList func; |
434 | TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func)); |
435 | return Create(ctx, std::move(func), params, out_metadata); |
436 | } |
437 | |
438 | Status FunctionMetadata::Create( |
439 | OpKernelConstruction* ctx, NameAttrList&& func, Params params, |
440 | std::shared_ptr<FunctionMetadata>* out_metadata) { |
441 | out_metadata->reset(new FunctionMetadata(std::move(func), params)); |
442 | TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition( |
443 | ctx->function_library()->GetFunctionLibraryDefinition(), |
444 | (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_)); |
445 | TF_RETURN_IF_ERROR(CreateShortCircuitInfo( |
446 | ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_)); |
447 | const FunctionDef* fdef; |
448 | TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(), |
449 | (*out_metadata)->func().name(), &fdef)); |
450 | |
451 | auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr); |
452 | if (attr != fdef->attr().end() && attr->second.b()) { |
453 | VLOG(1) << "Disabling multi-device execution for a function that uses the " |
454 | << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute." ; |
455 | (*out_metadata)->use_multi_device_function_ = false; |
456 | return OkStatus(); |
457 | } |
458 | auto validate_arg = [](const OpDef::ArgDef& arg) { |
459 | if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) { |
460 | VLOG(1) << "Disabling multi-device execution for a function with " |
461 | << "a vector argument " << arg.name() << "." ; |
462 | return false; |
463 | } |
464 | return true; |
465 | }; |
466 | for (const auto& arg : fdef->signature().input_arg()) { |
467 | if (!validate_arg(arg)) { |
468 | (*out_metadata)->use_multi_device_function_ = false; |
469 | return OkStatus(); |
470 | } |
471 | } |
472 | for (const auto& arg : fdef->signature().output_arg()) { |
473 | if (!validate_arg(arg)) { |
474 | (*out_metadata)->use_multi_device_function_ = false; |
475 | return OkStatus(); |
476 | } |
477 | } |
478 | return OkStatus(); |
479 | } |
480 | |
481 | /* static */ |
482 | Status CapturedFunction::Create( |
483 | OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata, |
484 | const string& argument_name, |
485 | std::unique_ptr<CapturedFunction>* out_function) { |
486 | OpInputList inputs; |
487 | TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs)); |
488 | std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end()); |
489 | return Create(ctx, std::move(metadata), std::move(captured_inputs), |
490 | out_function); |
491 | } |
492 | |
493 | /* static */ |
494 | Status CapturedFunction::Create( |
495 | OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata, |
496 | std::vector<Tensor>&& captured_inputs, |
497 | std::unique_ptr<CapturedFunction>* out_function) { |
498 | *out_function = absl::WrapUnique( |
499 | new CapturedFunction(std::move(metadata), std::move(captured_inputs))); |
500 | return OkStatus(); |
501 | } |
502 | |
503 | Status CapturedFunction::AddToGraph( |
504 | SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b, |
505 | std::vector<Node*>* other_arguments, |
506 | DataTypeVector* other_arguments_types) const { |
507 | other_arguments->reserve(captured_inputs_.size()); |
508 | other_arguments_types->reserve(captured_inputs_.size()); |
509 | for (const Tensor& t : captured_inputs_) { |
510 | Node* node; |
511 | if (!ctx->is_graph_rewrite()) { |
512 | TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node)); |
513 | } else { |
514 | TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node)); |
515 | DCHECK_NE(ctx->input_list(), nullptr); |
516 | ctx->input_list()->emplace_back(node->name(), t); |
517 | } |
518 | other_arguments->emplace_back(node); |
519 | other_arguments_types->emplace_back(t.dtype()); |
520 | } |
521 | TF_RETURN_IF_ERROR( |
522 | b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def())); |
523 | return OkStatus(); |
524 | } |
525 | |
526 | Status CapturedFunction::Instantiate( |
527 | IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>* |
528 | instantiated_captured_function) { |
529 | return CapturedFunction::Instantiate(InstantiateCapturedFunctionParams(ctx), |
530 | instantiated_captured_function); |
531 | } |
532 | |
533 | // TODO(b/190831948): Check whether the function creates a resource and if so, |
534 | // produce a warning. |
535 | Status CapturedFunction::Instantiate( |
536 | InstantiateCapturedFunctionParams params, |
537 | std::unique_ptr<InstantiatedCapturedFunction>* |
538 | instantiated_captured_function) { |
539 | // The context's runtime will be used for all subsequent calls. |
540 | FunctionLibraryRuntime* lib = params.flr; |
541 | FunctionLibraryRuntime::InstantiateOptions inst_opts; |
542 | inst_opts.lib_def = metadata_->lib_def(); |
543 | inst_opts.create_kernels_eagerly = true; |
544 | inst_opts.default_device_to_target = metadata_->use_default_device(); |
545 | inst_opts.config_proto = |
546 | lib->config_proto() ? *lib->config_proto() : ConfigProto(); |
547 | if (GetExperiments().contains(kAllowSmallFunctionOptimizations)) { |
548 | inst_opts.allow_small_function_optimizations = true; |
549 | } else { |
550 | if (!metadata_->use_inter_op_parallelism()) { |
551 | inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR" ; |
552 | } |
553 | } |
554 | inst_opts.is_multi_device_function = metadata_->use_multi_device_function(); |
555 | if (!params.function_handle_cache) { |
556 | // If the caller does not provide a cache, we use the FLR cache. |
557 | inst_opts.use_function_cache = true; |
558 | } |
559 | |
560 | // We infer the target device from the function library runtime. |
561 | DCHECK(lib->device() != nullptr); |
562 | inst_opts.target = lib->device()->name(); |
563 | |
564 | // Maps from a CompositeDevice name to underlying physical device names. |
565 | absl::flat_hash_map<string, std::vector<string>> composite_devices; |
566 | |
567 | if (inst_opts.is_multi_device_function) { |
568 | // Compute devices of non-captured inputs. |
569 | // |
570 | // We infer the number of non-captured inputs by subtracting the number |
571 | // of captured inputs from the number of input arguments and we infer the |
572 | // input devices from the function library runtime. |
573 | const FunctionDef* fdef; |
574 | TF_RETURN_IF_ERROR( |
575 | LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); |
576 | size_t num_non_captured_inputs = |
577 | fdef->signature().input_arg_size() - captured_inputs_.size(); |
578 | for (size_t i = 0; i < num_non_captured_inputs; ++i) { |
579 | inst_opts.input_devices.push_back(inst_opts.target); |
580 | } |
581 | // Compute devices of captured inputs. |
582 | // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0. |
583 | Device* cpu_device; |
584 | TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0" , &cpu_device)); |
585 | std::unordered_map<int, DtypeAndPartialTensorShape>& |
586 | input_resource_variable_dtypes_and_shapes = |
587 | inst_opts.input_resource_dtypes_and_shapes; |
588 | for (size_t i = 0; i < captured_inputs_.size(); ++i) { |
589 | const auto& input = captured_inputs_[i]; |
590 | DataType dtype = input.dtype(); |
591 | if (dtype == DT_RESOURCE) { |
592 | const auto& handles = input.flat<ResourceHandle>(); |
593 | const ResourceHandle& handle0 = handles(0); |
594 | string composite_device; |
595 | auto iter = fdef->arg_attr().find(num_non_captured_inputs + i); |
596 | if (iter != fdef->arg_attr().end()) { |
597 | auto arg_attr = iter->second.attr().find("_composite_device" ); |
598 | if (arg_attr != iter->second.attr().end()) { |
599 | composite_device = arg_attr->second.s(); |
600 | } |
601 | } |
602 | if (!composite_device.empty()) { |
603 | if (composite_devices.find(composite_device) == |
604 | composite_devices.end()) { |
605 | for (int i = 0; i < handles.size(); ++i) { |
606 | composite_devices[composite_device].push_back( |
607 | handles(i).device()); |
608 | } |
609 | } |
610 | inst_opts.input_devices.push_back(composite_device); |
611 | } else { |
612 | inst_opts.input_devices.push_back(handle0.device()); |
613 | } |
614 | const auto& dtypes_and_shapes = handle0.dtypes_and_shapes(); |
615 | // Set dtypes and shapes for resource variable inputs. |
616 | if (!dtypes_and_shapes.empty()) { |
617 | input_resource_variable_dtypes_and_shapes[num_non_captured_inputs + |
618 | i] = |
619 | dtypes_and_shapes.at(0); |
620 | } |
621 | } else if (MTypeFromDType(dtype) == HOST_MEMORY) { |
622 | inst_opts.input_devices.push_back(cpu_device->name()); |
623 | } else { |
624 | // Fall back to using the function library runtime device. |
625 | inst_opts.input_devices.push_back(inst_opts.target); |
626 | } |
627 | } |
628 | |
629 | for (const auto& it : composite_devices) { |
630 | inst_opts.composite_devices[it.first] = &it.second; |
631 | } |
632 | |
633 | for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) { |
634 | inst_opts.output_devices.push_back(inst_opts.target); |
635 | } |
636 | |
637 | #if !defined(IS_MOBILE_PLATFORM) |
638 | grappler::GrapplerItem::OptimizationOptions optimization_options; |
639 | optimization_options.allow_pruning_stateful_and_dataset_ops = false; |
640 | ConfigProto config_proto = inst_opts.config_proto; |
641 | // Layout optimizations are excluded because they assume that ops without |
642 | // explicit device assignment will be placed on GPU (if available) but |
643 | // that's not the case for operations within tf.data functions. |
644 | config_proto.mutable_graph_options() |
645 | ->mutable_rewrite_options() |
646 | ->set_layout_optimizer(RewriterConfig::OFF); |
647 | // TODO(b/120437209): Re-enable constant folding. |
648 | config_proto.mutable_graph_options() |
649 | ->mutable_rewrite_options() |
650 | ->set_constant_folding(RewriterConfig::OFF); |
651 | inst_opts.optimize_graph_fn = |
652 | std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1, |
653 | std::placeholders::_2, std::placeholders::_3, |
654 | std::placeholders::_4, std::placeholders::_5, |
655 | std::move(config_proto), fdef->signature().name(), |
656 | std::move(optimization_options), std::placeholders::_6); |
657 | #endif // !IS_MOBILE_PLATFORM |
658 | } |
659 | |
660 | FunctionLibraryRuntime::Handle f_handle; |
661 | if (params.function_handle_cache) { |
662 | TF_RETURN_IF_ERROR(params.function_handle_cache->Instantiate( |
663 | metadata_->func().name(), AttrSlice(&metadata_->func().attr()), |
664 | inst_opts, &f_handle)); |
665 | } else { |
666 | TF_RETURN_IF_ERROR(lib->Instantiate(metadata_->func().name(), |
667 | AttrSlice(&metadata_->func().attr()), |
668 | inst_opts, &f_handle)); |
669 | } |
670 | |
671 | DataTypeVector ret_types; |
672 | TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types)); |
673 | |
674 | bool is_multi_device; |
675 | TF_RETURN_IF_ERROR(IsMultiDevice(lib, &is_multi_device)); |
676 | *instantiated_captured_function = absl::WrapUnique( |
677 | new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types), |
678 | *params.runner, this, is_multi_device)); |
679 | return OkStatus(); |
680 | } |
681 | |
682 | Status CapturedFunction::CheckExternalState() const { |
683 | for (const auto& name : lib_def()->ListFunctionNames()) { |
684 | TF_RETURN_IF_ERROR( |
685 | IsFunctionStateful(*lib_def(), *(lib_def()->Find(name)))); |
686 | } |
687 | return OkStatus(); |
688 | } |
689 | |
690 | CapturedFunction::CapturedFunction( |
691 | std::shared_ptr<const FunctionMetadata> metadata, |
692 | std::vector<Tensor> captured_inputs) |
693 | : metadata_(std::move(metadata)), |
694 | captured_inputs_(std::move(captured_inputs)) {} |
695 | |
696 | Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, |
697 | bool* is_multi_device) const { |
698 | if (!metadata_->use_multi_device_function()) { |
699 | *is_multi_device = false; |
700 | return OkStatus(); |
701 | } |
702 | |
703 | const FunctionDef* fdef; |
704 | TF_RETURN_IF_ERROR( |
705 | LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); |
706 | |
707 | Device* current_device = flr->device(); |
708 | DeviceType current_device_type(current_device->device_type()); |
709 | DeviceNameUtils::ParsedName current_device_name; |
710 | if (!DeviceNameUtils::ParseFullName(current_device->name(), |
711 | ¤t_device_name)) { |
712 | return errors::InvalidArgument("Failed to parse device name: " , |
713 | current_device->name()); |
714 | } |
715 | |
716 | // Check if any of the captured inputs are placed on a device not compatible |
717 | // with the current device. For non-captured inputs, we assume they are placed |
718 | // on the current device. |
719 | for (const auto& input : captured_inputs_) { |
720 | DataType dtype = input.dtype(); |
721 | if (dtype == DT_RESOURCE) { |
722 | const ResourceHandle& handle = input.flat<ResourceHandle>()(0); |
723 | DeviceNameUtils::ParsedName resource_device_name; |
724 | if (!DeviceNameUtils::ParseFullName(handle.device(), |
725 | &resource_device_name)) { |
726 | return errors::InvalidArgument("Failed to parse device name: " , |
727 | handle.device()); |
728 | } |
729 | if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, |
730 | resource_device_name)) { |
731 | *is_multi_device = true; |
732 | return OkStatus(); |
733 | } |
734 | } |
735 | } |
736 | |
737 | // Check if all ops could be placed on the current device. |
738 | for (const auto& name : metadata_->lib_def()->ListFunctionNames()) { |
739 | const FunctionDef* fdef; |
740 | TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef)); |
741 | for (const auto& node : fdef->node_def()) { |
742 | // Check if the op has a kernel available for the current device. |
743 | if (!KernelDefAvailable(current_device_type, node)) { |
744 | *is_multi_device = true; |
745 | return OkStatus(); |
746 | } |
747 | // If the op has a requested device, check if the requested device is |
748 | // compatible with the current device. |
749 | if (!node.device().empty()) { |
750 | DeviceNameUtils::ParsedName node_device_name; |
751 | if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) { |
752 | return errors::InvalidArgument("Failed to parse device name: " , |
753 | node.device()); |
754 | } |
755 | if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, |
756 | node_device_name)) { |
757 | *is_multi_device = true; |
758 | return OkStatus(); |
759 | } |
760 | } |
761 | } |
762 | } |
763 | |
764 | *is_multi_device = false; |
765 | return OkStatus(); |
766 | } |
767 | |
768 | InstantiatedCapturedFunction::InstantiatedCapturedFunction( |
769 | FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, |
770 | DataTypeVector ret_types, std::function<void(std::function<void()>)> runner, |
771 | CapturedFunction* captured_func, bool is_multi_device) |
772 | : lib_(lib), |
773 | f_handle_(f_handle), |
774 | ret_types_(std::move(ret_types)), |
775 | captured_runner_(std::move(runner)), |
776 | captured_func_(captured_func), |
777 | is_multi_device_(is_multi_device) {} |
778 | |
779 | Status InstantiatedCapturedFunction::Run(IteratorContext* ctx, |
780 | std::vector<Tensor>&& args, |
781 | std::vector<Tensor>* rets) const { |
782 | return Run(ctx, std::move(args), rets, /*node=*/nullptr); |
783 | } |
784 | |
785 | Status InstantiatedCapturedFunction::Run( |
786 | IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, |
787 | const std::shared_ptr<model::Node>& node) const { |
788 | auto& info = captured_func_->short_circuit_info(); |
789 | if (!info.indices.empty()) { |
790 | return RunShortCircuit(info, std::move(args), captured_func_, rets); |
791 | } |
792 | |
793 | FunctionLibraryRuntime::Options f_opts; |
794 | ScopedStepContainer step_container( |
795 | f_opts.step_id, [this](const string& name) { |
796 | lib_->device()->resource_manager()->Cleanup(name).IgnoreError(); |
797 | }); |
798 | f_opts.step_container = &step_container; |
799 | f_opts.runner = ctx->runner(); |
800 | f_opts.create_rendezvous = ShouldCreateRendezvous(); |
801 | CancellationManager cancellation_manager(ctx->cancellation_manager()); |
802 | f_opts.cancellation_manager = &cancellation_manager; |
803 | f_opts.collective_executor = ctx->collective_executor(); |
804 | |
805 | std::shared_ptr<SimpleStepStatsCollector> stats_collector; |
806 | if (node || ctx->stats_aggregator()) { |
807 | stats_collector = std::make_shared<SimpleStepStatsCollector>(); |
808 | } |
809 | const bool collect_usage = node && ctx->model(); |
810 | f_opts.stats_collector = stats_collector.get(); |
811 | |
812 | OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(), |
813 | ret_types_); |
814 | profiler::TraceMe activity( |
815 | [&] { |
816 | return profiler::TraceMeEncode("InstantiatedCapturedFunction::Run" , |
817 | {{"id" , f_opts.step_id}}); |
818 | }, |
819 | profiler::TraceMeLevel::kInfo); |
820 | if (node) { |
821 | // Resource usage for function execution is gathered from the executor. |
822 | // TODO(jsimsa): Factor out common code for Run, RunAsync, and |
823 | // RunWithBorrowedArguments |
824 | if (collect_usage) node->record_stop(EnvTime::NowNanos()); |
825 | TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); |
826 | if (ctx->stats_aggregator()) { |
827 | string prefix_with_func_name = strings::StrCat( |
828 | node->name(), stats_utils::kDelimiter, captured_func_->func().name()); |
829 | ctx->stats_aggregator()->AddToHistogram( |
830 | stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), |
831 | {static_cast<float>(stats_collector->processing_time())}, |
832 | node->num_elements()); |
833 | } |
834 | node->add_processing_time(stats_collector->processing_time()); |
835 | if (collect_usage) node->record_start(EnvTime::NowNanos()); |
836 | } else { |
837 | TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); |
838 | } |
839 | return frame.ConsumeRetvals(rets); |
840 | } |
841 | |
842 | Status InstantiatedCapturedFunction::RunWithBorrowedArgs( |
843 | IteratorContext* ctx, const std::vector<Tensor>& args, |
844 | std::vector<Tensor>* ret) const { |
845 | return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr); |
846 | } |
847 | |
848 | Status InstantiatedCapturedFunction::RunWithBorrowedArgs( |
849 | IteratorContext* ctx, const std::vector<Tensor>& args, |
850 | std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const { |
851 | auto& info = captured_func_->short_circuit_info(); |
852 | if (!info.indices.empty()) { |
853 | return RunShortCircuit(info, args, captured_func_, rets); |
854 | } |
855 | |
856 | FunctionLibraryRuntime::Options f_opts; |
857 | ScopedStepContainer step_container( |
858 | f_opts.step_id, [this](const string& name) { |
859 | lib_->device()->resource_manager()->Cleanup(name).IgnoreError(); |
860 | }); |
861 | f_opts.step_container = &step_container; |
862 | f_opts.runner = ctx->runner(); |
863 | f_opts.create_rendezvous = ShouldCreateRendezvous(); |
864 | CancellationManager cancellation_manager(ctx->cancellation_manager()); |
865 | f_opts.cancellation_manager = &cancellation_manager; |
866 | f_opts.collective_executor = ctx->collective_executor(); |
867 | |
868 | std::shared_ptr<SimpleStepStatsCollector> stats_collector; |
869 | if (node || ctx->stats_aggregator()) { |
870 | stats_collector = std::make_shared<SimpleStepStatsCollector>(); |
871 | } |
872 | const bool collect_usage = node && ctx->model(); |
873 | f_opts.stats_collector = stats_collector.get(); |
874 | |
875 | BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(), |
876 | ret_types_); |
877 | profiler::TraceMe activity( |
878 | [&] { |
879 | return profiler::TraceMeEncode( |
880 | "InstantiatedCapturedFunction::RunWithBorrowedArgs" , |
881 | {{"id" , f_opts.step_id}}); |
882 | }, |
883 | profiler::TraceMeLevel::kInfo); |
884 | if (node) { |
885 | // Resource usage for function execution is gathered from the executor. |
886 | if (collect_usage) node->record_stop(EnvTime::NowNanos()); |
887 | TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); |
888 | if (ctx->stats_aggregator()) { |
889 | string prefix_with_func_name = strings::StrCat( |
890 | node->name(), stats_utils::kDelimiter, captured_func_->func().name()); |
891 | ctx->stats_aggregator()->AddToHistogram( |
892 | stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), |
893 | {static_cast<float>(stats_collector->processing_time())}, |
894 | node->num_elements()); |
895 | } |
896 | node->add_processing_time(stats_collector->processing_time()); |
897 | if (collect_usage) node->record_start(EnvTime::NowNanos()); |
898 | } else { |
899 | TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); |
900 | } |
901 | return frame.ConsumeRetvals(rets); |
902 | } |
903 | |
904 | Status InstantiatedCapturedFunction::RunInstantiated( |
905 | const std::vector<Tensor>& args, std::vector<Tensor>* rets) { |
906 | auto& info = captured_func_->short_circuit_info(); |
907 | if (!info.indices.empty()) { |
908 | return RunShortCircuit(info, args, captured_func_, rets); |
909 | } |
910 | |
911 | FunctionLibraryRuntime::Options f_opts; |
912 | ScopedStepContainer step_container( |
913 | f_opts.step_id, [this](const string& name) { |
914 | lib_->device()->resource_manager()->Cleanup(name).IgnoreError(); |
915 | }); |
916 | f_opts.step_container = &step_container; |
917 | f_opts.runner = &captured_runner_; |
918 | f_opts.create_rendezvous = ShouldCreateRendezvous(); |
919 | CancellationManager cancellation_manager; |
920 | f_opts.cancellation_manager = &cancellation_manager; |
921 | |
922 | BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(), |
923 | ret_types_); |
924 | profiler::TraceMe activity( |
925 | [&] { |
926 | return profiler::TraceMeEncode( |
927 | "InstantiatedCapturedFunction::RunInstantiated" , |
928 | {{"id" , f_opts.step_id}}); |
929 | }, |
930 | profiler::TraceMeLevel::kInfo); |
931 | TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); |
932 | return frame.ConsumeRetvals(rets); |
933 | } |
934 | |
935 | void InstantiatedCapturedFunction::RunAsync( |
936 | IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, |
937 | FunctionLibraryRuntime::DoneCallback done, |
938 | const std::shared_ptr<model::Node>& node) const { |
939 | auto& info = captured_func_->short_circuit_info(); |
940 | if (!info.indices.empty()) { |
941 | // Run the `done` callback on a threadpool thread, because it will |
942 | // potentially do a non-trivial amount of (e.g. copying) work, and we may |
943 | // want to run that concurrently with the next invocation. |
944 | Status s = RunShortCircuit(info, std::move(args), captured_func_, rets); |
945 | (*ctx->runner())( |
946 | std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); }, |
947 | std::move(done))); |
948 | return; |
949 | } |
950 | |
951 | // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may |
952 | // be deleted before `done` is called. Take care not to capture `ctx` in any |
953 | // code that may execute asynchronously in this function. |
954 | OwnedArgsCallFrame* frame = new OwnedArgsCallFrame( |
955 | std::move(args), &captured_func_->captured_inputs(), ret_types_); |
956 | |
957 | FunctionLibraryRuntime::Options f_opts; |
958 | ResourceMgr* resource_mgr = lib_->device()->resource_manager(); |
959 | ScopedStepContainer* step_container = new ScopedStepContainer( |
960 | f_opts.step_id, [resource_mgr](const string& name) { |
961 | resource_mgr->Cleanup(name).IgnoreError(); |
962 | }); |
963 | f_opts.step_container = step_container; |
964 | f_opts.runner = ctx->runner(); |
965 | f_opts.create_rendezvous = ShouldCreateRendezvous(); |
966 | auto cancellation_manager = |
967 | std::make_unique<CancellationManager>(ctx->cancellation_manager()); |
968 | f_opts.cancellation_manager = cancellation_manager.get(); |
969 | f_opts.collective_executor = ctx->collective_executor(); |
970 | |
971 | std::shared_ptr<SimpleStepStatsCollector> stats_collector; |
972 | if (node || ctx->stats_aggregator()) { |
973 | stats_collector = std::make_shared<SimpleStepStatsCollector>(); |
974 | } |
975 | const bool collect_usage = node && ctx->model(); |
976 | f_opts.stats_collector = stats_collector.get(); |
977 | |
978 | // Transfer ownership of the cancellation manager to `callback`. |
979 | CancellationManager* raw_cancellation_manager = |
980 | cancellation_manager.release(); |
981 | auto callback = std::bind( |
982 | [this, rets, step_container, raw_cancellation_manager, frame, node, |
983 | collect_usage]( |
984 | const FunctionLibraryRuntime::DoneCallback& done, |
985 | IteratorContext* ctx, |
986 | const std::shared_ptr<SimpleStepStatsCollector>& stats_collector, |
987 | // Begin unbound arguments. |
988 | Status s) { |
989 | delete step_container; |
990 | delete raw_cancellation_manager; |
991 | if (s.ok()) { |
992 | s = frame->ConsumeRetvals(rets); |
993 | } |
994 | delete frame; |
995 | if (node) { |
996 | // TODO(b/129085499) Utilize the `node_name` which would be unique |
997 | // than the prefix for the function execution time statistics. |
998 | // prefix_with_func_name would then be node_name + func_name. |
999 | if (ctx->stats_aggregator()) { |
1000 | string prefix_with_func_name = |
1001 | strings::StrCat(node->name(), stats_utils::kDelimiter, |
1002 | captured_func_->func().name()); |
1003 | ctx->stats_aggregator()->AddToHistogram( |
1004 | stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), |
1005 | {static_cast<float>(stats_collector->processing_time())}, |
1006 | node->num_elements()); |
1007 | } |
1008 | node->add_processing_time(stats_collector->processing_time()); |
1009 | } |
1010 | if (collect_usage) { |
1011 | node->record_start(EnvTime::NowNanos()); |
1012 | } |
1013 | done(s); |
1014 | if (collect_usage) { |
1015 | node->record_stop(EnvTime::NowNanos()); |
1016 | } |
1017 | }, |
1018 | std::move(done), ctx, std::move(stats_collector), std::placeholders::_1); |
1019 | |
1020 | profiler::TraceMe activity( |
1021 | [&] { |
1022 | return profiler::TraceMeEncode("InstantiatedCapturedFunction::RunAsync" , |
1023 | {{"id" , f_opts.step_id}}); |
1024 | }, |
1025 | profiler::TraceMeLevel::kInfo); |
1026 | // Stop the usage collection before calling `Run()` because `callback` may |
1027 | // be executed synchronously, and so the `node->record_start()` call within |
1028 | // `callback` would violate nesting. |
1029 | if (collect_usage) node->record_stop(EnvTime::NowNanos()); |
1030 | lib_->Run(f_opts, f_handle_, frame, std::move(callback)); |
1031 | if (collect_usage) node->record_start(EnvTime::NowNanos()); |
1032 | } |
1033 | |
1034 | bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const { |
1035 | // Rendezvous should only be created by the FLR for non-CPU single-device |
1036 | // functions. For multi-device functions the appropriate rendezvous will be |
1037 | // created by the process FLR. |
1038 | return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_; |
1039 | } |
1040 | |
1041 | } // namespace data |
1042 | } // namespace tensorflow |
1043 | |