1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/eager/c_api.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "absl/algorithm/container.h" |
25 | #include "absl/memory/memory.h" |
26 | #include "tensorflow/c/c_api.h" |
27 | #include "tensorflow/c/c_api_internal.h" |
28 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
29 | #include "tensorflow/c/eager/c_api_experimental.h" |
30 | #include "tensorflow/c/eager/c_api_internal.h" |
31 | #include "tensorflow/c/eager/immediate_execution_operation.h" |
32 | #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" |
33 | #include "tensorflow/c/eager/tfe_context_internal.h" |
34 | #include "tensorflow/c/eager/tfe_op_internal.h" |
35 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
36 | #include "tensorflow/c/tf_buffer_internal.h" |
37 | #include "tensorflow/c/tf_status.h" |
38 | #include "tensorflow/c/tf_tensor_internal.h" |
39 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
40 | #include "tensorflow/core/common_runtime/device.h" |
41 | #include "tensorflow/core/common_runtime/device_factory.h" |
42 | #include "tensorflow/core/common_runtime/device_mgr.h" |
43 | #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
44 | #include "tensorflow/core/common_runtime/eager/context.h" |
45 | #include "tensorflow/core/common_runtime/eager/custom_device.h" |
46 | #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" |
47 | #include "tensorflow/core/common_runtime/eager/execute.h" |
48 | #include "tensorflow/core/common_runtime/eager/placement_utils.h" |
49 | #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
50 | #include "tensorflow/core/common_runtime/function.h" |
51 | #include "tensorflow/core/framework/attr_value.pb.h" |
52 | #include "tensorflow/core/framework/device_attributes.pb.h" |
53 | #include "tensorflow/core/framework/function.h" |
54 | #include "tensorflow/core/framework/node_def_util.h" |
55 | #include "tensorflow/core/framework/rendezvous.h" |
56 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
57 | #include "tensorflow/core/framework/types.h" |
58 | #include "tensorflow/core/platform/casts.h" |
59 | #include "tensorflow/core/platform/errors.h" |
60 | #include "tensorflow/core/platform/platform.h" |
61 | #include "tensorflow/core/platform/status.h" |
62 | #include "tensorflow/core/profiler/lib/traceme.h" |
63 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
64 | #include "tensorflow/core/public/version.h" |
65 | |
66 | // "tensorflow/core/platform/platform.h" must be included first before using |
67 | // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. |
68 | #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) |
69 | #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" |
70 | #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h" |
71 | #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE |
72 | |
73 | #if !defined(IS_MOBILE_PLATFORM) |
74 | #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" |
75 | #endif // !IS_MOBILE_PLATFORM |
76 | |
77 | using tensorflow::string; |
78 | |
79 | namespace { |
80 | |
81 | string DeviceName(const tensorflow::Device* d) { |
82 | return (d == nullptr) ? "cpu:0" : d->name(); |
83 | } |
84 | |
85 | // Annotate eager runtime construction context to the given `function_def` as |
86 | // an attribute. |
87 | void AnnotateEagerRuntimeConstructionContext( |
88 | tensorflow::FunctionDef& function_def) { |
89 | tensorflow::AttrValue value; |
90 | SetAttrValue("kEagerRuntime" , &value); |
91 | (*function_def.mutable_attr())["_construction_context" ] = value; |
92 | } |
93 | |
94 | } // namespace |
95 | |
96 | extern "C" { |
97 | |
98 | TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } |
99 | |
100 | void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, |
101 | size_t proto_len, TF_Status* status) { |
102 | TF_SetConfig(&options->session_options, proto, proto_len, status); |
103 | } |
104 | |
105 | void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, |
106 | unsigned char enable) { |
107 | options->async = enable; |
108 | } |
109 | |
110 | void TFE_ContextOptionsSetDevicePlacementPolicy( |
111 | TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { |
112 | options->device_placement_policy = policy; |
113 | } |
114 | |
115 | void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } |
116 | |
117 | TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { |
118 | if (opts->use_tfrt) { |
119 | #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) |
120 | tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface( |
121 | opts->session_options.options, |
122 | static_cast<tensorflow::ContextDevicePlacementPolicy>( |
123 | opts->device_placement_policy), |
124 | opts->async, opts->use_tfrt_distributed_runtime); |
125 | #if !defined(IS_MOBILE_PLATFORM) |
126 | tfrt_context->SetDistributedManager( |
127 | tfrt::tf::CreateDistributedManagerContext( |
128 | tfrt_context->GetCoreRuntime()->GetHostContext())); |
129 | #endif // !IS_MOBILE_PLATFORM |
130 | return tensorflow::wrap(tfrt_context); |
131 | #else |
132 | status->status = tensorflow::errors::Unimplemented("TFRT is not supported" ); |
133 | return nullptr; |
134 | #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE |
135 | } |
136 | std::vector<std::unique_ptr<tensorflow::Device>> devices; |
137 | status->status = tensorflow::DeviceFactory::AddDevices( |
138 | opts->session_options.options, "/job:localhost/replica:0/task:0" , |
139 | &devices); |
140 | if (!status->status.ok()) return nullptr; |
141 | std::unique_ptr<tensorflow::DeviceMgr> device_mgr( |
142 | new tensorflow::DynamicDeviceMgr(std::move(devices))); |
143 | |
144 | tensorflow::Rendezvous* r = |
145 | new tensorflow::IntraProcessRendezvous(device_mgr.get()); |
146 | tensorflow::EagerContext* eager_context = new tensorflow::EagerContext( |
147 | opts->session_options.options, |
148 | static_cast<tensorflow::ContextDevicePlacementPolicy>( |
149 | opts->device_placement_policy), |
150 | opts->async, device_mgr.release(), |
151 | /*device_mgr_owned*/ true, r, |
152 | /*cluster_flr=*/nullptr, |
153 | /*collective_executor_mgr=*/nullptr, |
154 | /*run_eager_op_as_function=*/opts->run_eager_op_as_function, |
155 | /*jit_compile_rewrite=*/opts->jit_compile_rewrite); |
156 | #if !defined(IS_MOBILE_PLATFORM) |
157 | eager_context->SetDistributedManager( |
158 | std::make_unique<tensorflow::EagerContextDistributedManager>( |
159 | eager_context)); |
160 | #endif // !IS_MOBILE_PLATFORM |
161 | return tensorflow::wrap(eager_context); |
162 | } |
163 | |
164 | void TFE_DeleteContext(TFE_Context* ctx) { |
165 | if (ctx == nullptr) { |
166 | return; |
167 | } |
168 | |
169 | // ctx->RefCountIsOne() should be true here. |
170 | tensorflow::unwrap(ctx)->Release(); |
171 | } |
172 | |
173 | TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { |
174 | TF_DeviceList* l = new TF_DeviceList; |
175 | tensorflow::unwrap(ctx)->ListDevices(&l->response); |
176 | return l; |
177 | } |
178 | |
179 | void TFE_ContextClearCaches(TFE_Context* ctx) { |
180 | tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors(); |
181 | } |
182 | |
183 | // Set server_def on the context, possibly updating it. |
184 | TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, |
185 | int keep_alive_secs, |
186 | const void* proto, |
187 | size_t proto_len, |
188 | TF_Status* status) { |
189 | #if defined(IS_MOBILE_PLATFORM) |
190 | status->status = tensorflow::errors::Unimplemented( |
191 | "TFE_ContextSetServerDef not supported on mobile" ); |
192 | #else // !defined(IS_MOBILE_PLATFORM) |
193 | tensorflow::ServerDef server_def; |
194 | if (!server_def.ParseFromArray(proto, proto_len)) { |
195 | status->status = tensorflow::errors::InvalidArgument( |
196 | "Invalid tensorflow.ServerDef protocol buffer" ); |
197 | return; |
198 | } |
199 | status->status = |
200 | tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( |
201 | server_def, /*reset_context=*/true, keep_alive_secs); |
202 | #endif // !IS_MOBILE_PLATFORM |
203 | } |
204 | |
205 | TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, |
206 | int keep_alive_secs, |
207 | const void* proto, |
208 | size_t proto_len, |
209 | TF_Status* status) { |
210 | #if defined(IS_MOBILE_PLATFORM) |
211 | status->status = tensorflow::errors::Unimplemented( |
212 | "TFE_ContextSetServerDef not supported on mobile" ); |
213 | #else // !defined(IS_MOBILE_PLATFORM) |
214 | tensorflow::ServerDef server_def; |
215 | tensorflow::EagerContext* context = |
216 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
217 | if (!server_def.ParseFromArray(proto, proto_len)) { |
218 | status->status = tensorflow::errors::InvalidArgument( |
219 | "Invalid tensorflow.ServerDef protocol buffer" ); |
220 | return; |
221 | } else if (context->GetContextId() == |
222 | tensorflow::EagerContext::kInvalidContextId) { |
223 | status->status = tensorflow::errors::InvalidArgument( |
224 | "Trying to update a context with invalid context id." ); |
225 | } |
226 | status->status = |
227 | tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( |
228 | server_def, /*reset_context=*/false, keep_alive_secs); |
229 | #endif // !IS_MOBILE_PLATFORM |
230 | } |
231 | |
232 | TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, |
233 | const char* worker_name, |
234 | TF_Status* status) { |
235 | #if defined(IS_MOBILE_PLATFORM) |
236 | status->status = tensorflow::errors::Unimplemented( |
237 | "TFE_ContextSetServerDef not supported on mobile" ); |
238 | return false; |
239 | #else // !defined(IS_MOBILE_PLATFORM) |
240 | bool is_alive; |
241 | status->status = |
242 | tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive( |
243 | worker_name, &is_alive); |
244 | return is_alive; |
245 | #endif // !IS_MOBILE_PLATFORM |
246 | } |
247 | |
248 | TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, |
249 | TF_Status* status) { |
250 | #if defined(IS_MOBILE_PLATFORM) |
251 | status->status = tensorflow::OkStatus(); |
252 | #else // !defined(IS_MOBILE_PLATFORM) |
253 | status->status = tensorflow::unwrap(ctx)->AsyncWait(); |
254 | #endif // !IS_MOBILE_PLATFORM |
255 | } |
256 | |
257 | void TFE_ContextSetThreadLocalDevicePlacementPolicy( |
258 | TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { |
259 | tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy( |
260 | static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); |
261 | } |
262 | |
263 | // Note: this function looks up a thread local policy. So it should be called in |
264 | // the appropriate client thread. In particular, in async mode, it may not be |
265 | // safe to call this function from the async EagerExecutor threads. |
266 | extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( |
267 | TFE_Context* ctx) { |
268 | return static_cast<TFE_ContextDevicePlacementPolicy>( |
269 | tensorflow::unwrap(ctx)->GetDevicePlacementPolicy()); |
270 | } |
271 | |
272 | TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { |
273 | tensorflow::Tensor tensor; |
274 | status->status = tensorflow::TF_TensorToTensor(t, &tensor); |
275 | if (!status->status.ok()) return nullptr; |
276 | |
277 | return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); |
278 | } |
279 | |
280 | void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { |
281 | if (h == nullptr) return; |
282 | |
283 | tensorflow::profiler::TraceMe activity( |
284 | "TFE_DeleteTensorHandle" , tensorflow::profiler::TraceMeLevel::kInfo); |
285 | if (h) { |
286 | tensorflow::unwrap(h)->Release(); |
287 | } |
288 | } |
289 | |
290 | TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { |
291 | return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType()); |
292 | } |
293 | |
294 | int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { |
295 | if (h == nullptr) { |
296 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
297 | return -1; |
298 | } |
299 | |
300 | int num_dims = -1; |
301 | status->status = tensorflow::unwrap(h)->NumDims(&num_dims); |
302 | return num_dims; |
303 | } |
304 | |
305 | int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { |
306 | if (h == nullptr) { |
307 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
308 | return -1; |
309 | } |
310 | |
311 | int64_t num_elements = -1; |
312 | status->status = tensorflow::unwrap(h)->NumElements(&num_elements); |
313 | return num_elements; |
314 | } |
315 | |
316 | int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, |
317 | TF_Status* status) { |
318 | if (h == nullptr) { |
319 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
320 | return -1; |
321 | } |
322 | |
323 | int64_t dim = -1; |
324 | status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim); |
325 | return dim; |
326 | } |
327 | |
328 | const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { |
329 | if (h == nullptr) { |
330 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
331 | return nullptr; |
332 | } |
333 | return tensorflow::unwrap(h)->DeviceName(&status->status); |
334 | } |
335 | |
336 | const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, |
337 | TF_Status* status) { |
338 | if (h == nullptr) { |
339 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
340 | return nullptr; |
341 | } |
342 | return tensorflow::unwrap(h)->BackingDeviceName(&status->status); |
343 | } |
344 | |
345 | TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( |
346 | TFE_TensorHandle* h, TF_Status* status) { |
347 | if (h == nullptr) { |
348 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
349 | return nullptr; |
350 | } |
351 | |
352 | return tensorflow::wrap(tensorflow::unwrap(h)->Copy()); |
353 | } |
354 | |
355 | TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { |
356 | if (h == nullptr) { |
357 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
358 | return nullptr; |
359 | } |
360 | |
361 | tensorflow::AbstractTensorInterface* t = |
362 | tensorflow::unwrap(h)->Resolve(&status->status); |
363 | if (t == nullptr) { |
364 | return nullptr; |
365 | } |
366 | |
367 | return new TF_Tensor{t}; |
368 | } |
369 | |
370 | void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { |
371 | if (h == nullptr) { |
372 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
373 | return nullptr; |
374 | } |
375 | tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = |
376 | tensorflow::unwrap(h); |
377 | // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. |
378 | if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { |
379 | return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>( |
380 | unwrapped_handle) |
381 | ->DevicePointer(); |
382 | } |
383 | // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. |
384 | if (!tensorflow::TensorHandle::classof(unwrapped_handle)) { |
385 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
386 | return nullptr; |
387 | } |
388 | tensorflow::TensorHandle* handle = |
389 | tensorflow::TensorHandleFromInterface(unwrapped_handle); |
390 | |
391 | if (handle->Type() != tensorflow::TensorHandle::LOCAL) { |
392 | status->status = tensorflow::errors::InvalidArgument( |
393 | "TFE_TensorHandleDevicePointer may not be called on a " , |
394 | handle->TypeString(), " tensor handle." ); |
395 | return nullptr; |
396 | } |
397 | tensorflow::Device* device(handle->device()); |
398 | if (device != nullptr) { |
399 | status->status = device->Sync(); |
400 | if (!status->status.ok()) { |
401 | return nullptr; |
402 | } |
403 | } |
404 | const tensorflow::Tensor* tensor; |
405 | status->status = handle->Tensor(&tensor); |
406 | if (!status->status.ok()) { |
407 | return nullptr; |
408 | } |
409 | return const_cast<void*>( |
410 | static_cast<const void*>(tensor->tensor_data().data())); |
411 | } |
412 | |
413 | namespace tensorflow { |
414 | namespace { |
415 | class CustomDeviceAPI : public tensorflow::CustomDevice { |
416 | public: |
417 | CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, |
418 | string name) |
419 | : context_(context), device_(device), info_(info), name_(name) {} |
420 | |
421 | ~CustomDeviceAPI() override { device_.delete_device(info_); } |
422 | |
423 | const string& name() override { return name_; } |
424 | |
425 | tensorflow::Status CopyTensorToDevice( |
426 | ImmediateExecutionTensorHandle* handle, |
427 | ImmediateExecutionTensorHandle** result) override { |
428 | handle->Ref(); |
429 | TF_Status status; |
430 | TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( |
431 | context_, tensorflow::wrap(handle), &status, info_); |
432 | handle->Release(); |
433 | if (!status.status.ok()) return status.status; |
434 | *result = tensorflow::unwrap(result_handle); |
435 | (*result)->Ref(); |
436 | TFE_DeleteTensorHandle(result_handle); |
437 | return status.status; |
438 | } |
439 | |
440 | tensorflow::Status CopyTensorFromDevice( |
441 | ImmediateExecutionTensorHandle* handle, |
442 | const tensorflow::string& target_device_name, |
443 | ImmediateExecutionTensorHandle** result) override { |
444 | TF_Status status; |
445 | handle->Ref(); |
446 | TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( |
447 | context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, |
448 | info_); |
449 | handle->Release(); |
450 | if (!status.status.ok()) return status.status; |
451 | *result = tensorflow::unwrap(result_handle); |
452 | (*result)->Ref(); |
453 | TFE_DeleteTensorHandle(result_handle); |
454 | return status.status; |
455 | } |
456 | |
457 | tensorflow::Status Execute(const ImmediateExecutionOperation* op, |
458 | ImmediateExecutionTensorHandle** retvals, |
459 | int* num_retvals) override { |
460 | std::vector<TFE_TensorHandle*> outputs(*num_retvals); |
461 | TF_Status status; |
462 | device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, |
463 | info_); |
464 | if (status.status.ok()) { |
465 | for (int i = 0; i < *num_retvals; ++i) { |
466 | retvals[i] = tensorflow::unwrap(outputs[i]); |
467 | retvals[i]->Ref(); |
468 | TFE_DeleteTensorHandle(outputs[i]); |
469 | } |
470 | } |
471 | return status.status; |
472 | } |
473 | |
474 | tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles, |
475 | ImmediateExecutionTensorHandle** result) override { |
476 | TF_Status status; |
477 | *result = tensorflow::unwrap(device_.pack(context_, |
478 | tensorflow::wrap(handles.data()), |
479 | handles.size(), &status, info_)); |
480 | return status.status; |
481 | } |
482 | |
483 | tensorflow::StatusOr<bool> ShallPinToThisDevice( |
484 | const ImmediateExecutionOperation* op) override { |
485 | TF_Status status; |
486 | // Let this custom device choose the device to pin this op on if it |
487 | // implements the pinning function. |
488 | if (device_.shall_pin_to_this_device != nullptr) { |
489 | return device_.shall_pin_to_this_device(tensorflow::wrap(op), &status); |
490 | } |
491 | return errors::Unimplemented("No custom device pinning implementation." ); |
492 | } |
493 | |
494 | private: |
495 | TFE_Context* context_; |
496 | TFE_CustomDevice device_; |
497 | void* info_; |
498 | string name_; |
499 | }; |
500 | |
501 | // An adapter which wraps the shape/data produced by C custom devices and uses |
502 | // it to implement custom device methods. |
503 | class CAPICustomDeviceTensorHandle |
504 | : public tensorflow::CustomDeviceTensorHandle { |
505 | public: |
506 | CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context, |
507 | tensorflow::CustomDevice* device, |
508 | tensorflow::DataType dtype, void* data, |
509 | TFE_CustomDeviceTensorHandleMethods methods) |
510 | : tensorflow::CustomDeviceTensorHandle(context, device, dtype), |
511 | data_(data), |
512 | methods_(methods) {} |
513 | |
514 | ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); } |
515 | void* DevicePointer() const override { return data_; } |
516 | Status NumDims(int* num_dims) const override { |
517 | TF_Status s; |
518 | *num_dims = methods_.num_dims(data_, &s); |
519 | return s.status; |
520 | } |
521 | Status Dim(int dim_index, int64_t* dim) const override { |
522 | TF_Status s; |
523 | *dim = methods_.dim(data_, dim_index, &s); |
524 | return s.status; |
525 | } |
526 | |
527 | bool PreferCustomSummarizer() const override { |
528 | return methods_.summarize != nullptr; |
529 | } |
530 | |
531 | Status SummarizeValue(std::string& summary) const override { |
532 | if (methods_.summarize == nullptr) { |
533 | return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary); |
534 | } |
535 | TF_Status c_status; |
536 | std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer( |
537 | methods_.summarize(data_, &c_status), TF_DeleteBuffer); |
538 | if (!c_status.status.ok()) { |
539 | return c_status.status; |
540 | } |
541 | summary = std::string(reinterpret_cast<const char*>(summary_buffer->data), |
542 | summary_buffer->length); |
543 | return OkStatus(); |
544 | } |
545 | |
546 | private: |
547 | void* const data_; |
548 | const TFE_CustomDeviceTensorHandleMethods methods_; |
549 | }; |
550 | |
551 | } // namespace |
552 | } // namespace tensorflow |
553 | |
554 | TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( |
555 | TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data, |
556 | TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) { |
557 | tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx); |
558 | tensorflow::CustomDevice* device = nullptr; |
559 | if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name, |
560 | &device)) { |
561 | methods.deallocator(data); |
562 | status->status = |
563 | tensorflow::errors::InvalidArgument(device_name, " unknown device." ); |
564 | return nullptr; |
565 | } |
566 | return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle( |
567 | context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data, |
568 | methods)); |
569 | } |
570 | |
571 | TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( |
572 | TFE_Context* ctx, const char* device_name, TF_DataType dtype, |
573 | const int64_t* dims, int num_dims, void* data, size_t len, |
574 | void (*deallocator)(void* data, size_t len, void* arg), |
575 | void* deallocator_arg, TF_Status* status) { |
576 | tensorflow::Device* device = nullptr; |
577 | tensorflow::EagerContext* context = |
578 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
579 | status->status = context->FindDeviceFromName(device_name, &device); |
580 | if (!status->status.ok()) { |
581 | deallocator(data, len, deallocator_arg); |
582 | status->status = |
583 | tensorflow::errors::InvalidArgument(device_name, " unknown device." ); |
584 | return nullptr; |
585 | } |
586 | std::vector<int64_t> dimvec(num_dims); |
587 | for (int i = 0; i < num_dims; ++i) { |
588 | dimvec[i] = static_cast<int64_t>(dims[i]); |
589 | } |
590 | |
591 | // TODO(apassos) do we need to wrap the deallocator here to make sure to sync |
592 | // the device? |
593 | TF_ManagedBuffer* buf = |
594 | new TF_ManagedBuffer(data, len, deallocator, deallocator_arg, |
595 | /*owns_memory=*/false); |
596 | |
597 | tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype), |
598 | tensorflow::TensorShape(dimvec), buf); |
599 | buf->Unref(); |
600 | return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( |
601 | std::move(t), device, device, context)); |
602 | } |
603 | |
604 | // This function will block till the operation that produces `h` has |
605 | // completed. This is only valid on local TFE_TensorHandles. Returns the size in |
606 | // bytes of the memory pointed to by the device pointer returned above. |
607 | size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, |
608 | TF_Status* status) { |
609 | if (h == nullptr) { |
610 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
611 | return 0; |
612 | } |
613 | tensorflow::TensorHandle* handle = |
614 | tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); |
615 | if (handle->Type() != tensorflow::TensorHandle::LOCAL) { |
616 | status->status = tensorflow::errors::InvalidArgument( |
617 | "TFE_TensorHandleDeviceMemorySize may not be called on a " , |
618 | handle->TypeString(), " tensor handle." ); |
619 | return 0; |
620 | } |
621 | const tensorflow::Tensor* tensor; |
622 | status->status = handle->Tensor(&tensor); |
623 | if (!status->status.ok()) { |
624 | return 0; |
625 | } |
626 | return tensor->TotalBytes(); |
627 | } |
628 | |
629 | TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, |
630 | TF_Status* status) { |
631 | tensorflow::ImmediateExecutionOperation* new_op = |
632 | tensorflow::unwrap(ctx)->CreateOperation(); |
633 | status->status = new_op->Reset(op_or_function_name, nullptr); |
634 | if (!status->status.ok()) { |
635 | new_op->Release(); |
636 | new_op = nullptr; |
637 | } |
638 | return tensorflow::wrap(new_op); |
639 | } |
640 | |
641 | void TFE_DeleteOp(TFE_Op* op) { |
642 | if (op == nullptr) { |
643 | return; |
644 | } |
645 | |
646 | tensorflow::unwrap(op)->Release(); |
647 | } |
648 | |
649 | const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { |
650 | return tensorflow::unwrap(op)->Name().c_str(); |
651 | } |
652 | |
653 | TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { |
654 | return tensorflow::wrap(tensorflow::unwrap(op)->GetContext()); |
655 | } |
656 | |
657 | void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { |
658 | status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); |
659 | } |
660 | |
661 | const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) { |
662 | return tensorflow::unwrap(op)->DeviceName().c_str(); |
663 | } |
664 | |
665 | void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { |
666 | status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); |
667 | } |
668 | |
669 | void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, |
670 | TF_Status* status) { |
671 | status->status = tensorflow::unwrap(op)->AddInputList( |
672 | {reinterpret_cast<tensorflow::AbstractTensorHandle**>( |
673 | tensorflow::unwrap(inputs)), |
674 | static_cast<size_t>(num_inputs)}); |
675 | } |
676 | |
677 | extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) { |
678 | return tensorflow::unwrap(op)->GetInputs().size(); |
679 | } |
680 | |
681 | extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index, |
682 | TF_Status* status) { |
683 | return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]); |
684 | } |
685 | |
686 | TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, |
687 | unsigned char* is_list, TF_Status* status) { |
688 | TF_AttrType ret = TF_ATTR_INT; |
689 | const tensorflow::AttrTypeMap* attr_types_; |
690 | bool is_function; |
691 | status->status = tensorflow::AttrTypeMapForOp( |
692 | tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function); |
693 | if (!status->status.ok()) { |
694 | return ret; |
695 | } |
696 | status->status = |
697 | tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list); |
698 | return ret; |
699 | } |
700 | |
701 | TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, |
702 | const char* op_or_function_name, |
703 | const char* attr_name, unsigned char* is_list, |
704 | TF_Status* status) { |
705 | TF_AttrType ret; |
706 | TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); |
707 | if (status->status.ok()) { |
708 | ret = TFE_OpGetAttrType(op, attr_name, is_list, status); |
709 | } else { |
710 | ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. |
711 | } |
712 | TFE_DeleteOp(op); |
713 | return ret; |
714 | } |
715 | |
716 | void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, |
717 | size_t length) { |
718 | auto s = tensorflow::unwrap(op)->SetAttrString( |
719 | attr_name, static_cast<const char*>(value), length); |
720 | if (!s.ok()) { |
721 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
722 | } |
723 | } |
724 | |
725 | void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { |
726 | auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value); |
727 | if (!s.ok()) { |
728 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
729 | } |
730 | } |
731 | |
732 | void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { |
733 | auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value); |
734 | if (!s.ok()) { |
735 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
736 | } |
737 | } |
738 | |
739 | void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { |
740 | auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name, |
741 | (value == 0) ? false : true); |
742 | if (!s.ok()) { |
743 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
744 | } |
745 | } |
746 | |
747 | void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { |
748 | auto s = tensorflow::unwrap(op)->SetAttrType( |
749 | attr_name, static_cast<tensorflow::DataType>(value)); |
750 | if (!s.ok()) { |
751 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
752 | } |
753 | } |
754 | |
755 | void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, |
756 | const int num_dims, TF_Status* out_status) { |
757 | out_status->status = |
758 | tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims); |
759 | } |
760 | |
761 | void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, |
762 | const TFE_Op* value) { |
763 | auto s = tensorflow::unwrap(op)->SetAttrFunction( |
764 | attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value))); |
765 | if (!s.ok()) { |
766 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
767 | } |
768 | } |
769 | |
770 | void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, |
771 | const char* data, size_t length) { |
772 | auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length); |
773 | if (!s.ok()) { |
774 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
775 | } |
776 | } |
777 | |
778 | void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, |
779 | TF_Status* status) { |
780 | tensorflow::Tensor t; |
781 | status->status = TF_TensorToTensor(tensor, &t); |
782 | tensorflow::TensorInterface interface(t); |
783 | status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface); |
784 | } |
785 | |
786 | void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, |
787 | const void* const* values, const size_t* lengths, |
788 | int num_values) { |
789 | auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths, |
790 | num_values); |
791 | if (!s.ok()) { |
792 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
793 | } |
794 | } |
795 | |
796 | void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, |
797 | const float* values, int num_values) { |
798 | auto s = |
799 | tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values); |
800 | if (!s.ok()) { |
801 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
802 | } |
803 | } |
804 | |
805 | void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, |
806 | const int64_t* values, int num_values) { |
807 | auto s = |
808 | tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values); |
809 | if (!s.ok()) { |
810 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
811 | } |
812 | } |
813 | |
814 | void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, |
815 | const TF_DataType* values, int num_values) { |
816 | auto s = tensorflow::unwrap(op)->SetAttrTypeList( |
817 | attr_name, reinterpret_cast<const tensorflow::DataType*>(values), |
818 | num_values); |
819 | if (!s.ok()) { |
820 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
821 | } |
822 | } |
823 | |
824 | void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, |
825 | const unsigned char* values, int num_values) { |
826 | auto s = |
827 | tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values); |
828 | if (!s.ok()) { |
829 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
830 | } |
831 | } |
832 | |
833 | void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, |
834 | const int64_t** dims, const int* num_dims, |
835 | int num_values, TF_Status* out_status) { |
836 | out_status->status = tensorflow::unwrap(op)->SetAttrShapeList( |
837 | attr_name, dims, num_dims, num_values); |
838 | } |
839 | |
840 | void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, |
841 | const TFE_Op** value, int num_values) { |
842 | auto s = tensorflow::unwrap(op)->SetAttrFunctionList( |
843 | attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>( |
844 | tensorflow::unwrap(value)), |
845 | static_cast<size_t>(num_values)}); |
846 | if (!s.ok()) { |
847 | LOG(WARNING) << "Unable to set attribute: " << attr_name; |
848 | } |
849 | } |
850 | |
851 | void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name, |
852 | const void* proto, size_t proto_len, |
853 | TF_Status* status) { |
854 | tensorflow::AttrValue attr_value; |
855 | if (!attr_value.ParseFromArray(proto, proto_len)) { |
856 | status->status = |
857 | tensorflow::errors::InvalidArgument("Unparseable AttrValue proto" ); |
858 | return; |
859 | } |
860 | if (op == nullptr) { |
861 | status->status = tensorflow::errors::InvalidArgument( |
862 | "Got a null or uninitialized `op` argument" ); |
863 | return; |
864 | } |
865 | tensorflow::EagerOperation* operation = |
866 | OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op))); |
867 | operation->MutableAttrs()->Set(attr_name, attr_value); |
868 | } |
869 | |
870 | TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, |
871 | const char* input_name, |
872 | TF_Status* status) { |
873 | int ret = -1; |
874 | status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret); |
875 | return ret; |
876 | } |
877 | |
878 | TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, |
879 | const char* output_name, |
880 | TF_Status* status) { |
881 | int ret = -1; |
882 | status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret); |
883 | return ret; |
884 | } |
885 | |
886 | void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, |
887 | TF_Status* status) { |
888 | tensorflow::ImmediateExecutionOperation* unwrapped_op = |
889 | tensorflow::unwrap(op); |
890 | |
891 | status->status = |
892 | unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute( |
893 | unwrapped_op, |
894 | reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>( |
895 | retvals), |
896 | num_retvals); |
897 | } |
898 | |
899 | TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, |
900 | TFE_Context* ctx, |
901 | const char* device_name, |
902 | TF_Status* status) { |
903 | if (h == nullptr) { |
904 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
905 | return nullptr; |
906 | } |
907 | |
908 | tensorflow::ImmediateExecutionContext* unwrapped_ctx = |
909 | tensorflow::unwrap(ctx); |
910 | |
911 | auto* result = |
912 | unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice( |
913 | unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status); |
914 | |
915 | if (status->status.ok()) { |
916 | return tensorflow::wrap(result); |
917 | } |
918 | return nullptr; |
919 | } |
920 | |
921 | void TFE_ContextAddFunctionDef(TFE_Context* ctx, |
922 | const char* serialized_function_def, size_t size, |
923 | TF_Status* status) { |
924 | tensorflow::FunctionDef function_def; |
925 | if (!function_def.ParseFromArray(serialized_function_def, size)) { |
926 | status->status = |
927 | tensorflow::errors::InvalidArgument("Invalid FunctionDef proto" ); |
928 | return; |
929 | } |
930 | |
931 | AnnotateEagerRuntimeConstructionContext(function_def); |
932 | status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def); |
933 | } |
934 | |
935 | void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, |
936 | TF_Status* status) { |
937 | AnnotateEagerRuntimeConstructionContext(function->fdef); |
938 | status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces( |
939 | function->fdef, function->stack_traces); |
940 | } |
941 | |
942 | void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, |
943 | TF_Status* status) { |
944 | status->status = tensorflow::unwrap(ctx)->RemoveFunction(name); |
945 | } |
946 | |
947 | unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { |
948 | return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr; |
949 | } |
950 | |
951 | void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { |
952 | tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); |
953 | } |
954 | |
955 | void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { |
956 | tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); |
957 | } |
958 | |
959 | } // extern "C" |
960 | |
961 | TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, |
962 | TF_Status* status) { |
963 | return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t)); |
964 | } |
965 | |
966 | void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, |
967 | TF_Status* status) { |
968 | auto* context = tensorflow::unwrap(ctx); |
969 | status->status = context->AsyncWait(); |
970 | if (!status->status.ok()) return; |
971 | auto run_metadata = context->ExportRunMetadata(); |
972 | status->status = MessageToBuffer(*run_metadata, buf); |
973 | } |
974 | |
975 | namespace { |
976 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, |
977 | TF_Status* status) { |
978 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); |
979 | for (const auto& attr : func.attr()) { |
980 | if (!status->status.ok()) return nullptr; |
981 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); |
982 | if (!status->status.ok()) return nullptr; |
983 | } |
984 | return func_op; |
985 | } |
986 | } // namespace |
987 | |
988 | void TFE_ContextStartStep(TFE_Context* ctx) { |
989 | tensorflow::unwrap(ctx)->StartStep(); |
990 | } |
991 | |
992 | void TFE_ContextEndStep(TFE_Context* ctx) { |
993 | tensorflow::unwrap(ctx)->EndStep(); |
994 | } |
995 | |
996 | const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) { |
997 | return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs()); |
998 | } |
999 | |
1000 | void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { |
1001 | tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs)); |
1002 | } |
1003 | |
1004 | void (const TFE_OpAttrs* attrs, TF_Buffer* buf, |
1005 | TF_Status* status) { |
1006 | tensorflow::NameAttrList name_and_attrs; |
1007 | tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs); |
1008 | status->status = MessageToBuffer(name_and_attrs, buf); |
1009 | } |
1010 | |
1011 | namespace tensorflow { |
1012 | void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, |
1013 | const tensorflow::AttrValue& default_value, |
1014 | const char* attr_name, TF_Status* status) { |
1015 | switch (default_value.value_case()) { |
1016 | case tensorflow::AttrValue::kS: { |
1017 | const string& v = default_value.s(); |
1018 | TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); |
1019 | break; |
1020 | } |
1021 | case tensorflow::AttrValue::kI: |
1022 | TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i())); |
1023 | break; |
1024 | case tensorflow::AttrValue::kF: |
1025 | TFE_OpSetAttrFloat(op, attr_name, default_value.f()); |
1026 | break; |
1027 | case tensorflow::AttrValue::kB: |
1028 | TFE_OpSetAttrBool(op, attr_name, default_value.b()); |
1029 | break; |
1030 | case tensorflow::AttrValue::kType: |
1031 | TFE_OpSetAttrType(op, attr_name, |
1032 | static_cast<TF_DataType>(default_value.type())); |
1033 | break; |
1034 | case tensorflow::AttrValue::kShape: { |
1035 | const auto& tensor_shape = default_value.shape(); |
1036 | if (tensor_shape.unknown_rank()) { |
1037 | TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); |
1038 | } else { |
1039 | const auto num_dims = tensor_shape.dim_size(); |
1040 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); |
1041 | for (int i = 0; i < num_dims; ++i) { |
1042 | dims[i] = tensor_shape.dim(i).size(); |
1043 | } |
1044 | TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); |
1045 | } |
1046 | } break; |
1047 | case tensorflow::AttrValue::kFunc: { |
1048 | const auto func_op = GetFunc(ctx, default_value.func(), status); |
1049 | if (!status->status.ok()) return; |
1050 | // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList |
1051 | // require TFE_Op* and just convert it internally a NameAttrValue, so |
1052 | // consider adding an overload to the C API to make this case easier. |
1053 | TFE_OpSetAttrFunction(op, attr_name, func_op); |
1054 | TFE_DeleteOp(func_op); |
1055 | } break; |
1056 | case tensorflow::AttrValue::kList: { |
1057 | // String |
1058 | if (const int s_size = default_value.list().s_size()) { |
1059 | absl::InlinedVector<const void*, 4> values_vector; |
1060 | values_vector.reserve(s_size); |
1061 | absl::InlinedVector<size_t, 4> lengths_vector; |
1062 | lengths_vector.reserve(s_size); |
1063 | for (int i = 0; i < s_size; ++i) { |
1064 | const string& v = default_value.list().s(i); |
1065 | values_vector.push_back(v.data()); |
1066 | lengths_vector.push_back(v.size()); |
1067 | } |
1068 | TFE_OpSetAttrStringList(op, attr_name, values_vector.data(), |
1069 | lengths_vector.data(), s_size); |
1070 | } |
1071 | |
1072 | // Int |
1073 | if (const int i_size = default_value.list().i_size()) { |
1074 | absl::InlinedVector<int64_t, 4> i_vector; |
1075 | i_vector.reserve(i_size); |
1076 | for (int i = 0; i < i_size; ++i) { |
1077 | i_vector.push_back(default_value.list().i(i)); |
1078 | } |
1079 | TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size); |
1080 | } |
1081 | // Float |
1082 | if (const int f_size = default_value.list().f_size()) { |
1083 | absl::InlinedVector<float, 4> f_vector; |
1084 | f_vector.reserve(f_size); |
1085 | for (int i = 0; i < f_size; ++i) { |
1086 | f_vector.push_back(default_value.list().f(i)); |
1087 | } |
1088 | TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size); |
1089 | } |
1090 | // Bool |
1091 | if (const int b_size = default_value.list().b_size()) { |
1092 | absl::InlinedVector<unsigned char, 4> b_vector; |
1093 | b_vector.reserve(b_size); |
1094 | for (int i = 0; i < b_size; i++) { |
1095 | b_vector.push_back(default_value.list().b(i)); |
1096 | } |
1097 | TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size); |
1098 | } |
1099 | // Type |
1100 | if (const int type_size = default_value.list().type_size()) { |
1101 | absl::InlinedVector<unsigned int, 4> type_vector; |
1102 | type_vector.reserve(type_size); |
1103 | for (int i = 0; i < type_size; ++i) { |
1104 | type_vector.push_back(default_value.list().type(i)); |
1105 | } |
1106 | TFE_OpSetAttrTypeList( |
1107 | op, attr_name, |
1108 | reinterpret_cast<const TF_DataType*>(type_vector.data()), |
1109 | type_size); |
1110 | } |
1111 | |
1112 | // Rest are not supported. |
1113 | if (default_value.list().shape_size() > 0 || |
1114 | default_value.list().func_size() > 0 || |
1115 | default_value.list().tensor_size() > 0) { |
1116 | TF_SetStatus( |
1117 | status, TF_UNIMPLEMENTED, |
1118 | tensorflow::strings::StrCat("Unable to get setfor default value: " , |
1119 | default_value.DebugString()) |
1120 | .data()); |
1121 | } |
1122 | } break; |
1123 | case tensorflow::AttrValue::kTensor: |
1124 | TF_FALLTHROUGH_INTENDED; |
1125 | case tensorflow::AttrValue::kPlaceholder: |
1126 | TF_FALLTHROUGH_INTENDED; |
1127 | case tensorflow::AttrValue::VALUE_NOT_SET: |
1128 | TF_SetStatus( |
1129 | status, TF_UNIMPLEMENTED, |
1130 | tensorflow::strings::StrCat("Unable to get setfor default value: " , |
1131 | default_value.DebugString()) |
1132 | .data()); |
1133 | } |
1134 | } |
1135 | } // namespace tensorflow |
1136 | |
1137 | namespace { |
1138 | TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context, |
1139 | TFE_TensorHandle** handles, |
1140 | int num_handles, TF_Status* status, |
1141 | void* device_info) { |
1142 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
1143 | "This custom device does not support packing tensors." ); |
1144 | return nullptr; |
1145 | } |
1146 | } // namespace |
1147 | |
1148 | extern "C" { |
1149 | |
1150 | bool TFE_IsCustomDevice(TFE_Context* ctx, const char* device_name) { |
1151 | return tensorflow::unwrap(ctx)->IsCustomDevice(device_name); |
1152 | } |
1153 | |
1154 | void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, |
1155 | const char* device_name, void* device_info, |
1156 | TF_Status* status) { |
1157 | // Fill in default values for optional functionality. |
1158 | if (device.pack == nullptr) { |
1159 | device.pack = &DefaultCustomDevicePack; |
1160 | } |
1161 | auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>( |
1162 | ctx, device, device_info, device_name); |
1163 | status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice( |
1164 | device_name, std::move(custom_device)); |
1165 | } |
1166 | |
1167 | } // extern "C" |
1168 | |