1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/eager/c_api.h"
17
18#include <algorithm>
19#include <cstddef>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "absl/algorithm/container.h"
25#include "absl/memory/memory.h"
26#include "tensorflow/c/c_api.h"
27#include "tensorflow/c/c_api_internal.h"
28#include "tensorflow/c/eager/abstract_tensor_handle.h"
29#include "tensorflow/c/eager/c_api_experimental.h"
30#include "tensorflow/c/eager/c_api_internal.h"
31#include "tensorflow/c/eager/immediate_execution_operation.h"
32#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33#include "tensorflow/c/eager/tfe_context_internal.h"
34#include "tensorflow/c/eager/tfe_op_internal.h"
35#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36#include "tensorflow/c/tf_buffer_internal.h"
37#include "tensorflow/c/tf_status.h"
38#include "tensorflow/c/tf_tensor_internal.h"
39#include "tensorflow/core/common_runtime/copy_tensor.h"
40#include "tensorflow/core/common_runtime/device.h"
41#include "tensorflow/core/common_runtime/device_factory.h"
42#include "tensorflow/core/common_runtime/device_mgr.h"
43#include "tensorflow/core/common_runtime/eager/attr_builder.h"
44#include "tensorflow/core/common_runtime/eager/context.h"
45#include "tensorflow/core/common_runtime/eager/custom_device.h"
46#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
47#include "tensorflow/core/common_runtime/eager/execute.h"
48#include "tensorflow/core/common_runtime/eager/placement_utils.h"
49#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
50#include "tensorflow/core/common_runtime/function.h"
51#include "tensorflow/core/framework/attr_value.pb.h"
52#include "tensorflow/core/framework/device_attributes.pb.h"
53#include "tensorflow/core/framework/function.h"
54#include "tensorflow/core/framework/node_def_util.h"
55#include "tensorflow/core/framework/rendezvous.h"
56#include "tensorflow/core/framework/tensor_shape.pb.h"
57#include "tensorflow/core/framework/types.h"
58#include "tensorflow/core/platform/casts.h"
59#include "tensorflow/core/platform/errors.h"
60#include "tensorflow/core/platform/platform.h"
61#include "tensorflow/core/platform/status.h"
62#include "tensorflow/core/profiler/lib/traceme.h"
63#include "tensorflow/core/protobuf/error_codes.pb.h"
64#include "tensorflow/core/public/version.h"
65
66// "tensorflow/core/platform/platform.h" must be included first before using
67// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
68#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
69#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
70#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
71#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
72
73#if !defined(IS_MOBILE_PLATFORM)
74#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
75#endif // !IS_MOBILE_PLATFORM
76
77using tensorflow::string;
78
79namespace {
80
81string DeviceName(const tensorflow::Device* d) {
82 return (d == nullptr) ? "cpu:0" : d->name();
83}
84
85// Annotate eager runtime construction context to the given `function_def` as
86// an attribute.
87void AnnotateEagerRuntimeConstructionContext(
88 tensorflow::FunctionDef& function_def) {
89 tensorflow::AttrValue value;
90 SetAttrValue("kEagerRuntime", &value);
91 (*function_def.mutable_attr())["_construction_context"] = value;
92}
93
94} // namespace
95
96extern "C" {
97
98TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
99
100void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
101 size_t proto_len, TF_Status* status) {
102 TF_SetConfig(&options->session_options, proto, proto_len, status);
103}
104
105void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
106 unsigned char enable) {
107 options->async = enable;
108}
109
110void TFE_ContextOptionsSetDevicePlacementPolicy(
111 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
112 options->device_placement_policy = policy;
113}
114
115void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
116
117TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
118 if (opts->use_tfrt) {
119#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
120 tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
121 opts->session_options.options,
122 static_cast<tensorflow::ContextDevicePlacementPolicy>(
123 opts->device_placement_policy),
124 opts->async, opts->use_tfrt_distributed_runtime);
125#if !defined(IS_MOBILE_PLATFORM)
126 tfrt_context->SetDistributedManager(
127 tfrt::tf::CreateDistributedManagerContext(
128 tfrt_context->GetCoreRuntime()->GetHostContext()));
129#endif // !IS_MOBILE_PLATFORM
130 return tensorflow::wrap(tfrt_context);
131#else
132 status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
133 return nullptr;
134#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
135 }
136 std::vector<std::unique_ptr<tensorflow::Device>> devices;
137 status->status = tensorflow::DeviceFactory::AddDevices(
138 opts->session_options.options, "/job:localhost/replica:0/task:0",
139 &devices);
140 if (!status->status.ok()) return nullptr;
141 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
142 new tensorflow::DynamicDeviceMgr(std::move(devices)));
143
144 tensorflow::Rendezvous* r =
145 new tensorflow::IntraProcessRendezvous(device_mgr.get());
146 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
147 opts->session_options.options,
148 static_cast<tensorflow::ContextDevicePlacementPolicy>(
149 opts->device_placement_policy),
150 opts->async, device_mgr.release(),
151 /*device_mgr_owned*/ true, r,
152 /*cluster_flr=*/nullptr,
153 /*collective_executor_mgr=*/nullptr,
154 /*run_eager_op_as_function=*/opts->run_eager_op_as_function,
155 /*jit_compile_rewrite=*/opts->jit_compile_rewrite);
156#if !defined(IS_MOBILE_PLATFORM)
157 eager_context->SetDistributedManager(
158 std::make_unique<tensorflow::EagerContextDistributedManager>(
159 eager_context));
160#endif // !IS_MOBILE_PLATFORM
161 return tensorflow::wrap(eager_context);
162}
163
164void TFE_DeleteContext(TFE_Context* ctx) {
165 if (ctx == nullptr) {
166 return;
167 }
168
169 // ctx->RefCountIsOne() should be true here.
170 tensorflow::unwrap(ctx)->Release();
171}
172
173TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
174 TF_DeviceList* l = new TF_DeviceList;
175 tensorflow::unwrap(ctx)->ListDevices(&l->response);
176 return l;
177}
178
179void TFE_ContextClearCaches(TFE_Context* ctx) {
180 tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
181}
182
183// Set server_def on the context, possibly updating it.
184TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
185 int keep_alive_secs,
186 const void* proto,
187 size_t proto_len,
188 TF_Status* status) {
189#if defined(IS_MOBILE_PLATFORM)
190 status->status = tensorflow::errors::Unimplemented(
191 "TFE_ContextSetServerDef not supported on mobile");
192#else // !defined(IS_MOBILE_PLATFORM)
193 tensorflow::ServerDef server_def;
194 if (!server_def.ParseFromArray(proto, proto_len)) {
195 status->status = tensorflow::errors::InvalidArgument(
196 "Invalid tensorflow.ServerDef protocol buffer");
197 return;
198 }
199 status->status =
200 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
201 server_def, /*reset_context=*/true, keep_alive_secs);
202#endif // !IS_MOBILE_PLATFORM
203}
204
205TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
206 int keep_alive_secs,
207 const void* proto,
208 size_t proto_len,
209 TF_Status* status) {
210#if defined(IS_MOBILE_PLATFORM)
211 status->status = tensorflow::errors::Unimplemented(
212 "TFE_ContextSetServerDef not supported on mobile");
213#else // !defined(IS_MOBILE_PLATFORM)
214 tensorflow::ServerDef server_def;
215 tensorflow::EagerContext* context =
216 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
217 if (!server_def.ParseFromArray(proto, proto_len)) {
218 status->status = tensorflow::errors::InvalidArgument(
219 "Invalid tensorflow.ServerDef protocol buffer");
220 return;
221 } else if (context->GetContextId() ==
222 tensorflow::EagerContext::kInvalidContextId) {
223 status->status = tensorflow::errors::InvalidArgument(
224 "Trying to update a context with invalid context id.");
225 }
226 status->status =
227 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
228 server_def, /*reset_context=*/false, keep_alive_secs);
229#endif // !IS_MOBILE_PLATFORM
230}
231
232TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
233 const char* worker_name,
234 TF_Status* status) {
235#if defined(IS_MOBILE_PLATFORM)
236 status->status = tensorflow::errors::Unimplemented(
237 "TFE_ContextSetServerDef not supported on mobile");
238 return false;
239#else // !defined(IS_MOBILE_PLATFORM)
240 bool is_alive;
241 status->status =
242 tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
243 worker_name, &is_alive);
244 return is_alive;
245#endif // !IS_MOBILE_PLATFORM
246}
247
248TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
249 TF_Status* status) {
250#if defined(IS_MOBILE_PLATFORM)
251 status->status = tensorflow::OkStatus();
252#else // !defined(IS_MOBILE_PLATFORM)
253 status->status = tensorflow::unwrap(ctx)->AsyncWait();
254#endif // !IS_MOBILE_PLATFORM
255}
256
257void TFE_ContextSetThreadLocalDevicePlacementPolicy(
258 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
259 tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
260 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
261}
262
263// Note: this function looks up a thread local policy. So it should be called in
264// the appropriate client thread. In particular, in async mode, it may not be
265// safe to call this function from the async EagerExecutor threads.
266extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
267 TFE_Context* ctx) {
268 return static_cast<TFE_ContextDevicePlacementPolicy>(
269 tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
270}
271
272TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
273 tensorflow::Tensor tensor;
274 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
275 if (!status->status.ok()) return nullptr;
276
277 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
278}
279
280void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
281 if (h == nullptr) return;
282
283 tensorflow::profiler::TraceMe activity(
284 "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
285 if (h) {
286 tensorflow::unwrap(h)->Release();
287 }
288}
289
290TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
291 return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
292}
293
294int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
295 if (h == nullptr) {
296 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
297 return -1;
298 }
299
300 int num_dims = -1;
301 status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
302 return num_dims;
303}
304
305int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
306 if (h == nullptr) {
307 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
308 return -1;
309 }
310
311 int64_t num_elements = -1;
312 status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
313 return num_elements;
314}
315
316int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
317 TF_Status* status) {
318 if (h == nullptr) {
319 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
320 return -1;
321 }
322
323 int64_t dim = -1;
324 status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
325 return dim;
326}
327
328const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
329 if (h == nullptr) {
330 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
331 return nullptr;
332 }
333 return tensorflow::unwrap(h)->DeviceName(&status->status);
334}
335
336const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
337 TF_Status* status) {
338 if (h == nullptr) {
339 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
340 return nullptr;
341 }
342 return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
343}
344
345TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
346 TFE_TensorHandle* h, TF_Status* status) {
347 if (h == nullptr) {
348 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
349 return nullptr;
350 }
351
352 return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
353}
354
355TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
356 if (h == nullptr) {
357 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
358 return nullptr;
359 }
360
361 tensorflow::AbstractTensorInterface* t =
362 tensorflow::unwrap(h)->Resolve(&status->status);
363 if (t == nullptr) {
364 return nullptr;
365 }
366
367 return new TF_Tensor{t};
368}
369
370void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
371 if (h == nullptr) {
372 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
373 return nullptr;
374 }
375 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
376 tensorflow::unwrap(h);
377 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
378 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
379 return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
380 unwrapped_handle)
381 ->DevicePointer();
382 }
383 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
384 if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
385 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
386 return nullptr;
387 }
388 tensorflow::TensorHandle* handle =
389 tensorflow::TensorHandleFromInterface(unwrapped_handle);
390
391 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
392 status->status = tensorflow::errors::InvalidArgument(
393 "TFE_TensorHandleDevicePointer may not be called on a ",
394 handle->TypeString(), " tensor handle.");
395 return nullptr;
396 }
397 tensorflow::Device* device(handle->device());
398 if (device != nullptr) {
399 status->status = device->Sync();
400 if (!status->status.ok()) {
401 return nullptr;
402 }
403 }
404 const tensorflow::Tensor* tensor;
405 status->status = handle->Tensor(&tensor);
406 if (!status->status.ok()) {
407 return nullptr;
408 }
409 return const_cast<void*>(
410 static_cast<const void*>(tensor->tensor_data().data()));
411}
412
413namespace tensorflow {
414namespace {
415class CustomDeviceAPI : public tensorflow::CustomDevice {
416 public:
417 CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
418 string name)
419 : context_(context), device_(device), info_(info), name_(name) {}
420
421 ~CustomDeviceAPI() override { device_.delete_device(info_); }
422
423 const string& name() override { return name_; }
424
425 tensorflow::Status CopyTensorToDevice(
426 ImmediateExecutionTensorHandle* handle,
427 ImmediateExecutionTensorHandle** result) override {
428 handle->Ref();
429 TF_Status status;
430 TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
431 context_, tensorflow::wrap(handle), &status, info_);
432 handle->Release();
433 if (!status.status.ok()) return status.status;
434 *result = tensorflow::unwrap(result_handle);
435 (*result)->Ref();
436 TFE_DeleteTensorHandle(result_handle);
437 return status.status;
438 }
439
440 tensorflow::Status CopyTensorFromDevice(
441 ImmediateExecutionTensorHandle* handle,
442 const tensorflow::string& target_device_name,
443 ImmediateExecutionTensorHandle** result) override {
444 TF_Status status;
445 handle->Ref();
446 TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
447 context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
448 info_);
449 handle->Release();
450 if (!status.status.ok()) return status.status;
451 *result = tensorflow::unwrap(result_handle);
452 (*result)->Ref();
453 TFE_DeleteTensorHandle(result_handle);
454 return status.status;
455 }
456
457 tensorflow::Status Execute(const ImmediateExecutionOperation* op,
458 ImmediateExecutionTensorHandle** retvals,
459 int* num_retvals) override {
460 std::vector<TFE_TensorHandle*> outputs(*num_retvals);
461 TF_Status status;
462 device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
463 info_);
464 if (status.status.ok()) {
465 for (int i = 0; i < *num_retvals; ++i) {
466 retvals[i] = tensorflow::unwrap(outputs[i]);
467 retvals[i]->Ref();
468 TFE_DeleteTensorHandle(outputs[i]);
469 }
470 }
471 return status.status;
472 }
473
474 tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
475 ImmediateExecutionTensorHandle** result) override {
476 TF_Status status;
477 *result = tensorflow::unwrap(device_.pack(context_,
478 tensorflow::wrap(handles.data()),
479 handles.size(), &status, info_));
480 return status.status;
481 }
482
483 tensorflow::StatusOr<bool> ShallPinToThisDevice(
484 const ImmediateExecutionOperation* op) override {
485 TF_Status status;
486 // Let this custom device choose the device to pin this op on if it
487 // implements the pinning function.
488 if (device_.shall_pin_to_this_device != nullptr) {
489 return device_.shall_pin_to_this_device(tensorflow::wrap(op), &status);
490 }
491 return errors::Unimplemented("No custom device pinning implementation.");
492 }
493
494 private:
495 TFE_Context* context_;
496 TFE_CustomDevice device_;
497 void* info_;
498 string name_;
499};
500
501// An adapter which wraps the shape/data produced by C custom devices and uses
502// it to implement custom device methods.
503class CAPICustomDeviceTensorHandle
504 : public tensorflow::CustomDeviceTensorHandle {
505 public:
506 CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
507 tensorflow::CustomDevice* device,
508 tensorflow::DataType dtype, void* data,
509 TFE_CustomDeviceTensorHandleMethods methods)
510 : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
511 data_(data),
512 methods_(methods) {}
513
514 ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); }
515 void* DevicePointer() const override { return data_; }
516 Status NumDims(int* num_dims) const override {
517 TF_Status s;
518 *num_dims = methods_.num_dims(data_, &s);
519 return s.status;
520 }
521 Status Dim(int dim_index, int64_t* dim) const override {
522 TF_Status s;
523 *dim = methods_.dim(data_, dim_index, &s);
524 return s.status;
525 }
526
527 bool PreferCustomSummarizer() const override {
528 return methods_.summarize != nullptr;
529 }
530
531 Status SummarizeValue(std::string& summary) const override {
532 if (methods_.summarize == nullptr) {
533 return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary);
534 }
535 TF_Status c_status;
536 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer(
537 methods_.summarize(data_, &c_status), TF_DeleteBuffer);
538 if (!c_status.status.ok()) {
539 return c_status.status;
540 }
541 summary = std::string(reinterpret_cast<const char*>(summary_buffer->data),
542 summary_buffer->length);
543 return OkStatus();
544 }
545
546 private:
547 void* const data_;
548 const TFE_CustomDeviceTensorHandleMethods methods_;
549};
550
551} // namespace
552} // namespace tensorflow
553
554TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
555 TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
556 TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) {
557 tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx);
558 tensorflow::CustomDevice* device = nullptr;
559 if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
560 &device)) {
561 methods.deallocator(data);
562 status->status =
563 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
564 return nullptr;
565 }
566 return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
567 context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
568 methods));
569}
570
571TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
572 TFE_Context* ctx, const char* device_name, TF_DataType dtype,
573 const int64_t* dims, int num_dims, void* data, size_t len,
574 void (*deallocator)(void* data, size_t len, void* arg),
575 void* deallocator_arg, TF_Status* status) {
576 tensorflow::Device* device = nullptr;
577 tensorflow::EagerContext* context =
578 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
579 status->status = context->FindDeviceFromName(device_name, &device);
580 if (!status->status.ok()) {
581 deallocator(data, len, deallocator_arg);
582 status->status =
583 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
584 return nullptr;
585 }
586 std::vector<int64_t> dimvec(num_dims);
587 for (int i = 0; i < num_dims; ++i) {
588 dimvec[i] = static_cast<int64_t>(dims[i]);
589 }
590
591 // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
592 // the device?
593 TF_ManagedBuffer* buf =
594 new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
595 /*owns_memory=*/false);
596
597 tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
598 tensorflow::TensorShape(dimvec), buf);
599 buf->Unref();
600 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
601 std::move(t), device, device, context));
602}
603
604// This function will block till the operation that produces `h` has
605// completed. This is only valid on local TFE_TensorHandles. Returns the size in
606// bytes of the memory pointed to by the device pointer returned above.
607size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
608 TF_Status* status) {
609 if (h == nullptr) {
610 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
611 return 0;
612 }
613 tensorflow::TensorHandle* handle =
614 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
615 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
616 status->status = tensorflow::errors::InvalidArgument(
617 "TFE_TensorHandleDeviceMemorySize may not be called on a ",
618 handle->TypeString(), " tensor handle.");
619 return 0;
620 }
621 const tensorflow::Tensor* tensor;
622 status->status = handle->Tensor(&tensor);
623 if (!status->status.ok()) {
624 return 0;
625 }
626 return tensor->TotalBytes();
627}
628
629TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
630 TF_Status* status) {
631 tensorflow::ImmediateExecutionOperation* new_op =
632 tensorflow::unwrap(ctx)->CreateOperation();
633 status->status = new_op->Reset(op_or_function_name, nullptr);
634 if (!status->status.ok()) {
635 new_op->Release();
636 new_op = nullptr;
637 }
638 return tensorflow::wrap(new_op);
639}
640
641void TFE_DeleteOp(TFE_Op* op) {
642 if (op == nullptr) {
643 return;
644 }
645
646 tensorflow::unwrap(op)->Release();
647}
648
649const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
650 return tensorflow::unwrap(op)->Name().c_str();
651}
652
653TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
654 return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
655}
656
657void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
658 status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
659}
660
661const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
662 return tensorflow::unwrap(op)->DeviceName().c_str();
663}
664
665void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
666 status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
667}
668
669void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
670 TF_Status* status) {
671 status->status = tensorflow::unwrap(op)->AddInputList(
672 {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
673 tensorflow::unwrap(inputs)),
674 static_cast<size_t>(num_inputs)});
675}
676
677extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
678 return tensorflow::unwrap(op)->GetInputs().size();
679}
680
681extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
682 TF_Status* status) {
683 return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
684}
685
686TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
687 unsigned char* is_list, TF_Status* status) {
688 TF_AttrType ret = TF_ATTR_INT;
689 const tensorflow::AttrTypeMap* attr_types_;
690 bool is_function;
691 status->status = tensorflow::AttrTypeMapForOp(
692 tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
693 if (!status->status.ok()) {
694 return ret;
695 }
696 status->status =
697 tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
698 return ret;
699}
700
701TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
702 const char* op_or_function_name,
703 const char* attr_name, unsigned char* is_list,
704 TF_Status* status) {
705 TF_AttrType ret;
706 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
707 if (status->status.ok()) {
708 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
709 } else {
710 ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
711 }
712 TFE_DeleteOp(op);
713 return ret;
714}
715
716void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
717 size_t length) {
718 auto s = tensorflow::unwrap(op)->SetAttrString(
719 attr_name, static_cast<const char*>(value), length);
720 if (!s.ok()) {
721 LOG(WARNING) << "Unable to set attribute: " << attr_name;
722 }
723}
724
725void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
726 auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
727 if (!s.ok()) {
728 LOG(WARNING) << "Unable to set attribute: " << attr_name;
729 }
730}
731
732void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
733 auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
734 if (!s.ok()) {
735 LOG(WARNING) << "Unable to set attribute: " << attr_name;
736 }
737}
738
739void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
740 auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
741 (value == 0) ? false : true);
742 if (!s.ok()) {
743 LOG(WARNING) << "Unable to set attribute: " << attr_name;
744 }
745}
746
747void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
748 auto s = tensorflow::unwrap(op)->SetAttrType(
749 attr_name, static_cast<tensorflow::DataType>(value));
750 if (!s.ok()) {
751 LOG(WARNING) << "Unable to set attribute: " << attr_name;
752 }
753}
754
755void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
756 const int num_dims, TF_Status* out_status) {
757 out_status->status =
758 tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
759}
760
761void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
762 const TFE_Op* value) {
763 auto s = tensorflow::unwrap(op)->SetAttrFunction(
764 attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
765 if (!s.ok()) {
766 LOG(WARNING) << "Unable to set attribute: " << attr_name;
767 }
768}
769
770void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
771 const char* data, size_t length) {
772 auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
773 if (!s.ok()) {
774 LOG(WARNING) << "Unable to set attribute: " << attr_name;
775 }
776}
777
778void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
779 TF_Status* status) {
780 tensorflow::Tensor t;
781 status->status = TF_TensorToTensor(tensor, &t);
782 tensorflow::TensorInterface interface(t);
783 status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
784}
785
786void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
787 const void* const* values, const size_t* lengths,
788 int num_values) {
789 auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
790 num_values);
791 if (!s.ok()) {
792 LOG(WARNING) << "Unable to set attribute: " << attr_name;
793 }
794}
795
796void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
797 const float* values, int num_values) {
798 auto s =
799 tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
800 if (!s.ok()) {
801 LOG(WARNING) << "Unable to set attribute: " << attr_name;
802 }
803}
804
805void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
806 const int64_t* values, int num_values) {
807 auto s =
808 tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
809 if (!s.ok()) {
810 LOG(WARNING) << "Unable to set attribute: " << attr_name;
811 }
812}
813
814void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
815 const TF_DataType* values, int num_values) {
816 auto s = tensorflow::unwrap(op)->SetAttrTypeList(
817 attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
818 num_values);
819 if (!s.ok()) {
820 LOG(WARNING) << "Unable to set attribute: " << attr_name;
821 }
822}
823
824void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
825 const unsigned char* values, int num_values) {
826 auto s =
827 tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
828 if (!s.ok()) {
829 LOG(WARNING) << "Unable to set attribute: " << attr_name;
830 }
831}
832
833void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
834 const int64_t** dims, const int* num_dims,
835 int num_values, TF_Status* out_status) {
836 out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
837 attr_name, dims, num_dims, num_values);
838}
839
840void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
841 const TFE_Op** value, int num_values) {
842 auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
843 attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
844 tensorflow::unwrap(value)),
845 static_cast<size_t>(num_values)});
846 if (!s.ok()) {
847 LOG(WARNING) << "Unable to set attribute: " << attr_name;
848 }
849}
850
851void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
852 const void* proto, size_t proto_len,
853 TF_Status* status) {
854 tensorflow::AttrValue attr_value;
855 if (!attr_value.ParseFromArray(proto, proto_len)) {
856 status->status =
857 tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
858 return;
859 }
860 if (op == nullptr) {
861 status->status = tensorflow::errors::InvalidArgument(
862 "Got a null or uninitialized `op` argument");
863 return;
864 }
865 tensorflow::EagerOperation* operation =
866 OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
867 operation->MutableAttrs()->Set(attr_name, attr_value);
868}
869
870TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
871 const char* input_name,
872 TF_Status* status) {
873 int ret = -1;
874 status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
875 return ret;
876}
877
878TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
879 const char* output_name,
880 TF_Status* status) {
881 int ret = -1;
882 status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
883 return ret;
884}
885
886void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
887 TF_Status* status) {
888 tensorflow::ImmediateExecutionOperation* unwrapped_op =
889 tensorflow::unwrap(op);
890
891 status->status =
892 unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
893 unwrapped_op,
894 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
895 retvals),
896 num_retvals);
897}
898
899TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
900 TFE_Context* ctx,
901 const char* device_name,
902 TF_Status* status) {
903 if (h == nullptr) {
904 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
905 return nullptr;
906 }
907
908 tensorflow::ImmediateExecutionContext* unwrapped_ctx =
909 tensorflow::unwrap(ctx);
910
911 auto* result =
912 unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice(
913 unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status);
914
915 if (status->status.ok()) {
916 return tensorflow::wrap(result);
917 }
918 return nullptr;
919}
920
921void TFE_ContextAddFunctionDef(TFE_Context* ctx,
922 const char* serialized_function_def, size_t size,
923 TF_Status* status) {
924 tensorflow::FunctionDef function_def;
925 if (!function_def.ParseFromArray(serialized_function_def, size)) {
926 status->status =
927 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
928 return;
929 }
930
931 AnnotateEagerRuntimeConstructionContext(function_def);
932 status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
933}
934
935void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
936 TF_Status* status) {
937 AnnotateEagerRuntimeConstructionContext(function->fdef);
938 status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
939 function->fdef, function->stack_traces);
940}
941
942void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
943 TF_Status* status) {
944 status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
945}
946
947unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
948 return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
949}
950
951void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
952 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
953}
954
955void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
956 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
957}
958
959} // extern "C"
960
961TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
962 TF_Status* status) {
963 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
964}
965
966void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
967 TF_Status* status) {
968 auto* context = tensorflow::unwrap(ctx);
969 status->status = context->AsyncWait();
970 if (!status->status.ok()) return;
971 auto run_metadata = context->ExportRunMetadata();
972 status->status = MessageToBuffer(*run_metadata, buf);
973}
974
975namespace {
976TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
977 TF_Status* status) {
978 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
979 for (const auto& attr : func.attr()) {
980 if (!status->status.ok()) return nullptr;
981 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
982 if (!status->status.ok()) return nullptr;
983 }
984 return func_op;
985}
986} // namespace
987
988void TFE_ContextStartStep(TFE_Context* ctx) {
989 tensorflow::unwrap(ctx)->StartStep();
990}
991
992void TFE_ContextEndStep(TFE_Context* ctx) {
993 tensorflow::unwrap(ctx)->EndStep();
994}
995
996const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
997 return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
998}
999
1000void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
1001 tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
1002}
1003
1004void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
1005 TF_Status* status) {
1006 tensorflow::NameAttrList name_and_attrs;
1007 tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
1008 status->status = MessageToBuffer(name_and_attrs, buf);
1009}
1010
1011namespace tensorflow {
1012void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1013 const tensorflow::AttrValue& default_value,
1014 const char* attr_name, TF_Status* status) {
1015 switch (default_value.value_case()) {
1016 case tensorflow::AttrValue::kS: {
1017 const string& v = default_value.s();
1018 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1019 break;
1020 }
1021 case tensorflow::AttrValue::kI:
1022 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1023 break;
1024 case tensorflow::AttrValue::kF:
1025 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1026 break;
1027 case tensorflow::AttrValue::kB:
1028 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1029 break;
1030 case tensorflow::AttrValue::kType:
1031 TFE_OpSetAttrType(op, attr_name,
1032 static_cast<TF_DataType>(default_value.type()));
1033 break;
1034 case tensorflow::AttrValue::kShape: {
1035 const auto& tensor_shape = default_value.shape();
1036 if (tensor_shape.unknown_rank()) {
1037 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1038 } else {
1039 const auto num_dims = tensor_shape.dim_size();
1040 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1041 for (int i = 0; i < num_dims; ++i) {
1042 dims[i] = tensor_shape.dim(i).size();
1043 }
1044 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1045 }
1046 } break;
1047 case tensorflow::AttrValue::kFunc: {
1048 const auto func_op = GetFunc(ctx, default_value.func(), status);
1049 if (!status->status.ok()) return;
1050 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1051 // require TFE_Op* and just convert it internally a NameAttrValue, so
1052 // consider adding an overload to the C API to make this case easier.
1053 TFE_OpSetAttrFunction(op, attr_name, func_op);
1054 TFE_DeleteOp(func_op);
1055 } break;
1056 case tensorflow::AttrValue::kList: {
1057 // String
1058 if (const int s_size = default_value.list().s_size()) {
1059 absl::InlinedVector<const void*, 4> values_vector;
1060 values_vector.reserve(s_size);
1061 absl::InlinedVector<size_t, 4> lengths_vector;
1062 lengths_vector.reserve(s_size);
1063 for (int i = 0; i < s_size; ++i) {
1064 const string& v = default_value.list().s(i);
1065 values_vector.push_back(v.data());
1066 lengths_vector.push_back(v.size());
1067 }
1068 TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1069 lengths_vector.data(), s_size);
1070 }
1071
1072 // Int
1073 if (const int i_size = default_value.list().i_size()) {
1074 absl::InlinedVector<int64_t, 4> i_vector;
1075 i_vector.reserve(i_size);
1076 for (int i = 0; i < i_size; ++i) {
1077 i_vector.push_back(default_value.list().i(i));
1078 }
1079 TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1080 }
1081 // Float
1082 if (const int f_size = default_value.list().f_size()) {
1083 absl::InlinedVector<float, 4> f_vector;
1084 f_vector.reserve(f_size);
1085 for (int i = 0; i < f_size; ++i) {
1086 f_vector.push_back(default_value.list().f(i));
1087 }
1088 TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1089 }
1090 // Bool
1091 if (const int b_size = default_value.list().b_size()) {
1092 absl::InlinedVector<unsigned char, 4> b_vector;
1093 b_vector.reserve(b_size);
1094 for (int i = 0; i < b_size; i++) {
1095 b_vector.push_back(default_value.list().b(i));
1096 }
1097 TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1098 }
1099 // Type
1100 if (const int type_size = default_value.list().type_size()) {
1101 absl::InlinedVector<unsigned int, 4> type_vector;
1102 type_vector.reserve(type_size);
1103 for (int i = 0; i < type_size; ++i) {
1104 type_vector.push_back(default_value.list().type(i));
1105 }
1106 TFE_OpSetAttrTypeList(
1107 op, attr_name,
1108 reinterpret_cast<const TF_DataType*>(type_vector.data()),
1109 type_size);
1110 }
1111
1112 // Rest are not supported.
1113 if (default_value.list().shape_size() > 0 ||
1114 default_value.list().func_size() > 0 ||
1115 default_value.list().tensor_size() > 0) {
1116 TF_SetStatus(
1117 status, TF_UNIMPLEMENTED,
1118 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1119 default_value.DebugString())
1120 .data());
1121 }
1122 } break;
1123 case tensorflow::AttrValue::kTensor:
1124 TF_FALLTHROUGH_INTENDED;
1125 case tensorflow::AttrValue::kPlaceholder:
1126 TF_FALLTHROUGH_INTENDED;
1127 case tensorflow::AttrValue::VALUE_NOT_SET:
1128 TF_SetStatus(
1129 status, TF_UNIMPLEMENTED,
1130 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1131 default_value.DebugString())
1132 .data());
1133 }
1134}
1135} // namespace tensorflow
1136
1137namespace {
1138TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1139 TFE_TensorHandle** handles,
1140 int num_handles, TF_Status* status,
1141 void* device_info) {
1142 TF_SetStatus(status, TF_UNIMPLEMENTED,
1143 "This custom device does not support packing tensors.");
1144 return nullptr;
1145}
1146} // namespace
1147
1148extern "C" {
1149
1150bool TFE_IsCustomDevice(TFE_Context* ctx, const char* device_name) {
1151 return tensorflow::unwrap(ctx)->IsCustomDevice(device_name);
1152}
1153
1154void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1155 const char* device_name, void* device_info,
1156 TF_Status* status) {
1157 // Fill in default values for optional functionality.
1158 if (device.pack == nullptr) {
1159 device.pack = &DefaultCustomDevicePack;
1160 }
1161 auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1162 ctx, device, device_info, device_name);
1163 status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1164 device_name, std::move(custom_device));
1165}
1166
1167} // extern "C"
1168