1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/c_api_experimental.h" |
17 | |
18 | #include "absl/strings/substitute.h" |
19 | #include "tensorflow/c/c_api.h" |
20 | #include "tensorflow/c/c_api_internal.h" |
21 | #include "tensorflow/c/checkpoint_reader.h" |
22 | #include "tensorflow/c/eager/c_api.h" |
23 | #include "tensorflow/c/eager/c_api_internal.h" |
24 | #include "tensorflow/c/eager/tfe_context_internal.h" |
25 | #include "tensorflow/c/eager/tfe_op_internal.h" |
26 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
27 | #include "tensorflow/c/tf_buffer_internal.h" |
28 | #include "tensorflow/compiler/jit/flags.h" |
29 | #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
30 | #include "tensorflow/core/common_runtime/eager/context.h" |
31 | #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
32 | #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h" |
33 | #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" |
34 | #include "tensorflow/core/framework/collective.h" |
35 | #include "tensorflow/core/framework/node_def.pb.h" |
36 | #include "tensorflow/core/framework/shape_inference.h" |
37 | #include "tensorflow/core/framework/tensor.pb.h" |
38 | #include "tensorflow/core/graph/graph.h" |
39 | #include "tensorflow/core/graph/node_builder.h" |
40 | #include "tensorflow/core/platform/blocking_counter.h" |
41 | #include "tensorflow/core/platform/casts.h" |
42 | #include "tensorflow/core/platform/env.h" |
43 | #include "tensorflow/core/platform/init_main.h" |
44 | #include "tensorflow/core/platform/mutex.h" |
45 | #include "tensorflow/core/platform/net.h" |
46 | #include "tensorflow/core/platform/platform.h" |
47 | #include "tensorflow/core/platform/strcat.h" |
48 | #include "tensorflow/core/protobuf/config.pb.h" |
49 | #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
50 | |
51 | using tensorflow::FunctionDef; |
52 | using tensorflow::Node; |
53 | using tensorflow::NodeBuilder; |
54 | using tensorflow::Status; |
55 | using tensorflow::errors::InvalidArgument; |
56 | |
57 | namespace { |
58 | typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> |
59 | UniqueFuncPtr; |
60 | } |
61 | |
62 | // struct TF_Operation { tensorflow::Node node; }; |
63 | static TF_Operation* ToTF_Operation(Node* node) { |
64 | return static_cast<TF_Operation*>(static_cast<void*>(node)); |
65 | } |
66 | |
67 | void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { |
68 | tensorflow::ConfigProto& config = options->options.config; |
69 | auto* optimizer_options = |
70 | config.mutable_graph_options()->mutable_optimizer_options(); |
71 | if (enable) { |
72 | optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); |
73 | |
74 | // These XLA flags are needed to trigger XLA properly from C (more generally |
75 | // non-Python) clients. If this API is called again with `enable` set to |
76 | // false, it is safe to keep these flag values as is. |
77 | tensorflow::MarkForCompilationPassFlags* flags = |
78 | tensorflow::GetMarkForCompilationPassFlags(); |
79 | flags->tf_xla_cpu_global_jit = true; |
80 | flags->tf_xla_min_cluster_size = 1; |
81 | } else { |
82 | optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); |
83 | } |
84 | } |
85 | |
86 | unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) { |
87 | tensorflow::BuildXlaOpsPassFlags* flags = |
88 | tensorflow::GetBuildXlaOpsPassFlags(); |
89 | bool original = flags->tf_xla_enable_lazy_compilation; |
90 | flags->tf_xla_enable_lazy_compilation = enable; |
91 | return original; |
92 | } |
93 | |
94 | unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) { |
95 | tensorflow::MarkForCompilationPassFlags* flags = |
96 | tensorflow::GetMarkForCompilationPassFlags(); |
97 | bool original = flags->tf_xla_cpu_global_jit; |
98 | flags->tf_xla_cpu_global_jit = static_cast<bool>(enable); |
99 | return static_cast<unsigned char>(original); |
100 | } |
101 | |
102 | void TF_SetXlaAutoJitMode(const char* mode) { |
103 | tensorflow::SetXlaAutoJitFlagFromFlagString(mode); |
104 | } |
105 | |
106 | unsigned char TF_GetXlaAutoJitEnabled() { |
107 | tensorflow::XlaAutoJitFlag flag = |
108 | tensorflow::GetMarkForCompilationPassFlags()->xla_auto_jit_flag; |
109 | return static_cast<unsigned char>(flag.optimization_level_single_gpu > 0 || |
110 | flag.optimization_level_general > 0); |
111 | } |
112 | |
113 | unsigned char TF_GetXlaConstantFoldingDisabled() { |
114 | return static_cast<unsigned char>( |
115 | tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding); |
116 | } |
117 | |
118 | void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) { |
119 | tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding = |
120 | static_cast<bool>(should_enable); |
121 | } |
122 | |
123 | void TF_SetXlaMinClusterSize(int size) { |
124 | tensorflow::MarkForCompilationPassFlags* flags = |
125 | tensorflow::GetMarkForCompilationPassFlags(); |
126 | flags->tf_xla_min_cluster_size = size; |
127 | } |
128 | |
129 | TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, |
130 | unsigned char gpu_memory_allow_growth, |
131 | unsigned int num_cpu_devices) { |
132 | tensorflow::ConfigProto config; |
133 | auto* optimizer_options = |
134 | config.mutable_graph_options()->mutable_optimizer_options(); |
135 | if (enable_xla_compilation) { |
136 | optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); |
137 | |
138 | // These XLA flags are needed to trigger XLA properly from C (more generally |
139 | // non-Python) clients. If this API is called again with `enable` set to |
140 | // false, it is safe to keep these flag values as is. |
141 | tensorflow::MarkForCompilationPassFlags* flags = |
142 | tensorflow::GetMarkForCompilationPassFlags(); |
143 | flags->tf_xla_cpu_global_jit = true; |
144 | flags->tf_xla_min_cluster_size = 1; |
145 | } else { |
146 | optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); |
147 | } |
148 | |
149 | auto* gpu_options = config.mutable_gpu_options(); |
150 | gpu_options->set_allow_growth(gpu_memory_allow_growth); |
151 | |
152 | (*config.mutable_device_count())["CPU" ] = num_cpu_devices; |
153 | |
154 | // TODO(b/113217601): This is needed for EagerContext::runner_ to use a |
155 | // threadpool, so that we avoid the possibility of running the runner_ in the |
156 | // threadpool of GPU event mgr, as that can trigger more callbacks to be |
157 | // scheduled on that same threadpool, causing a deadlock in cases where the |
158 | // caller of event_mgr->ThenExecute() blocks on the completion of the callback |
159 | // (as in the case of ConstOp kernel creation on GPU, which involves copying a |
160 | // CPU tensor to GPU). |
161 | // Setting a larger thread pool does not help with the Swift caller, as we use |
162 | // a different TFE context for each thread of execution (for running graph |
163 | // functions, and their send/recvs corountines). |
164 | config.set_inter_op_parallelism_threads(1); |
165 | |
166 | TF_Buffer* ret = TF_NewBuffer(); |
167 | TF_CHECK_OK(MessageToBuffer(config, ret)); |
168 | return ret; |
169 | } |
170 | |
171 | TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) { |
172 | tensorflow::RunOptions options; |
173 | if (enable_full_trace) { |
174 | options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); |
175 | } else { |
176 | options.set_trace_level(tensorflow::RunOptions::NO_TRACE); |
177 | } |
178 | TF_Buffer* ret = TF_NewBuffer(); |
179 | TF_CHECK_OK(MessageToBuffer(options, ret)); |
180 | return ret; |
181 | } |
182 | |
183 | const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { |
184 | tensorflow::mutex_lock c(graph->mu); |
185 | const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); |
186 | *len = debug_str.size(); |
187 | char* ret = static_cast<char*>(malloc(*len + 1)); |
188 | memcpy(ret, debug_str.c_str(), *len + 1); |
189 | return ret; |
190 | } |
191 | |
192 | char* TF_FunctionDebugString(TF_Function* func, size_t* len) { |
193 | const auto& debug_str = DebugString(func->fdef); |
194 | *len = debug_str.size(); |
195 | char* ret = static_cast<char*>(malloc(*len + 1)); |
196 | memcpy(ret, debug_str.c_str(), *len + 1); |
197 | return ret; |
198 | } |
199 | |
200 | // On success, returns a set of TF_Function instances from `text_proto` of |
201 | // GraphDef type. These functions must be deleted by calling TF_DeleteFunction. |
202 | // |
203 | // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto, |
204 | // before creating a TF_Function out of the possibly mutated proto. |
205 | static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto( |
206 | const char* text_proto, |
207 | std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) { |
208 | tensorflow::GraphDef gdef; |
209 | if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) { |
210 | status->status = tensorflow::errors::Internal( |
211 | "Invalid text proto for GraphDef: " , text_proto); |
212 | return {}; |
213 | } |
214 | const auto& fdef_lib = gdef.library(); |
215 | if (fdef_lib.gradient_size() > 0) { |
216 | status->status = tensorflow::errors::Internal( |
217 | "GradientDef is not supported in reading Dataset related functions: " , |
218 | text_proto); |
219 | return {}; |
220 | } |
221 | std::vector<UniqueFuncPtr> ret; |
222 | for (const FunctionDef& fdef : fdef_lib.function()) { |
223 | // Make a copy so that we can mutate it. |
224 | FunctionDef fdef_to_load = fdef; |
225 | if (mutate_proto_func) { |
226 | (*mutate_proto_func)(&fdef_to_load); |
227 | } |
228 | VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString(); |
229 | std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong()); |
230 | fdef_to_load.SerializeToArray(binary_proto_buf.data(), |
231 | binary_proto_buf.size()); |
232 | TF_Function* func = TF_FunctionImportFunctionDef( |
233 | binary_proto_buf.data(), binary_proto_buf.size(), status); |
234 | if (!status->status.ok()) return {}; |
235 | ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction)); |
236 | } |
237 | return ret; |
238 | } |
239 | |
240 | TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, |
241 | TF_Status* status) { |
242 | assert(session); |
243 | { |
244 | tensorflow::mutex_lock c(session->graph->mu); |
245 | VLOG(1) << "Dequeuing named tensor with id " << tensor_id |
246 | << ", with input graph: " |
247 | << session->graph->graph.ToGraphDefDebug().DebugString(); |
248 | } |
249 | |
250 | TF_Operation* dequeue_op = TF_GraphOperationByName( |
251 | session->graph, |
252 | tensorflow::strings::StrCat("fifo_queue_dequeue_" , tensor_id).c_str()); |
253 | if (dequeue_op == nullptr) { |
254 | status->status = tensorflow::errors::Internal( |
255 | "Unable to find the dequeue node in the TF graph." ); |
256 | return nullptr; |
257 | } |
258 | |
259 | VLOG(1) << "Running the dequeue op" ; |
260 | TF_Output output{dequeue_op, 0}; |
261 | TF_Tensor* ret; |
262 | TF_SessionRun(session, /*run_options*/ nullptr, |
263 | // input related parameters |
264 | /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, |
265 | // output related parameters |
266 | /*outputs*/ &output, /*output_values*/ &ret, |
267 | /*noutputs*/ 1, |
268 | /*targets*/ nullptr, /*ntargets*/ 0, |
269 | /*run_metadata*/ nullptr, status); |
270 | if (VLOG_IS_ON(1) && status->status.ok()) { |
271 | tensorflow::Tensor tensor; |
272 | if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) { |
273 | VLOG(1) << "Dequeued tensor content: " << tensor.DebugString(); |
274 | } |
275 | } |
276 | return ret; |
277 | } |
278 | |
279 | void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, |
280 | TF_Tensor* tensor, TF_Status* status) { |
281 | assert(session); |
282 | { |
283 | tensorflow::mutex_lock c(session->graph->mu); |
284 | if (VLOG_IS_ON(1)) { |
285 | VLOG(1) << "Enqueuing named tensor with id " << tensor_id |
286 | << ", with input graph: " |
287 | << session->graph->graph.ToGraphDefDebug().DebugString(); |
288 | tensorflow::Tensor internal_tensor; |
289 | if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { |
290 | VLOG(1) << "Enqueu'ing tensor content: " |
291 | << internal_tensor.DebugString(); |
292 | } |
293 | } |
294 | } |
295 | |
296 | TF_Operation* enqueue_op = TF_GraphOperationByName( |
297 | session->graph, |
298 | tensorflow::strings::StrCat("fifo_queue_enqueue_" , tensor_id).c_str()); |
299 | if (enqueue_op == nullptr) { |
300 | status->status = tensorflow::errors::Internal( |
301 | "Unable to find the enqueue node in the TF graph." ); |
302 | return; |
303 | } |
304 | |
305 | TF_Operation* placeholder_op = TF_GraphOperationByName( |
306 | session->graph, |
307 | tensorflow::strings::StrCat("arg_tensor_enqueue_" , tensor_id).c_str()); |
308 | if (placeholder_op == nullptr) { |
309 | status->status = tensorflow::errors::Internal( |
310 | "Unable to find the placeholder node as input to enqueue in the TF " |
311 | "graph." ); |
312 | return; |
313 | } |
314 | |
315 | VLOG(1) << "Running the enqueue op" ; |
316 | TF_Output input{placeholder_op, 0}; |
317 | TF_SessionRun(session, /*run_options*/ nullptr, |
318 | // input related parameters |
319 | /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, |
320 | // output related parameters |
321 | /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, |
322 | /*targets*/ &enqueue_op, /*ntargets*/ 1, |
323 | /*run_metadata*/ nullptr, status); |
324 | VLOG(1) << "Enqueuing is done." ; |
325 | } |
326 | |
327 | TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { |
328 | tensorflow::ServerDef server_def; |
329 | if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, |
330 | &server_def)) { |
331 | status->status = tensorflow::errors::Internal( |
332 | "Invalid text proto for ServerDef: " , text_proto); |
333 | return nullptr; |
334 | } |
335 | status->status = tensorflow::Status(); |
336 | TF_Buffer* ret = TF_NewBuffer(); |
337 | TF_CHECK_OK(MessageToBuffer(server_def, ret)); |
338 | return ret; |
339 | } |
340 | |
341 | void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { |
342 | status->status = tensorflow::errors::Internal(errMsg); |
343 | } |
344 | |
345 | struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader { |
346 | using tensorflow::checkpoint::CheckpointReader::CheckpointReader; |
347 | std::vector<std::string> variable_list; |
348 | }; |
349 | |
350 | TF_CheckpointReader* TF_NewCheckpointReader(const char* filename, |
351 | TF_Status* status) { |
352 | TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status); |
353 | if (!status->status.ok()) { |
354 | TF_DeleteCheckpointReader(reader); |
355 | return nullptr; |
356 | } |
357 | const auto& m = reader->GetVariableToDataTypeMap(); |
358 | for (auto it = m.begin(); it != m.end(); ++it) |
359 | reader->variable_list.push_back(it->first); |
360 | std::sort(reader->variable_list.begin(), reader->variable_list.end()); |
361 | return reader; |
362 | } |
363 | |
364 | void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; } |
365 | |
366 | int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader, |
367 | const char* name) { |
368 | return reader->HasTensor(name); |
369 | } |
370 | |
371 | const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader, |
372 | int index) { |
373 | return reader->variable_list[index].c_str(); |
374 | } |
375 | |
376 | int TF_CheckpointReaderSize(TF_CheckpointReader* reader) { |
377 | return reader->variable_list.size(); |
378 | } |
379 | |
380 | TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader, |
381 | const char* name) { |
382 | const auto& m = reader->GetVariableToDataTypeMap(); |
383 | return static_cast<TF_DataType>(m.at(name)); |
384 | } |
385 | |
386 | TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader, |
387 | const char* name, TF_Status* status) { |
388 | std::unique_ptr<tensorflow::Tensor> tensor; |
389 | reader->GetTensor(name, &tensor, status); |
390 | if (!status->status.ok()) return nullptr; |
391 | return tensorflow::TF_TensorFromTensor(*tensor, &status->status); |
392 | } |
393 | |
394 | void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader, |
395 | const char* name, int64_t* dims, |
396 | int num_dims, TF_Status* status) { |
397 | const auto& shape = reader->GetVariableToShapeMap().at(name); |
398 | int rank = shape.dims(); |
399 | if (num_dims != rank) { |
400 | status->status = InvalidArgument("Expected rank is " , num_dims, |
401 | " but actual rank is " , rank); |
402 | return; |
403 | } |
404 | for (int i = 0; i < num_dims; i++) { |
405 | dims[i] = shape.dim_size(i); |
406 | } |
407 | } |
408 | |
409 | int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader, |
410 | const char* name) { |
411 | const auto& m = reader->GetVariableToShapeMap(); |
412 | return m.at(name).dims(); |
413 | } |
414 | |
415 | // This builder is used in the eager API to build a NodeDef. |
416 | struct TF_AttrBuilder : public tensorflow::AttrBuilder { |
417 | using tensorflow::AttrBuilder::AttrBuilder; |
418 | // The string buffers to make sure that any `attr_name` we pass into |
419 | // `builder->Set()` will outlive the subsequent |
420 | // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`. |
421 | std::set<std::string> attr_names; |
422 | }; |
423 | |
424 | TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) { |
425 | return new TF_AttrBuilder(op_name); |
426 | } |
427 | |
428 | void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; } |
429 | |
430 | void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, |
431 | TF_DataType value) { |
432 | auto iter = builder->attr_names.insert(attr_name).first; |
433 | builder->Set(*iter, static_cast<tensorflow::DataType>(value)); |
434 | } |
435 | |
436 | void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, |
437 | const TF_DataType* values, int num_values) { |
438 | auto iter = builder->attr_names.insert(attr_name).first; |
439 | builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>( |
440 | reinterpret_cast<const tensorflow::DataType*>(values), |
441 | num_values)); |
442 | } |
443 | |
444 | void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, |
445 | const char* device_type, |
446 | TF_Status* status) { |
447 | status->status = tensorflow::FindKernelDef( |
448 | tensorflow::DeviceType(device_type), builder->BuildNodeDef(), |
449 | /* def = */ nullptr, /* kernel_class_name = */ nullptr); |
450 | } |
451 | |
452 | const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, |
453 | TF_Status* status) { |
454 | const tensorflow::OpDef* op_def = nullptr; |
455 | status->status = |
456 | tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def); |
457 | if (!status->status.ok()) return nullptr; |
458 | |
459 | if (input_index >= op_def->input_arg_size() || input_index < 0) { |
460 | status->status = tensorflow::errors::InvalidArgument( |
461 | input_index, " out of range for " , op_name); |
462 | return nullptr; |
463 | } |
464 | |
465 | const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index]; |
466 | |
467 | if (input_arg.number_attr().empty()) { |
468 | status->status = tensorflow::errors::NotFound( |
469 | op_name, " does not have number_attr() defined." ); |
470 | return nullptr; |
471 | } |
472 | |
473 | // The returned string is owned by OpRegistry, so liveness is not a concern. |
474 | return input_arg.number_attr().c_str(); |
475 | } |
476 | |
477 | int TF_OpIsStateful(const char* op_type, TF_Status* status) { |
478 | const tensorflow::OpRegistrationData* op_reg_data; |
479 | status->status = |
480 | tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data); |
481 | if (!status->status.ok()) { |
482 | return 0; |
483 | } |
484 | return op_reg_data->op_def.is_stateful(); |
485 | } |
486 | |
487 | void TF_InitMain(const char* usage, int* argc, char*** argv) { |
488 | tensorflow::port::InitMain(usage, argc, argv); |
489 | } |
490 | |
491 | int TF_PickUnusedPortOrDie() { |
492 | return tensorflow::internal::PickUnusedPortOrDie(); |
493 | } |
494 | |
495 | TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, |
496 | void* data, size_t len, |
497 | TF_Status* status) { |
498 | auto dtype = static_cast<tensorflow::DataType>(data_type); |
499 | DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype)); |
500 | |
501 | tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); |
502 | std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); |
503 | |
504 | status->status = ::tensorflow::OkStatus(); |
505 | return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); |
506 | } |
507 | |
508 | // Set server_def on the context, possibly updating it. |
509 | TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, |
510 | const void* proto, |
511 | size_t proto_len, |
512 | TF_Status* status) { |
513 | tensorflow::ServerDef server_def; |
514 | if (!server_def.ParseFromArray(proto, proto_len)) { |
515 | status->status = tensorflow::errors::InvalidArgument( |
516 | "Invalid tensorflow.ServerDef protocol buffer" ); |
517 | return; |
518 | } |
519 | status->status = tensorflow::unwrap(ctx)->EnableCollectiveOps(server_def); |
520 | } |
521 | |
522 | TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, |
523 | TF_Status* status) { |
524 | tensorflow::EagerContext* context = |
525 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
526 | auto collective_executor_handle = context->GetCollectiveExecutorHandle(); |
527 | collective_executor_handle->get()->StartAbort(status->status); |
528 | } |
529 | |
530 | TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( |
531 | TFE_Context* ctx, const char* task, int64_t timeout_in_ms, |
532 | TF_Status* status) { |
533 | tensorflow::EagerContext* context = |
534 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
535 | auto collective_executor_handle = context->GetCollectiveExecutorHandle(); |
536 | tensorflow::Notification done; |
537 | collective_executor_handle->get()->remote_access()->CheckPeerHealth( |
538 | task, timeout_in_ms, [&done, status](const Status& s) { |
539 | status->status = s; |
540 | done.Notify(); |
541 | }); |
542 | done.WaitForNotification(); |
543 | } |
544 | |
545 | TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { |
546 | TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; |
547 | result->num_items = num_items; |
548 | result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items](); |
549 | return result; |
550 | } |
551 | |
552 | void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index, |
553 | const int64_t* dims, int num_dims) { |
554 | DCHECK(index >= 0 && index < shape_list->num_items); |
555 | TF_ShapeAndType& shape = shape_list->items[index]; |
556 | DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!" ; |
557 | DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!" ; |
558 | shape.num_dims = num_dims; |
559 | shape.dims = new int64_t[num_dims]; |
560 | memcpy(shape.dims, dims, sizeof(int64_t) * num_dims); |
561 | } |
562 | |
563 | void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list, |
564 | int index) { |
565 | DCHECK(index >= 0 && index < shape_list->num_items); |
566 | TF_ShapeAndType& shape = shape_list->items[index]; |
567 | DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!" ; |
568 | shape.num_dims = -1; |
569 | shape.dims = nullptr; |
570 | } |
571 | |
572 | void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index, |
573 | TF_DataType dtype) { |
574 | DCHECK(index >= 0 && index < shape_list->num_items); |
575 | TF_ShapeAndType& shape_and_type = shape_list->items[index]; |
576 | shape_and_type.dtype = dtype; |
577 | } |
578 | |
579 | void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) { |
580 | if (shape_list == nullptr) return; |
581 | for (size_t i = 0; i < shape_list->num_items; ++i) { |
582 | delete[] shape_list->items[i].dims; |
583 | } |
584 | delete[] shape_list->items; |
585 | delete shape_list; |
586 | } |
587 | |
588 | void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, |
589 | int num_items) { |
590 | if (shape_list_array == nullptr) return; |
591 | for (int i = 0; i < num_items; ++i) { |
592 | TF_DeleteShapeAndTypeList(shape_list_array[i]); |
593 | } |
594 | delete[] shape_list_array; |
595 | } |
596 | |
597 | namespace tensorflow { |
598 | Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); |
599 | |
600 | // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file). |
601 | Status LoadPluggableDeviceLibrary(const char* library_filename, void** result); |
602 | } // namespace tensorflow |
603 | |
604 | void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, |
605 | TF_Tensor** input_tensors, |
606 | TF_ShapeAndTypeList* input_tensors_as_shapes, |
607 | TF_ShapeAndTypeList** input_resource_shapes_and_types, |
608 | TF_ShapeAndTypeList** output_shapes, |
609 | TF_ShapeAndTypeList*** output_resource_shapes_and_types, |
610 | TF_Status* status) { |
611 | using tensorflow::NodeDef; |
612 | using tensorflow::OpRegistrationData; |
613 | using tensorflow::Tensor; |
614 | using tensorflow::shape_inference::DimensionHandle; |
615 | using tensorflow::shape_inference::InferenceContext; |
616 | using tensorflow::shape_inference::ShapeAndType; |
617 | using tensorflow::shape_inference::ShapeHandle; |
618 | |
619 | const int num_inputs = input_shapes->num_items; |
620 | NodeDef node_def; |
621 | tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op); |
622 | node_def.set_name(op->Name()); |
623 | node_def.set_op(op->Name()); |
624 | for (int i = 0; i < num_inputs; ++i) { |
625 | node_def.add_input("dummy_input" ); |
626 | } |
627 | OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr()); |
628 | |
629 | const tensorflow::OpRegistrationData* op_reg_data; |
630 | status->status = |
631 | tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); |
632 | if (!status->status.ok()) return; |
633 | |
634 | // Initialize a input_tensor vector with `nullptr` values. |
635 | std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr); |
636 | // A vector to keep track of newly created `tf::Tensor` objects. |
637 | std::vector<Tensor> all_input_tensors; |
638 | // Update the vector with information from `input_tensors` if provided. |
639 | if (input_tensors != nullptr) { |
640 | // Note that we take the address of the elements in `all_input_tensors` |
641 | // below. Allocate enough space so that no reallocation happens, which will |
642 | // make the pointers invalid. |
643 | all_input_tensors.reserve(num_inputs); |
644 | for (int i = 0; i < num_inputs; ++i) { |
645 | if (input_tensors[i] == nullptr) continue; |
646 | all_input_tensors.emplace_back(); |
647 | Tensor& input_tensor = all_input_tensors.back(); |
648 | status->status = TF_TensorToTensor(input_tensors[i], &input_tensor); |
649 | if (!status->status.ok()) return; |
650 | input_tensors_vector[i] = &input_tensor; |
651 | } |
652 | } |
653 | |
654 | // Create an inference context with dummy values, which will be updated later. |
655 | InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def, |
656 | std::vector<ShapeHandle>(num_inputs), input_tensors_vector, |
657 | {}, |
658 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>>()); |
659 | |
660 | // Set input_shapes. |
661 | for (int i = 0; i < num_inputs; ++i) { |
662 | std::vector<DimensionHandle> dims; |
663 | const TF_ShapeAndType& input_shape = input_shapes->items[i]; |
664 | if (input_shape.num_dims == InferenceContext::kUnknownRank) { |
665 | c.SetInput(i, c.UnknownShape()); |
666 | continue; |
667 | } |
668 | dims.reserve(input_shape.num_dims); |
669 | for (int j = 0; j < input_shape.num_dims; ++j) { |
670 | dims.push_back(c.MakeDim(input_shape.dims[j])); |
671 | } |
672 | c.SetInput(i, c.MakeShape(dims)); |
673 | } |
674 | |
675 | // TODO(bgogul): Handle input_tensors_as_shapes. |
676 | // TODO(bgogul): Handle input_resource_shapes_and_types. |
677 | |
678 | status->status = c.construction_status(); |
679 | if (!status->status.ok()) return; |
680 | |
681 | if (op_reg_data->shape_inference_fn == nullptr) { |
682 | status->status = |
683 | InvalidArgument("No shape inference function exists for op '" , |
684 | node_def.op(), "', did you forget to define it?" ); |
685 | return; |
686 | } |
687 | |
688 | status->status = c.Run(op_reg_data->shape_inference_fn); |
689 | if (!status->status.ok()) return; |
690 | |
691 | // Set output_shapes. |
692 | TF_ShapeAndTypeList* output_shapes_result = |
693 | TF_NewShapeAndTypeList(c.num_outputs()); |
694 | for (int i = 0; i < c.num_outputs(); ++i) { |
695 | ShapeHandle shape_handle = c.output(i); |
696 | TF_ShapeAndType& shape = output_shapes_result->items[i]; |
697 | shape.num_dims = c.Rank(shape_handle); |
698 | if (shape.num_dims == InferenceContext::kUnknownRank) { |
699 | shape.dims = nullptr; |
700 | continue; |
701 | } |
702 | shape.dims = new int64_t[shape.num_dims]; |
703 | for (size_t j = 0; j < shape.num_dims; ++j) { |
704 | shape.dims[j] = c.Value(c.Dim(shape_handle, j)); |
705 | } |
706 | } |
707 | if (output_shapes != nullptr) *output_shapes = output_shapes_result; |
708 | |
709 | // TODO(bgogul): Set output_resource_shapes_and_types. |
710 | } |
711 | |
712 | void TF_ImportGraphDefOptionsSetValidateColocationConstraints( |
713 | TF_ImportGraphDefOptions* opts, unsigned char enable) { |
714 | opts->opts.validate_colocation_constraints = enable; |
715 | } |
716 | |
717 | // Load a Pluggable Device library. |
718 | // On success, returns the handle to library in result and return OK from the |
719 | // function. Otherwise return nullptr in result and error Status from the |
720 | // function. |
721 | // |
722 | // If `library_filename` has already been loaded, we return a cached handle. |
723 | // Device and Kernels/Ops are registered as globals when a library is loaded |
724 | // for the first time. |
725 | TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, |
726 | TF_Status* status) { |
727 | #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) |
728 | status->status = tensorflow::errors::Unimplemented( |
729 | "PluggableDevice plugin functionality is not supported on mobile" ); |
730 | return nullptr; |
731 | #else |
732 | TF_Library* lib_handle = new TF_Library; |
733 | static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); |
734 | static std::unordered_map<std::string, void*>* loaded_libs = |
735 | new std::unordered_map<std::string, void*>(); |
736 | tensorflow::Env* env = tensorflow::Env::Default(); |
737 | { |
738 | tensorflow::mutex_lock lock(mu); |
739 | auto it = loaded_libs->find(library_filename); |
740 | if (it != loaded_libs->end()) { |
741 | lib_handle->lib_handle = it->second; |
742 | } else { |
743 | status->status = |
744 | env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle); |
745 | if (status->status.ok()) { |
746 | TF_CHECK_OK( |
747 | tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle)); |
748 | } else { |
749 | delete lib_handle; |
750 | return nullptr; |
751 | } |
752 | } |
753 | return lib_handle; |
754 | } |
755 | #endif |
756 | } |
757 | |
758 | void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { |
759 | delete lib_handle; |
760 | } |
761 | |