1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/c_api_experimental.h"
17
18#include "absl/strings/substitute.h"
19#include "tensorflow/c/c_api.h"
20#include "tensorflow/c/c_api_internal.h"
21#include "tensorflow/c/checkpoint_reader.h"
22#include "tensorflow/c/eager/c_api.h"
23#include "tensorflow/c/eager/c_api_internal.h"
24#include "tensorflow/c/eager/tfe_context_internal.h"
25#include "tensorflow/c/eager/tfe_op_internal.h"
26#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27#include "tensorflow/c/tf_buffer_internal.h"
28#include "tensorflow/compiler/jit/flags.h"
29#include "tensorflow/core/common_runtime/eager/attr_builder.h"
30#include "tensorflow/core/common_runtime/eager/context.h"
31#include "tensorflow/core/common_runtime/eager/eager_operation.h"
32#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
33#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
34#include "tensorflow/core/framework/collective.h"
35#include "tensorflow/core/framework/node_def.pb.h"
36#include "tensorflow/core/framework/shape_inference.h"
37#include "tensorflow/core/framework/tensor.pb.h"
38#include "tensorflow/core/graph/graph.h"
39#include "tensorflow/core/graph/node_builder.h"
40#include "tensorflow/core/platform/blocking_counter.h"
41#include "tensorflow/core/platform/casts.h"
42#include "tensorflow/core/platform/env.h"
43#include "tensorflow/core/platform/init_main.h"
44#include "tensorflow/core/platform/mutex.h"
45#include "tensorflow/core/platform/net.h"
46#include "tensorflow/core/platform/platform.h"
47#include "tensorflow/core/platform/strcat.h"
48#include "tensorflow/core/protobuf/config.pb.h"
49#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
50
51using tensorflow::FunctionDef;
52using tensorflow::Node;
53using tensorflow::NodeBuilder;
54using tensorflow::Status;
55using tensorflow::errors::InvalidArgument;
56
57namespace {
58typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
59 UniqueFuncPtr;
60}
61
62// struct TF_Operation { tensorflow::Node node; };
63static TF_Operation* ToTF_Operation(Node* node) {
64 return static_cast<TF_Operation*>(static_cast<void*>(node));
65}
66
67void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
68 tensorflow::ConfigProto& config = options->options.config;
69 auto* optimizer_options =
70 config.mutable_graph_options()->mutable_optimizer_options();
71 if (enable) {
72 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
73
74 // These XLA flags are needed to trigger XLA properly from C (more generally
75 // non-Python) clients. If this API is called again with `enable` set to
76 // false, it is safe to keep these flag values as is.
77 tensorflow::MarkForCompilationPassFlags* flags =
78 tensorflow::GetMarkForCompilationPassFlags();
79 flags->tf_xla_cpu_global_jit = true;
80 flags->tf_xla_min_cluster_size = 1;
81 } else {
82 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
83 }
84}
85
86unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
87 tensorflow::BuildXlaOpsPassFlags* flags =
88 tensorflow::GetBuildXlaOpsPassFlags();
89 bool original = flags->tf_xla_enable_lazy_compilation;
90 flags->tf_xla_enable_lazy_compilation = enable;
91 return original;
92}
93
94unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
95 tensorflow::MarkForCompilationPassFlags* flags =
96 tensorflow::GetMarkForCompilationPassFlags();
97 bool original = flags->tf_xla_cpu_global_jit;
98 flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
99 return static_cast<unsigned char>(original);
100}
101
102void TF_SetXlaAutoJitMode(const char* mode) {
103 tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
104}
105
106unsigned char TF_GetXlaAutoJitEnabled() {
107 tensorflow::XlaAutoJitFlag flag =
108 tensorflow::GetMarkForCompilationPassFlags()->xla_auto_jit_flag;
109 return static_cast<unsigned char>(flag.optimization_level_single_gpu > 0 ||
110 flag.optimization_level_general > 0);
111}
112
113unsigned char TF_GetXlaConstantFoldingDisabled() {
114 return static_cast<unsigned char>(
115 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
116}
117
118void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
119 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
120 static_cast<bool>(should_enable);
121}
122
123void TF_SetXlaMinClusterSize(int size) {
124 tensorflow::MarkForCompilationPassFlags* flags =
125 tensorflow::GetMarkForCompilationPassFlags();
126 flags->tf_xla_min_cluster_size = size;
127}
128
129TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
130 unsigned char gpu_memory_allow_growth,
131 unsigned int num_cpu_devices) {
132 tensorflow::ConfigProto config;
133 auto* optimizer_options =
134 config.mutable_graph_options()->mutable_optimizer_options();
135 if (enable_xla_compilation) {
136 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
137
138 // These XLA flags are needed to trigger XLA properly from C (more generally
139 // non-Python) clients. If this API is called again with `enable` set to
140 // false, it is safe to keep these flag values as is.
141 tensorflow::MarkForCompilationPassFlags* flags =
142 tensorflow::GetMarkForCompilationPassFlags();
143 flags->tf_xla_cpu_global_jit = true;
144 flags->tf_xla_min_cluster_size = 1;
145 } else {
146 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
147 }
148
149 auto* gpu_options = config.mutable_gpu_options();
150 gpu_options->set_allow_growth(gpu_memory_allow_growth);
151
152 (*config.mutable_device_count())["CPU"] = num_cpu_devices;
153
154 // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
155 // threadpool, so that we avoid the possibility of running the runner_ in the
156 // threadpool of GPU event mgr, as that can trigger more callbacks to be
157 // scheduled on that same threadpool, causing a deadlock in cases where the
158 // caller of event_mgr->ThenExecute() blocks on the completion of the callback
159 // (as in the case of ConstOp kernel creation on GPU, which involves copying a
160 // CPU tensor to GPU).
161 // Setting a larger thread pool does not help with the Swift caller, as we use
162 // a different TFE context for each thread of execution (for running graph
163 // functions, and their send/recvs corountines).
164 config.set_inter_op_parallelism_threads(1);
165
166 TF_Buffer* ret = TF_NewBuffer();
167 TF_CHECK_OK(MessageToBuffer(config, ret));
168 return ret;
169}
170
171TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
172 tensorflow::RunOptions options;
173 if (enable_full_trace) {
174 options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
175 } else {
176 options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
177 }
178 TF_Buffer* ret = TF_NewBuffer();
179 TF_CHECK_OK(MessageToBuffer(options, ret));
180 return ret;
181}
182
183const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
184 tensorflow::mutex_lock c(graph->mu);
185 const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
186 *len = debug_str.size();
187 char* ret = static_cast<char*>(malloc(*len + 1));
188 memcpy(ret, debug_str.c_str(), *len + 1);
189 return ret;
190}
191
192char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
193 const auto& debug_str = DebugString(func->fdef);
194 *len = debug_str.size();
195 char* ret = static_cast<char*>(malloc(*len + 1));
196 memcpy(ret, debug_str.c_str(), *len + 1);
197 return ret;
198}
199
200// On success, returns a set of TF_Function instances from `text_proto` of
201// GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
202//
203// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
204// before creating a TF_Function out of the possibly mutated proto.
205static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
206 const char* text_proto,
207 std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
208 tensorflow::GraphDef gdef;
209 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
210 status->status = tensorflow::errors::Internal(
211 "Invalid text proto for GraphDef: ", text_proto);
212 return {};
213 }
214 const auto& fdef_lib = gdef.library();
215 if (fdef_lib.gradient_size() > 0) {
216 status->status = tensorflow::errors::Internal(
217 "GradientDef is not supported in reading Dataset related functions: ",
218 text_proto);
219 return {};
220 }
221 std::vector<UniqueFuncPtr> ret;
222 for (const FunctionDef& fdef : fdef_lib.function()) {
223 // Make a copy so that we can mutate it.
224 FunctionDef fdef_to_load = fdef;
225 if (mutate_proto_func) {
226 (*mutate_proto_func)(&fdef_to_load);
227 }
228 VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
229 std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
230 fdef_to_load.SerializeToArray(binary_proto_buf.data(),
231 binary_proto_buf.size());
232 TF_Function* func = TF_FunctionImportFunctionDef(
233 binary_proto_buf.data(), binary_proto_buf.size(), status);
234 if (!status->status.ok()) return {};
235 ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
236 }
237 return ret;
238}
239
240TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
241 TF_Status* status) {
242 assert(session);
243 {
244 tensorflow::mutex_lock c(session->graph->mu);
245 VLOG(1) << "Dequeuing named tensor with id " << tensor_id
246 << ", with input graph: "
247 << session->graph->graph.ToGraphDefDebug().DebugString();
248 }
249
250 TF_Operation* dequeue_op = TF_GraphOperationByName(
251 session->graph,
252 tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
253 if (dequeue_op == nullptr) {
254 status->status = tensorflow::errors::Internal(
255 "Unable to find the dequeue node in the TF graph.");
256 return nullptr;
257 }
258
259 VLOG(1) << "Running the dequeue op";
260 TF_Output output{dequeue_op, 0};
261 TF_Tensor* ret;
262 TF_SessionRun(session, /*run_options*/ nullptr,
263 // input related parameters
264 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
265 // output related parameters
266 /*outputs*/ &output, /*output_values*/ &ret,
267 /*noutputs*/ 1,
268 /*targets*/ nullptr, /*ntargets*/ 0,
269 /*run_metadata*/ nullptr, status);
270 if (VLOG_IS_ON(1) && status->status.ok()) {
271 tensorflow::Tensor tensor;
272 if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
273 VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
274 }
275 }
276 return ret;
277}
278
279void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
280 TF_Tensor* tensor, TF_Status* status) {
281 assert(session);
282 {
283 tensorflow::mutex_lock c(session->graph->mu);
284 if (VLOG_IS_ON(1)) {
285 VLOG(1) << "Enqueuing named tensor with id " << tensor_id
286 << ", with input graph: "
287 << session->graph->graph.ToGraphDefDebug().DebugString();
288 tensorflow::Tensor internal_tensor;
289 if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
290 VLOG(1) << "Enqueu'ing tensor content: "
291 << internal_tensor.DebugString();
292 }
293 }
294 }
295
296 TF_Operation* enqueue_op = TF_GraphOperationByName(
297 session->graph,
298 tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
299 if (enqueue_op == nullptr) {
300 status->status = tensorflow::errors::Internal(
301 "Unable to find the enqueue node in the TF graph.");
302 return;
303 }
304
305 TF_Operation* placeholder_op = TF_GraphOperationByName(
306 session->graph,
307 tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
308 if (placeholder_op == nullptr) {
309 status->status = tensorflow::errors::Internal(
310 "Unable to find the placeholder node as input to enqueue in the TF "
311 "graph.");
312 return;
313 }
314
315 VLOG(1) << "Running the enqueue op";
316 TF_Output input{placeholder_op, 0};
317 TF_SessionRun(session, /*run_options*/ nullptr,
318 // input related parameters
319 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
320 // output related parameters
321 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
322 /*targets*/ &enqueue_op, /*ntargets*/ 1,
323 /*run_metadata*/ nullptr, status);
324 VLOG(1) << "Enqueuing is done.";
325}
326
327TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
328 tensorflow::ServerDef server_def;
329 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
330 &server_def)) {
331 status->status = tensorflow::errors::Internal(
332 "Invalid text proto for ServerDef: ", text_proto);
333 return nullptr;
334 }
335 status->status = tensorflow::Status();
336 TF_Buffer* ret = TF_NewBuffer();
337 TF_CHECK_OK(MessageToBuffer(server_def, ret));
338 return ret;
339}
340
341void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
342 status->status = tensorflow::errors::Internal(errMsg);
343}
344
345struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
346 using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
347 std::vector<std::string> variable_list;
348};
349
350TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
351 TF_Status* status) {
352 TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
353 if (!status->status.ok()) {
354 TF_DeleteCheckpointReader(reader);
355 return nullptr;
356 }
357 const auto& m = reader->GetVariableToDataTypeMap();
358 for (auto it = m.begin(); it != m.end(); ++it)
359 reader->variable_list.push_back(it->first);
360 std::sort(reader->variable_list.begin(), reader->variable_list.end());
361 return reader;
362}
363
364void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
365
366int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
367 const char* name) {
368 return reader->HasTensor(name);
369}
370
371const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
372 int index) {
373 return reader->variable_list[index].c_str();
374}
375
376int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
377 return reader->variable_list.size();
378}
379
380TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
381 const char* name) {
382 const auto& m = reader->GetVariableToDataTypeMap();
383 return static_cast<TF_DataType>(m.at(name));
384}
385
386TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
387 const char* name, TF_Status* status) {
388 std::unique_ptr<tensorflow::Tensor> tensor;
389 reader->GetTensor(name, &tensor, status);
390 if (!status->status.ok()) return nullptr;
391 return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
392}
393
394void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
395 const char* name, int64_t* dims,
396 int num_dims, TF_Status* status) {
397 const auto& shape = reader->GetVariableToShapeMap().at(name);
398 int rank = shape.dims();
399 if (num_dims != rank) {
400 status->status = InvalidArgument("Expected rank is ", num_dims,
401 " but actual rank is ", rank);
402 return;
403 }
404 for (int i = 0; i < num_dims; i++) {
405 dims[i] = shape.dim_size(i);
406 }
407}
408
409int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
410 const char* name) {
411 const auto& m = reader->GetVariableToShapeMap();
412 return m.at(name).dims();
413}
414
415// This builder is used in the eager API to build a NodeDef.
416struct TF_AttrBuilder : public tensorflow::AttrBuilder {
417 using tensorflow::AttrBuilder::AttrBuilder;
418 // The string buffers to make sure that any `attr_name` we pass into
419 // `builder->Set()` will outlive the subsequent
420 // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
421 std::set<std::string> attr_names;
422};
423
424TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
425 return new TF_AttrBuilder(op_name);
426}
427
428void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
429
430void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
431 TF_DataType value) {
432 auto iter = builder->attr_names.insert(attr_name).first;
433 builder->Set(*iter, static_cast<tensorflow::DataType>(value));
434}
435
436void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
437 const TF_DataType* values, int num_values) {
438 auto iter = builder->attr_names.insert(attr_name).first;
439 builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
440 reinterpret_cast<const tensorflow::DataType*>(values),
441 num_values));
442}
443
444void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
445 const char* device_type,
446 TF_Status* status) {
447 status->status = tensorflow::FindKernelDef(
448 tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
449 /* def = */ nullptr, /* kernel_class_name = */ nullptr);
450}
451
452const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
453 TF_Status* status) {
454 const tensorflow::OpDef* op_def = nullptr;
455 status->status =
456 tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
457 if (!status->status.ok()) return nullptr;
458
459 if (input_index >= op_def->input_arg_size() || input_index < 0) {
460 status->status = tensorflow::errors::InvalidArgument(
461 input_index, " out of range for ", op_name);
462 return nullptr;
463 }
464
465 const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
466
467 if (input_arg.number_attr().empty()) {
468 status->status = tensorflow::errors::NotFound(
469 op_name, " does not have number_attr() defined.");
470 return nullptr;
471 }
472
473 // The returned string is owned by OpRegistry, so liveness is not a concern.
474 return input_arg.number_attr().c_str();
475}
476
477int TF_OpIsStateful(const char* op_type, TF_Status* status) {
478 const tensorflow::OpRegistrationData* op_reg_data;
479 status->status =
480 tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
481 if (!status->status.ok()) {
482 return 0;
483 }
484 return op_reg_data->op_def.is_stateful();
485}
486
487void TF_InitMain(const char* usage, int* argc, char*** argv) {
488 tensorflow::port::InitMain(usage, argc, argv);
489}
490
491int TF_PickUnusedPortOrDie() {
492 return tensorflow::internal::PickUnusedPortOrDie();
493}
494
495TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
496 void* data, size_t len,
497 TF_Status* status) {
498 auto dtype = static_cast<tensorflow::DataType>(data_type);
499 DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
500
501 tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
502 std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
503
504 status->status = ::tensorflow::OkStatus();
505 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
506}
507
508// Set server_def on the context, possibly updating it.
509TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
510 const void* proto,
511 size_t proto_len,
512 TF_Status* status) {
513 tensorflow::ServerDef server_def;
514 if (!server_def.ParseFromArray(proto, proto_len)) {
515 status->status = tensorflow::errors::InvalidArgument(
516 "Invalid tensorflow.ServerDef protocol buffer");
517 return;
518 }
519 status->status = tensorflow::unwrap(ctx)->EnableCollectiveOps(server_def);
520}
521
522TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
523 TF_Status* status) {
524 tensorflow::EagerContext* context =
525 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
526 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
527 collective_executor_handle->get()->StartAbort(status->status);
528}
529
530TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
531 TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
532 TF_Status* status) {
533 tensorflow::EagerContext* context =
534 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
535 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
536 tensorflow::Notification done;
537 collective_executor_handle->get()->remote_access()->CheckPeerHealth(
538 task, timeout_in_ms, [&done, status](const Status& s) {
539 status->status = s;
540 done.Notify();
541 });
542 done.WaitForNotification();
543}
544
545TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
546 TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
547 result->num_items = num_items;
548 result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
549 return result;
550}
551
552void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
553 const int64_t* dims, int num_dims) {
554 DCHECK(index >= 0 && index < shape_list->num_items);
555 TF_ShapeAndType& shape = shape_list->items[index];
556 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
557 DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
558 shape.num_dims = num_dims;
559 shape.dims = new int64_t[num_dims];
560 memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
561}
562
563void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
564 int index) {
565 DCHECK(index >= 0 && index < shape_list->num_items);
566 TF_ShapeAndType& shape = shape_list->items[index];
567 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
568 shape.num_dims = -1;
569 shape.dims = nullptr;
570}
571
572void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
573 TF_DataType dtype) {
574 DCHECK(index >= 0 && index < shape_list->num_items);
575 TF_ShapeAndType& shape_and_type = shape_list->items[index];
576 shape_and_type.dtype = dtype;
577}
578
579void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
580 if (shape_list == nullptr) return;
581 for (size_t i = 0; i < shape_list->num_items; ++i) {
582 delete[] shape_list->items[i].dims;
583 }
584 delete[] shape_list->items;
585 delete shape_list;
586}
587
588void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
589 int num_items) {
590 if (shape_list_array == nullptr) return;
591 for (int i = 0; i < num_items; ++i) {
592 TF_DeleteShapeAndTypeList(shape_list_array[i]);
593 }
594 delete[] shape_list_array;
595}
596
597namespace tensorflow {
598Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
599
600// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
601Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
602} // namespace tensorflow
603
604void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
605 TF_Tensor** input_tensors,
606 TF_ShapeAndTypeList* input_tensors_as_shapes,
607 TF_ShapeAndTypeList** input_resource_shapes_and_types,
608 TF_ShapeAndTypeList** output_shapes,
609 TF_ShapeAndTypeList*** output_resource_shapes_and_types,
610 TF_Status* status) {
611 using tensorflow::NodeDef;
612 using tensorflow::OpRegistrationData;
613 using tensorflow::Tensor;
614 using tensorflow::shape_inference::DimensionHandle;
615 using tensorflow::shape_inference::InferenceContext;
616 using tensorflow::shape_inference::ShapeAndType;
617 using tensorflow::shape_inference::ShapeHandle;
618
619 const int num_inputs = input_shapes->num_items;
620 NodeDef node_def;
621 tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
622 node_def.set_name(op->Name());
623 node_def.set_op(op->Name());
624 for (int i = 0; i < num_inputs; ++i) {
625 node_def.add_input("dummy_input");
626 }
627 OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
628
629 const tensorflow::OpRegistrationData* op_reg_data;
630 status->status =
631 tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
632 if (!status->status.ok()) return;
633
634 // Initialize a input_tensor vector with `nullptr` values.
635 std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
636 // A vector to keep track of newly created `tf::Tensor` objects.
637 std::vector<Tensor> all_input_tensors;
638 // Update the vector with information from `input_tensors` if provided.
639 if (input_tensors != nullptr) {
640 // Note that we take the address of the elements in `all_input_tensors`
641 // below. Allocate enough space so that no reallocation happens, which will
642 // make the pointers invalid.
643 all_input_tensors.reserve(num_inputs);
644 for (int i = 0; i < num_inputs; ++i) {
645 if (input_tensors[i] == nullptr) continue;
646 all_input_tensors.emplace_back();
647 Tensor& input_tensor = all_input_tensors.back();
648 status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
649 if (!status->status.ok()) return;
650 input_tensors_vector[i] = &input_tensor;
651 }
652 }
653
654 // Create an inference context with dummy values, which will be updated later.
655 InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
656 std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
657 {},
658 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
659
660 // Set input_shapes.
661 for (int i = 0; i < num_inputs; ++i) {
662 std::vector<DimensionHandle> dims;
663 const TF_ShapeAndType& input_shape = input_shapes->items[i];
664 if (input_shape.num_dims == InferenceContext::kUnknownRank) {
665 c.SetInput(i, c.UnknownShape());
666 continue;
667 }
668 dims.reserve(input_shape.num_dims);
669 for (int j = 0; j < input_shape.num_dims; ++j) {
670 dims.push_back(c.MakeDim(input_shape.dims[j]));
671 }
672 c.SetInput(i, c.MakeShape(dims));
673 }
674
675 // TODO(bgogul): Handle input_tensors_as_shapes.
676 // TODO(bgogul): Handle input_resource_shapes_and_types.
677
678 status->status = c.construction_status();
679 if (!status->status.ok()) return;
680
681 if (op_reg_data->shape_inference_fn == nullptr) {
682 status->status =
683 InvalidArgument("No shape inference function exists for op '",
684 node_def.op(), "', did you forget to define it?");
685 return;
686 }
687
688 status->status = c.Run(op_reg_data->shape_inference_fn);
689 if (!status->status.ok()) return;
690
691 // Set output_shapes.
692 TF_ShapeAndTypeList* output_shapes_result =
693 TF_NewShapeAndTypeList(c.num_outputs());
694 for (int i = 0; i < c.num_outputs(); ++i) {
695 ShapeHandle shape_handle = c.output(i);
696 TF_ShapeAndType& shape = output_shapes_result->items[i];
697 shape.num_dims = c.Rank(shape_handle);
698 if (shape.num_dims == InferenceContext::kUnknownRank) {
699 shape.dims = nullptr;
700 continue;
701 }
702 shape.dims = new int64_t[shape.num_dims];
703 for (size_t j = 0; j < shape.num_dims; ++j) {
704 shape.dims[j] = c.Value(c.Dim(shape_handle, j));
705 }
706 }
707 if (output_shapes != nullptr) *output_shapes = output_shapes_result;
708
709 // TODO(bgogul): Set output_resource_shapes_and_types.
710}
711
712void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
713 TF_ImportGraphDefOptions* opts, unsigned char enable) {
714 opts->opts.validate_colocation_constraints = enable;
715}
716
717// Load a Pluggable Device library.
718// On success, returns the handle to library in result and return OK from the
719// function. Otherwise return nullptr in result and error Status from the
720// function.
721//
722// If `library_filename` has already been loaded, we return a cached handle.
723// Device and Kernels/Ops are registered as globals when a library is loaded
724// for the first time.
725TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
726 TF_Status* status) {
727#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
728 status->status = tensorflow::errors::Unimplemented(
729 "PluggableDevice plugin functionality is not supported on mobile");
730 return nullptr;
731#else
732 TF_Library* lib_handle = new TF_Library;
733 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
734 static std::unordered_map<std::string, void*>* loaded_libs =
735 new std::unordered_map<std::string, void*>();
736 tensorflow::Env* env = tensorflow::Env::Default();
737 {
738 tensorflow::mutex_lock lock(mu);
739 auto it = loaded_libs->find(library_filename);
740 if (it != loaded_libs->end()) {
741 lib_handle->lib_handle = it->second;
742 } else {
743 status->status =
744 env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
745 if (status->status.ok()) {
746 TF_CHECK_OK(
747 tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
748 } else {
749 delete lib_handle;
750 return nullptr;
751 }
752 }
753 return lib_handle;
754 }
755#endif
756}
757
758void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
759 delete lib_handle;
760}
761