1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/kernels.h" |
17 | |
18 | #include <memory> |
19 | |
20 | #include "tensorflow/c/c_api_internal.h" |
21 | #include "tensorflow/c/c_api_macros.h" |
22 | #include "tensorflow/c/tf_buffer_internal.h" |
23 | #include "tensorflow/c/tf_status_helper.h" |
24 | #include "tensorflow/c/tf_tensor_internal.h" |
25 | #include "tensorflow/core/framework/attr_value.pb.h" |
26 | #include "tensorflow/core/framework/kernel_def_builder.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/register_types.h" |
29 | #include "tensorflow/core/framework/resource_mgr.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | // Required for IS_MOBILE_PLATFORM definition |
32 | #include "tensorflow/core/platform/platform.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
35 | #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" |
36 | #include "tensorflow/compiler/xla/stream_executor/stream.h" |
37 | #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
38 | |
39 | using tensorflow::errors::InvalidArgument; |
40 | // This file forms the basis of a stable ABI for third-party kernel |
41 | // implementations. It is crucial that changes to this file are made cautiously |
42 | // and with a focus on maintaining both source and binary compatibility. |
43 | |
44 | struct TF_KernelBuilder { |
45 | ::tensorflow::KernelDefBuilder* cc_builder; |
46 | |
47 | void* (*create_function)(TF_OpKernelConstruction*); |
48 | void (*compute_function)(void*, TF_OpKernelContext*); |
49 | void (*delete_function)(void*); |
50 | }; |
51 | |
52 | TF_KernelBuilder* TF_NewKernelBuilder( |
53 | const char* op_name, const char* device_name, |
54 | void* (*create_func)(TF_OpKernelConstruction*), |
55 | void (*compute_func)(void*, TF_OpKernelContext*), |
56 | void (*delete_func)(void*)) { |
57 | TF_KernelBuilder* result = new TF_KernelBuilder; |
58 | result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name); |
59 | result->cc_builder->Device(device_name); |
60 | result->create_function = create_func; |
61 | result->compute_function = compute_func; |
62 | result->delete_function = delete_func; |
63 | return result; |
64 | } |
65 | |
66 | void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { |
67 | if (builder != nullptr) { |
68 | delete builder->cc_builder; |
69 | delete builder; |
70 | } |
71 | } |
72 | |
73 | namespace tensorflow { |
74 | namespace { |
75 | |
76 | #define CASE(type) \ |
77 | case DataTypeToEnum<type>::value: { \ |
78 | kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \ |
79 | break; \ |
80 | } |
81 | |
82 | void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, |
83 | const DataType dtype, TF_Status* status) { |
84 | // This needs to be under tensorflow:: namespace so that |
85 | // TF_CALL_ALL_TYPES macro can find tensorflow::string as string. |
86 | switch (dtype) { |
87 | TF_CALL_ALL_TYPES(CASE); |
88 | TF_CALL_QUANTIZED_TYPES(CASE); |
89 | TF_CALL_quint16(CASE); |
90 | TF_CALL_qint16(CASE); |
91 | default: |
92 | status->status = errors::Unimplemented("Unexpected type " , dtype); |
93 | return; |
94 | } |
95 | TF_SetStatus(status, TF_OK, "" ); |
96 | } |
97 | #undef CASE |
98 | |
99 | } // namespace |
100 | } // namespace tensorflow |
101 | |
102 | namespace { |
103 | const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx, |
104 | const char* attr_name, |
105 | TF_Status* status) { |
106 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
107 | const tensorflow::AttrValue* attr = |
108 | ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name); |
109 | if (attr == nullptr) { |
110 | status->status = InvalidArgument("Operation '" , cc_ctx->def().name(), |
111 | "' has no attr named '" , attr_name, "'." ); |
112 | } |
113 | return attr; |
114 | } |
115 | } // namespace |
116 | |
117 | void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder, |
118 | const char* attr_name, |
119 | const TF_DataType type, |
120 | TF_Status* status) { |
121 | tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type); |
122 | tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status); |
123 | } |
124 | |
125 | void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder, |
126 | const char* arg_name) { |
127 | kernel_builder->cc_builder->HostMemory(arg_name); |
128 | } |
129 | |
130 | void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder, |
131 | int32_t priority_number) { |
132 | kernel_builder->cc_builder->Priority(priority_number); |
133 | } |
134 | |
135 | void TF_KernelBuilder_Label(TF_KernelBuilder* kernel_builder, |
136 | const char* label) { |
137 | kernel_builder->cc_builder->Label(label); |
138 | } |
139 | |
140 | namespace tensorflow { |
141 | namespace { |
142 | |
143 | // An OpKernel whose methods delegate to C function pointers. |
144 | class COpKernel : public OpKernel { |
145 | public: |
146 | explicit COpKernel(OpKernelConstruction* ctx, |
147 | void* (*create_func)(TF_OpKernelConstruction*), |
148 | void (*compute_func)(void*, TF_OpKernelContext*), |
149 | void (*delete_func)(void*)) |
150 | : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) { |
151 | if (create_func != nullptr) { |
152 | c_kernel_ = |
153 | (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx)); |
154 | } else { |
155 | c_kernel_ = nullptr; |
156 | } |
157 | } |
158 | |
159 | void Compute(OpKernelContext* ctx) override { |
160 | (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx)); |
161 | } |
162 | |
163 | ~COpKernel() override { |
164 | if (delete_func_ != nullptr) { |
165 | (*delete_func_)(c_kernel_); |
166 | } |
167 | } |
168 | |
169 | private: |
170 | void (*compute_func_)(void*, TF_OpKernelContext* context); |
171 | void (*delete_func_)(void*); |
172 | void* c_kernel_; |
173 | }; |
174 | |
175 | // A KernelFactory that returns COpKernel instances. |
176 | class KernelBuilderFactory |
177 | : public ::tensorflow::kernel_factory::OpKernelFactory { |
178 | public: |
179 | explicit KernelBuilderFactory(TF_KernelBuilder* builder) |
180 | : builder_(builder) {} |
181 | ::tensorflow::OpKernel* Create( |
182 | ::tensorflow::OpKernelConstruction* context) override { |
183 | return new ::tensorflow::COpKernel(context, builder_->create_function, |
184 | builder_->compute_function, |
185 | builder_->delete_function); |
186 | } |
187 | ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); } |
188 | |
189 | private: |
190 | TF_KernelBuilder* builder_; |
191 | }; |
192 | } // namespace |
193 | } // namespace tensorflow |
194 | |
195 | void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, |
196 | TF_Status* status) { |
197 | using tensorflow::register_kernel::Name; |
198 | |
199 | TF_RegisterKernelBuilderWithKernelDef( |
200 | /*serialized_kernel_def=*/nullptr, name, builder, status); |
201 | } |
202 | |
203 | void TF_RegisterKernelBuilderWithKernelDef(const char* serialized_kernel_def, |
204 | const char* name, |
205 | TF_KernelBuilder* builder, |
206 | TF_Status* status) { |
207 | using tensorflow::register_kernel::Name; |
208 | if (serialized_kernel_def == nullptr) { |
209 | // If user doesn't provide a serialized KernelDef, use the kernel builder |
210 | // to build a new one. |
211 | tensorflow::kernel_factory::OpKernelRegistrar( |
212 | builder->cc_builder->Build(), name, |
213 | std::make_unique<tensorflow::KernelBuilderFactory>(builder)); |
214 | |
215 | TF_SetStatus(status, TF_OK, "" ); |
216 | return; |
217 | } |
218 | |
219 | tensorflow::KernelDef* kernel_def = new tensorflow::KernelDef(); |
220 | bool success = kernel_def->ParsePartialFromString(serialized_kernel_def); |
221 | if (!success) { |
222 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
223 | "Error parsing serialized KernelDef." ); |
224 | return; |
225 | } |
226 | |
227 | tensorflow::kernel_factory::OpKernelRegistrar( |
228 | kernel_def, name, |
229 | std::make_unique<tensorflow::KernelBuilderFactory>(builder)); |
230 | |
231 | TF_SetStatus(status, TF_OK, "" ); |
232 | } |
233 | |
234 | // This function is only for pluggable device. |
235 | // It will return nullptr in all other cases. |
236 | // This function is experimental and subject to change. |
237 | SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) { |
238 | #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) |
239 | status->status = tensorflow::errors::Unimplemented( |
240 | "Accessing device stream is not supported on mobile. File a bug at " |
241 | "https://github.com/tensorflow/tensorflow/issues if this feature is " |
242 | "important to you" ); |
243 | return nullptr; |
244 | #else |
245 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
246 | if (cc_ctx->op_device_context() == nullptr) { // CPU Device |
247 | status->status = tensorflow::errors::FailedPrecondition( |
248 | "Accessing device stream is not supported for a CPU device." ); |
249 | return nullptr; |
250 | } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) { |
251 | status->status = tensorflow::errors::FailedPrecondition( |
252 | "Accessing device stream is only supported for pluggable devices." ); |
253 | return nullptr; |
254 | } else { // Is a PluggableDevice |
255 | TF_SetStatus(status, TF_OK, "" ); |
256 | auto c_stream = static_cast<stream_executor::CStream*>( |
257 | cc_ctx->op_device_context()->stream()->implementation()); |
258 | return c_stream->Handle(); |
259 | } |
260 | #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) |
261 | } |
262 | |
263 | int TF_NumInputs(TF_OpKernelContext* ctx) { |
264 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
265 | return cc_ctx->num_inputs(); |
266 | } |
267 | |
268 | int TF_NumOutputs(TF_OpKernelContext* ctx) { |
269 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
270 | return cc_ctx->num_outputs(); |
271 | } |
272 | |
273 | void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, |
274 | TF_Status* status) { |
275 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
276 | if (i < 0 || i >= cc_ctx->num_inputs()) { |
277 | TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range" ); |
278 | return; |
279 | } |
280 | const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); |
281 | TF_Tensor* result = |
282 | ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status); |
283 | if (TF_GetCode(status) == TF_OK) { |
284 | *tensor = result; |
285 | } |
286 | } |
287 | |
288 | void TF_InputRange(TF_OpKernelContext* ctx, const char* name, |
289 | TF_InputRange_Args* args) { |
290 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
291 | int start = -1, stop = -1; |
292 | auto status = cc_ctx->op_kernel().InputRange(name, &start, &stop); |
293 | args->start = start; |
294 | args->stop = stop; |
295 | tensorflow::Set_TF_Status_from_Status(args->status, status); |
296 | } |
297 | |
298 | void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, |
299 | TF_Status* status) { |
300 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
301 | if (i < 0 || i >= cc_ctx->num_outputs()) { |
302 | TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range" ); |
303 | return; |
304 | } |
305 | ::tensorflow::Tensor cc_tensor; |
306 | ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); |
307 | TF_SetStatus(status, TF_OK, "" ); |
308 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
309 | if (s.ok()) { |
310 | cc_ctx->set_output(i, cc_tensor); |
311 | } |
312 | } |
313 | |
314 | TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, int i, |
315 | TF_Status* status) { |
316 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
317 | if (i < 0 || i >= cc_ctx->num_outputs()) { |
318 | TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range" ); |
319 | return nullptr; |
320 | } |
321 | const ::tensorflow::Tensor& cc_tensor = *(cc_ctx->mutable_output(i)); |
322 | TF_Tensor* result = |
323 | ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status); |
324 | if (TF_GetCode(status) == TF_OK) { |
325 | return result; |
326 | } else { |
327 | return nullptr; |
328 | } |
329 | } |
330 | |
331 | void TF_GetSerializedFunctionDefLibrary( |
332 | TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library, |
333 | TF_Status* status) { |
334 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
335 | auto fdef_lib = |
336 | cc_ctx->function_library()->GetFunctionLibraryDefinition()->ToProto(); |
337 | auto cc_status = |
338 | tensorflow::MessageToBuffer(fdef_lib, serialized_function_def_library); |
339 | tensorflow::Set_TF_Status_from_Status(status, cc_status); |
340 | } |
341 | |
342 | void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx, |
343 | TF_Buffer* serialized_config_proto, |
344 | TF_Status* status) { |
345 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
346 | const tensorflow::ConfigProto* config_proto_ptr = |
347 | cc_ctx->function_library()->config_proto(); |
348 | tensorflow::ConfigProto config_proto; |
349 | if (config_proto_ptr != nullptr) { |
350 | config_proto = *config_proto_ptr; |
351 | } |
352 | auto cc_status = |
353 | tensorflow::MessageToBuffer(config_proto, serialized_config_proto); |
354 | tensorflow::Set_TF_Status_from_Status(status, cc_status); |
355 | } |
356 | |
357 | void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx, |
358 | TF_Status* status) { |
359 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
360 | ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); |
361 | cc_ctx->CtxFailure(s); |
362 | } |
363 | |
364 | void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { |
365 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
366 | ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); |
367 | cc_ctx->CtxFailure(s); |
368 | } |
369 | |
370 | void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx, |
371 | const char* attr_name, |
372 | int32_t* list_size, |
373 | int32_t* total_size, |
374 | TF_Status* status) { |
375 | const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); |
376 | if (!status->status.ok()) { |
377 | *list_size = -1; |
378 | *total_size = -1; |
379 | return; |
380 | } |
381 | switch (attr->value_case()) { |
382 | #define SINGLE_CASE(kK, attr_type, size_expr) \ |
383 | case tensorflow::AttrValue::kK: \ |
384 | *list_size = -1; \ |
385 | *total_size = size_expr; \ |
386 | break; |
387 | |
388 | SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); |
389 | SINGLE_CASE(kI, TF_ATTR_INT, -1); |
390 | SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); |
391 | SINGLE_CASE(kB, TF_ATTR_BOOL, -1); |
392 | SINGLE_CASE(kType, TF_ATTR_TYPE, -1); |
393 | SINGLE_CASE(kShape, TF_ATTR_SHAPE, |
394 | attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); |
395 | SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); |
396 | #undef SINGLE_CASE |
397 | |
398 | case tensorflow::AttrValue::kList: |
399 | *list_size = 0; |
400 | *total_size = -1; |
401 | #define LIST_CASE(field, attr_type, ...) \ |
402 | if (attr->list().field##_size() > 0) { \ |
403 | *list_size = attr->list().field##_size(); \ |
404 | __VA_ARGS__; \ |
405 | break; \ |
406 | } |
407 | |
408 | LIST_CASE( |
409 | s, TF_ATTR_STRING, *total_size = 0; |
410 | for (int i = 0; i < attr->list().s_size(); |
411 | ++i) { *total_size += attr->list().s(i).size(); }); |
412 | LIST_CASE(i, TF_ATTR_INT); |
413 | LIST_CASE(f, TF_ATTR_FLOAT); |
414 | LIST_CASE(b, TF_ATTR_BOOL); |
415 | LIST_CASE(type, TF_ATTR_TYPE); |
416 | LIST_CASE( |
417 | shape, TF_ATTR_SHAPE, *total_size = 0; |
418 | for (int i = 0; i < attr->list().shape_size(); ++i) { |
419 | const auto& s = attr->list().shape(i); |
420 | *total_size += s.unknown_rank() ? 0 : s.dim_size(); |
421 | }); |
422 | LIST_CASE(tensor, TF_ATTR_TENSOR); |
423 | LIST_CASE(tensor, TF_ATTR_FUNC); |
424 | #undef LIST_CASE |
425 | break; |
426 | |
427 | case tensorflow::AttrValue::kPlaceholder: |
428 | *list_size = -1; |
429 | *total_size = -1; |
430 | break; |
431 | |
432 | case tensorflow::AttrValue::kFunc: |
433 | *list_size = -1; |
434 | *total_size = -1; |
435 | break; |
436 | |
437 | case tensorflow::AttrValue::VALUE_NOT_SET: |
438 | status->status = |
439 | InvalidArgument("Attribute '" , attr_name, "' has no value set" ); |
440 | break; |
441 | } |
442 | } |
443 | |
444 | #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \ |
445 | void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ |
446 | const char* attr_name, \ |
447 | c_type* val, TF_Status* status) { \ |
448 | TF_SetStatus(status, TF_OK, ""); \ |
449 | cc_type v; \ |
450 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ |
451 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \ |
452 | ::tensorflow::Set_TF_Status_from_Status(status, s); \ |
453 | if (s.ok()) { \ |
454 | *val = static_cast<c_type>(v); \ |
455 | } \ |
456 | } \ |
457 | void TF_OpKernelConstruction_GetAttr##func##List( \ |
458 | TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \ |
459 | int max_vals, TF_Status* status) { \ |
460 | TF_SetStatus(status, TF_OK, ""); \ |
461 | const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \ |
462 | if (!status->status.ok()) return; \ |
463 | if (attr->value_case() != tensorflow::AttrValue::kList) { \ |
464 | status->status = \ |
465 | InvalidArgument("Value for '", attr_name, "' is not a list."); \ |
466 | return; \ |
467 | } \ |
468 | status->status = \ |
469 | tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \ |
470 | if (!status->status.ok()) return; \ |
471 | const auto len = std::min(max_vals, attr->list().list_field##_size()); \ |
472 | for (int i = 0; i < len; ++i) { \ |
473 | vals[i] = static_cast<c_type>(attr->list().list_field(i)); \ |
474 | } \ |
475 | } |
476 | |
477 | DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type" , type) |
478 | DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int" , i) |
479 | DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int" , i) |
480 | DEFINE_TF_GETATTR(Float, float, float, "float" , f) |
481 | DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool" , b) |
482 | |
483 | void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx, |
484 | const char* attr_name, char* value, |
485 | size_t max_length, |
486 | TF_Status* status) { |
487 | std::string v; |
488 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
489 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); |
490 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
491 | |
492 | if (!status->status.ok()) return; |
493 | |
494 | if (max_length <= 0) { |
495 | return; |
496 | } |
497 | std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length)); |
498 | } |
499 | |
500 | void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx, |
501 | const char* attr_name, |
502 | char** values, size_t* lengths, |
503 | int max_values, void* storage, |
504 | size_t storage_size, |
505 | TF_Status* status) { |
506 | std::vector<std::string> v; |
507 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
508 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); |
509 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
510 | |
511 | if (!status->status.ok()) return; |
512 | |
513 | const auto len = std::min(max_values, static_cast<int>(v.size())); |
514 | char* p = static_cast<char*>(storage); |
515 | for (int i = 0; i < len; ++i) { |
516 | const std::string& s = v[i]; |
517 | values[i] = p; |
518 | lengths[i] = s.size(); |
519 | if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) { |
520 | status->status = InvalidArgument( |
521 | "Not enough storage to hold the requested list of strings" ); |
522 | return; |
523 | } |
524 | memcpy(values[i], s.data(), s.size()); |
525 | p += s.size(); |
526 | } |
527 | } |
528 | |
529 | void TF_OpKernelConstruction_GetAttrTensor(TF_OpKernelConstruction* ctx, |
530 | const char* attr_name, |
531 | TF_Tensor** val, TF_Status* status) { |
532 | *val = nullptr; |
533 | ::tensorflow::Tensor t; |
534 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
535 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &t); |
536 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
537 | |
538 | if (!status->status.ok()) return; |
539 | |
540 | *val = TF_TensorFromTensor(t, &status->status); |
541 | } |
542 | |
543 | void TF_OpKernelConstruction_GetAttrTensorList(TF_OpKernelConstruction* ctx, |
544 | const char* attr_name, |
545 | TF_Tensor** vals, int max_values, |
546 | TF_Status* status) { |
547 | std::vector<::tensorflow::Tensor> v; |
548 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
549 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); |
550 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
551 | |
552 | if (!status->status.ok()) return; |
553 | |
554 | const auto len = std::min(max_values, static_cast<int>(v.size())); |
555 | for (int i = 0; i < len; ++i) { |
556 | vals[i] = TF_TensorFromTensor(v[i], &status->status); |
557 | if (!status->status.ok()) return; |
558 | } |
559 | } |
560 | |
561 | TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction* ctx, |
562 | const char* attr_name, |
563 | TF_Status* status) { |
564 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
565 | tensorflow::NameAttrList function; |
566 | auto cc_status = cc_ctx->GetAttr(attr_name, &function); |
567 | if (!cc_status.ok()) { |
568 | Set_TF_Status_from_Status(status, cc_status); |
569 | return nullptr; |
570 | } |
571 | TF_Buffer* buffer = TF_NewBuffer(); |
572 | cc_status = tensorflow::MessageToBuffer(function, buffer); |
573 | Set_TF_Status_from_Status(status, cc_status); |
574 | if (!cc_status.ok()) |
575 | return nullptr; |
576 | else |
577 | return buffer; |
578 | } |
579 | |
580 | bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx, |
581 | const char* attr_name, TF_Status* status) { |
582 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
583 | return cc_ctx->HasAttr(attr_name); |
584 | } |
585 | |
586 | TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { |
587 | auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx); |
588 | TF_StringView string_view_of_name; |
589 | string_view_of_name.data = cc_ctx->def().name().data(); |
590 | string_view_of_name.len = cc_ctx->def().name().length(); |
591 | return string_view_of_name; |
592 | } |
593 | |
594 | TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { |
595 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
596 | CHECK_GE(i, 0); |
597 | CHECK_LT(i, cc_ctx->num_outputs()); |
598 | return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i)); |
599 | } |
600 | |
601 | bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, TF_Status* status) { |
602 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
603 | if (i < 0 || i >= cc_ctx->num_inputs()) { |
604 | TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range" ); |
605 | return false; |
606 | } |
607 | TF_SetStatus(status, TF_OK, "" ); |
608 | return cc_ctx->input_memory_type(i) == tensorflow::HOST_MEMORY; |
609 | } |
610 | |
611 | bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, TF_Status* status) { |
612 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
613 | if (i < 0 || i >= cc_ctx->num_outputs()) { |
614 | TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range" ); |
615 | return false; |
616 | } |
617 | TF_SetStatus(status, TF_OK, "" ); |
618 | return cc_ctx->output_memory_type(i) == tensorflow::HOST_MEMORY; |
619 | } |
620 | |
621 | int64_t TF_StepId(TF_OpKernelContext* ctx) { |
622 | return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); |
623 | } |
624 | |
625 | TF_Buffer* TF_OpKernelConstruction_GetNodeDef(TF_OpKernelConstruction* ctx, |
626 | TF_Status* status) { |
627 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
628 | TF_Buffer* ret = TF_NewBuffer(); |
629 | status->status = MessageToBuffer(cc_ctx->def(), ret); |
630 | if (!status->status.ok()) { |
631 | TF_DeleteBuffer(ret); |
632 | return nullptr; |
633 | } |
634 | return ret; |
635 | } |
636 | |
637 | uint64_t TF_GetFrameId(TF_OpKernelContext* ctx) { |
638 | return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx) |
639 | ->frame_iter() |
640 | .frame_id; |
641 | } |
642 | |
643 | int TF_GetGraphDefVersion(TF_OpKernelContext* ctx) { |
644 | return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx) |
645 | ->function_library() |
646 | ->graph_def_version(); |
647 | } |
648 | |
649 | int64_t TF_GetIterId(TF_OpKernelContext* ctx) { |
650 | return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx) |
651 | ->frame_iter() |
652 | .iter_id; |
653 | } |
654 | |
655 | TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) { |
656 | auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
657 | TF_StringView opkernel_name_sv; |
658 | opkernel_name_sv.data = cc_ctx->op_kernel().name().data(); |
659 | opkernel_name_sv.len = cc_ctx->op_kernel().name().length(); |
660 | return opkernel_name_sv; |
661 | } |
662 | |
663 | TF_StringView TF_GetResourceMgrDefaultContainerName(TF_OpKernelContext* ctx) { |
664 | auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
665 | TF_StringView default_container_name_sv; |
666 | default_container_name_sv.data = |
667 | cc_ctx->resource_manager()->default_container().data(); |
668 | default_container_name_sv.len = |
669 | cc_ctx->resource_manager()->default_container().length(); |
670 | return default_container_name_sv; |
671 | } |
672 | |
673 | TF_StringView TF_GetOpKernelRequestedInput(TF_OpKernelContext* ctx, |
674 | size_t index) { |
675 | auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
676 | TF_StringView requested_input_sv; |
677 | requested_input_sv.data = cc_ctx->op_kernel().requested_input(index).data(); |
678 | requested_input_sv.len = cc_ctx->op_kernel().requested_input(index).length(); |
679 | return requested_input_sv; |
680 | } |
681 | |
682 | TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, |
683 | TF_DataType dtype, const int64_t* dims, |
684 | int num_dims, size_t len, TF_Status* status) { |
685 | TF_SetStatus(status, TF_OK, "" ); |
686 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); |
687 | static_assert(sizeof(int64_t) == sizeof(int64_t), |
688 | "64-bit int types should match in size" ); |
689 | tensorflow::gtl::ArraySlice<const int64_t> dimarray( |
690 | reinterpret_cast<const int64_t*>(dims), num_dims); |
691 | tensorflow::Tensor* tensor; |
692 | tensorflow::Status s = cc_ctx->allocate_output( |
693 | index, tensorflow::TensorShape(dimarray), &tensor); |
694 | if (!s.ok()) { |
695 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
696 | return nullptr; |
697 | } |
698 | TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s); |
699 | if (!s.ok()) { |
700 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
701 | return nullptr; |
702 | } |
703 | return tf_tensor; |
704 | } |
705 | |
706 | TF_Tensor* TF_ForwardInputOrAllocateOutput( |
707 | TF_OpKernelContext* context, const int* candidate_input_indices, |
708 | int num_candidate_input_indices, int output_index, |
709 | const int64_t* output_dims, int output_num_dims, int* forwarded_input, |
710 | TF_Status* status) { |
711 | TF_SetStatus(status, TF_OK, "" ); |
712 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); |
713 | |
714 | static_assert(sizeof(int64_t) == sizeof(int64_t), |
715 | "64-bit int types should match in size" ); |
716 | tensorflow::gtl::ArraySlice<int> input_indices_array( |
717 | candidate_input_indices, num_candidate_input_indices); |
718 | tensorflow::gtl::ArraySlice<const int64_t> output_dimarray( |
719 | reinterpret_cast<const int64_t*>(output_dims), output_num_dims); |
720 | tensorflow::Tensor* output_tensor_pointer; |
721 | tensorflow::Status s = cc_ctx->forward_input_or_allocate_output( |
722 | input_indices_array, output_index, |
723 | tensorflow::TensorShape(output_dimarray), &output_tensor_pointer, |
724 | forwarded_input); |
725 | if (!s.ok()) { |
726 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
727 | return nullptr; |
728 | } |
729 | TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s); |
730 | if (!s.ok()) { |
731 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
732 | return nullptr; |
733 | } |
734 | return tf_tensor_output; |
735 | } |
736 | |
737 | TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, |
738 | const int64_t* dims, int num_dims, |
739 | TF_AllocatorAttributes* attributes, |
740 | TF_Status* status) { |
741 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); |
742 | TF_SetStatus(status, TF_OK, "" ); |
743 | static_assert(sizeof(int64_t) == sizeof(int64_t), |
744 | "64-bit int types should match in size" ); |
745 | tensorflow::gtl::ArraySlice<const int64_t> dimarray( |
746 | reinterpret_cast<const int64_t*>(dims), num_dims); |
747 | if (attributes && !attributes->struct_size) { |
748 | TF_SetStatus( |
749 | status, TF_INVALID_ARGUMENT, |
750 | "TF_AllocatorAttributes struct " |
751 | "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE" ); |
752 | return nullptr; |
753 | } |
754 | tensorflow::AllocatorAttributes allocator_attr; |
755 | if (attributes && attributes->on_host) { |
756 | allocator_attr.set_on_host(true); |
757 | } |
758 | tensorflow::Status s; |
759 | tensorflow::Tensor tensor; |
760 | s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype), |
761 | tensorflow::TensorShape(dimarray), &tensor, |
762 | allocator_attr); |
763 | if (!s.ok()) { |
764 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
765 | return nullptr; |
766 | } |
767 | TF_Tensor* tf_tensor; |
768 | tf_tensor = TF_TensorFromTensor(tensor, &s); |
769 | if (!s.ok()) { |
770 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
771 | return nullptr; |
772 | } |
773 | return tf_tensor; |
774 | } |
775 | |