1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/kernels.h"
17
18#include <memory>
19
20#include "tensorflow/c/c_api_internal.h"
21#include "tensorflow/c/c_api_macros.h"
22#include "tensorflow/c/tf_buffer_internal.h"
23#include "tensorflow/c/tf_status_helper.h"
24#include "tensorflow/c/tf_tensor_internal.h"
25#include "tensorflow/core/framework/attr_value.pb.h"
26#include "tensorflow/core/framework/kernel_def_builder.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/resource_mgr.h"
30#include "tensorflow/core/framework/types.h"
31// Required for IS_MOBILE_PLATFORM definition
32#include "tensorflow/core/platform/platform.h"
33#include "tensorflow/core/platform/types.h"
34#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
35#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
36#include "tensorflow/compiler/xla/stream_executor/stream.h"
37#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38
39using tensorflow::errors::InvalidArgument;
40// This file forms the basis of a stable ABI for third-party kernel
41// implementations. It is crucial that changes to this file are made cautiously
42// and with a focus on maintaining both source and binary compatibility.
43
44struct TF_KernelBuilder {
45 ::tensorflow::KernelDefBuilder* cc_builder;
46
47 void* (*create_function)(TF_OpKernelConstruction*);
48 void (*compute_function)(void*, TF_OpKernelContext*);
49 void (*delete_function)(void*);
50};
51
52TF_KernelBuilder* TF_NewKernelBuilder(
53 const char* op_name, const char* device_name,
54 void* (*create_func)(TF_OpKernelConstruction*),
55 void (*compute_func)(void*, TF_OpKernelContext*),
56 void (*delete_func)(void*)) {
57 TF_KernelBuilder* result = new TF_KernelBuilder;
58 result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
59 result->cc_builder->Device(device_name);
60 result->create_function = create_func;
61 result->compute_function = compute_func;
62 result->delete_function = delete_func;
63 return result;
64}
65
66void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
67 if (builder != nullptr) {
68 delete builder->cc_builder;
69 delete builder;
70 }
71}
72
73namespace tensorflow {
74namespace {
75
76#define CASE(type) \
77 case DataTypeToEnum<type>::value: { \
78 kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
79 break; \
80 }
81
82void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
83 const DataType dtype, TF_Status* status) {
84 // This needs to be under tensorflow:: namespace so that
85 // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
86 switch (dtype) {
87 TF_CALL_ALL_TYPES(CASE);
88 TF_CALL_QUANTIZED_TYPES(CASE);
89 TF_CALL_quint16(CASE);
90 TF_CALL_qint16(CASE);
91 default:
92 status->status = errors::Unimplemented("Unexpected type ", dtype);
93 return;
94 }
95 TF_SetStatus(status, TF_OK, "");
96}
97#undef CASE
98
99} // namespace
100} // namespace tensorflow
101
102namespace {
103const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
104 const char* attr_name,
105 TF_Status* status) {
106 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
107 const tensorflow::AttrValue* attr =
108 ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
109 if (attr == nullptr) {
110 status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
111 "' has no attr named '", attr_name, "'.");
112 }
113 return attr;
114}
115} // namespace
116
117void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
118 const char* attr_name,
119 const TF_DataType type,
120 TF_Status* status) {
121 tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
122 tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
123}
124
125void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
126 const char* arg_name) {
127 kernel_builder->cc_builder->HostMemory(arg_name);
128}
129
130void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
131 int32_t priority_number) {
132 kernel_builder->cc_builder->Priority(priority_number);
133}
134
135void TF_KernelBuilder_Label(TF_KernelBuilder* kernel_builder,
136 const char* label) {
137 kernel_builder->cc_builder->Label(label);
138}
139
140namespace tensorflow {
141namespace {
142
143// An OpKernel whose methods delegate to C function pointers.
144class COpKernel : public OpKernel {
145 public:
146 explicit COpKernel(OpKernelConstruction* ctx,
147 void* (*create_func)(TF_OpKernelConstruction*),
148 void (*compute_func)(void*, TF_OpKernelContext*),
149 void (*delete_func)(void*))
150 : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
151 if (create_func != nullptr) {
152 c_kernel_ =
153 (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
154 } else {
155 c_kernel_ = nullptr;
156 }
157 }
158
159 void Compute(OpKernelContext* ctx) override {
160 (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
161 }
162
163 ~COpKernel() override {
164 if (delete_func_ != nullptr) {
165 (*delete_func_)(c_kernel_);
166 }
167 }
168
169 private:
170 void (*compute_func_)(void*, TF_OpKernelContext* context);
171 void (*delete_func_)(void*);
172 void* c_kernel_;
173};
174
175// A KernelFactory that returns COpKernel instances.
176class KernelBuilderFactory
177 : public ::tensorflow::kernel_factory::OpKernelFactory {
178 public:
179 explicit KernelBuilderFactory(TF_KernelBuilder* builder)
180 : builder_(builder) {}
181 ::tensorflow::OpKernel* Create(
182 ::tensorflow::OpKernelConstruction* context) override {
183 return new ::tensorflow::COpKernel(context, builder_->create_function,
184 builder_->compute_function,
185 builder_->delete_function);
186 }
187 ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
188
189 private:
190 TF_KernelBuilder* builder_;
191};
192} // namespace
193} // namespace tensorflow
194
195void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
196 TF_Status* status) {
197 using tensorflow::register_kernel::Name;
198
199 TF_RegisterKernelBuilderWithKernelDef(
200 /*serialized_kernel_def=*/nullptr, name, builder, status);
201}
202
203void TF_RegisterKernelBuilderWithKernelDef(const char* serialized_kernel_def,
204 const char* name,
205 TF_KernelBuilder* builder,
206 TF_Status* status) {
207 using tensorflow::register_kernel::Name;
208 if (serialized_kernel_def == nullptr) {
209 // If user doesn't provide a serialized KernelDef, use the kernel builder
210 // to build a new one.
211 tensorflow::kernel_factory::OpKernelRegistrar(
212 builder->cc_builder->Build(), name,
213 std::make_unique<tensorflow::KernelBuilderFactory>(builder));
214
215 TF_SetStatus(status, TF_OK, "");
216 return;
217 }
218
219 tensorflow::KernelDef* kernel_def = new tensorflow::KernelDef();
220 bool success = kernel_def->ParsePartialFromString(serialized_kernel_def);
221 if (!success) {
222 TF_SetStatus(status, TF_INVALID_ARGUMENT,
223 "Error parsing serialized KernelDef.");
224 return;
225 }
226
227 tensorflow::kernel_factory::OpKernelRegistrar(
228 kernel_def, name,
229 std::make_unique<tensorflow::KernelBuilderFactory>(builder));
230
231 TF_SetStatus(status, TF_OK, "");
232}
233
234// This function is only for pluggable device.
235// It will return nullptr in all other cases.
236// This function is experimental and subject to change.
237SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
238#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
239 status->status = tensorflow::errors::Unimplemented(
240 "Accessing device stream is not supported on mobile. File a bug at "
241 "https://github.com/tensorflow/tensorflow/issues if this feature is "
242 "important to you");
243 return nullptr;
244#else
245 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
246 if (cc_ctx->op_device_context() == nullptr) { // CPU Device
247 status->status = tensorflow::errors::FailedPrecondition(
248 "Accessing device stream is not supported for a CPU device.");
249 return nullptr;
250 } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
251 status->status = tensorflow::errors::FailedPrecondition(
252 "Accessing device stream is only supported for pluggable devices.");
253 return nullptr;
254 } else { // Is a PluggableDevice
255 TF_SetStatus(status, TF_OK, "");
256 auto c_stream = static_cast<stream_executor::CStream*>(
257 cc_ctx->op_device_context()->stream()->implementation());
258 return c_stream->Handle();
259 }
260#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
261}
262
263int TF_NumInputs(TF_OpKernelContext* ctx) {
264 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
265 return cc_ctx->num_inputs();
266}
267
268int TF_NumOutputs(TF_OpKernelContext* ctx) {
269 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
270 return cc_ctx->num_outputs();
271}
272
273void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
274 TF_Status* status) {
275 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
276 if (i < 0 || i >= cc_ctx->num_inputs()) {
277 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
278 return;
279 }
280 const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
281 TF_Tensor* result =
282 ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
283 if (TF_GetCode(status) == TF_OK) {
284 *tensor = result;
285 }
286}
287
288void TF_InputRange(TF_OpKernelContext* ctx, const char* name,
289 TF_InputRange_Args* args) {
290 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
291 int start = -1, stop = -1;
292 auto status = cc_ctx->op_kernel().InputRange(name, &start, &stop);
293 args->start = start;
294 args->stop = stop;
295 tensorflow::Set_TF_Status_from_Status(args->status, status);
296}
297
298void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
299 TF_Status* status) {
300 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
301 if (i < 0 || i >= cc_ctx->num_outputs()) {
302 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
303 return;
304 }
305 ::tensorflow::Tensor cc_tensor;
306 ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
307 TF_SetStatus(status, TF_OK, "");
308 ::tensorflow::Set_TF_Status_from_Status(status, s);
309 if (s.ok()) {
310 cc_ctx->set_output(i, cc_tensor);
311 }
312}
313
314TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, int i,
315 TF_Status* status) {
316 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
317 if (i < 0 || i >= cc_ctx->num_outputs()) {
318 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
319 return nullptr;
320 }
321 const ::tensorflow::Tensor& cc_tensor = *(cc_ctx->mutable_output(i));
322 TF_Tensor* result =
323 ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
324 if (TF_GetCode(status) == TF_OK) {
325 return result;
326 } else {
327 return nullptr;
328 }
329}
330
331void TF_GetSerializedFunctionDefLibrary(
332 TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library,
333 TF_Status* status) {
334 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
335 auto fdef_lib =
336 cc_ctx->function_library()->GetFunctionLibraryDefinition()->ToProto();
337 auto cc_status =
338 tensorflow::MessageToBuffer(fdef_lib, serialized_function_def_library);
339 tensorflow::Set_TF_Status_from_Status(status, cc_status);
340}
341
342void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx,
343 TF_Buffer* serialized_config_proto,
344 TF_Status* status) {
345 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
346 const tensorflow::ConfigProto* config_proto_ptr =
347 cc_ctx->function_library()->config_proto();
348 tensorflow::ConfigProto config_proto;
349 if (config_proto_ptr != nullptr) {
350 config_proto = *config_proto_ptr;
351 }
352 auto cc_status =
353 tensorflow::MessageToBuffer(config_proto, serialized_config_proto);
354 tensorflow::Set_TF_Status_from_Status(status, cc_status);
355}
356
357void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
358 TF_Status* status) {
359 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
360 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
361 cc_ctx->CtxFailure(s);
362}
363
364void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
365 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
366 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
367 cc_ctx->CtxFailure(s);
368}
369
370void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
371 const char* attr_name,
372 int32_t* list_size,
373 int32_t* total_size,
374 TF_Status* status) {
375 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
376 if (!status->status.ok()) {
377 *list_size = -1;
378 *total_size = -1;
379 return;
380 }
381 switch (attr->value_case()) {
382#define SINGLE_CASE(kK, attr_type, size_expr) \
383 case tensorflow::AttrValue::kK: \
384 *list_size = -1; \
385 *total_size = size_expr; \
386 break;
387
388 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
389 SINGLE_CASE(kI, TF_ATTR_INT, -1);
390 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
391 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
392 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
393 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
394 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
395 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
396#undef SINGLE_CASE
397
398 case tensorflow::AttrValue::kList:
399 *list_size = 0;
400 *total_size = -1;
401#define LIST_CASE(field, attr_type, ...) \
402 if (attr->list().field##_size() > 0) { \
403 *list_size = attr->list().field##_size(); \
404 __VA_ARGS__; \
405 break; \
406 }
407
408 LIST_CASE(
409 s, TF_ATTR_STRING, *total_size = 0;
410 for (int i = 0; i < attr->list().s_size();
411 ++i) { *total_size += attr->list().s(i).size(); });
412 LIST_CASE(i, TF_ATTR_INT);
413 LIST_CASE(f, TF_ATTR_FLOAT);
414 LIST_CASE(b, TF_ATTR_BOOL);
415 LIST_CASE(type, TF_ATTR_TYPE);
416 LIST_CASE(
417 shape, TF_ATTR_SHAPE, *total_size = 0;
418 for (int i = 0; i < attr->list().shape_size(); ++i) {
419 const auto& s = attr->list().shape(i);
420 *total_size += s.unknown_rank() ? 0 : s.dim_size();
421 });
422 LIST_CASE(tensor, TF_ATTR_TENSOR);
423 LIST_CASE(tensor, TF_ATTR_FUNC);
424#undef LIST_CASE
425 break;
426
427 case tensorflow::AttrValue::kPlaceholder:
428 *list_size = -1;
429 *total_size = -1;
430 break;
431
432 case tensorflow::AttrValue::kFunc:
433 *list_size = -1;
434 *total_size = -1;
435 break;
436
437 case tensorflow::AttrValue::VALUE_NOT_SET:
438 status->status =
439 InvalidArgument("Attribute '", attr_name, "' has no value set");
440 break;
441 }
442}
443
444#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
445 void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
446 const char* attr_name, \
447 c_type* val, TF_Status* status) { \
448 TF_SetStatus(status, TF_OK, ""); \
449 cc_type v; \
450 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
451 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
452 ::tensorflow::Set_TF_Status_from_Status(status, s); \
453 if (s.ok()) { \
454 *val = static_cast<c_type>(v); \
455 } \
456 } \
457 void TF_OpKernelConstruction_GetAttr##func##List( \
458 TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
459 int max_vals, TF_Status* status) { \
460 TF_SetStatus(status, TF_OK, ""); \
461 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
462 if (!status->status.ok()) return; \
463 if (attr->value_case() != tensorflow::AttrValue::kList) { \
464 status->status = \
465 InvalidArgument("Value for '", attr_name, "' is not a list."); \
466 return; \
467 } \
468 status->status = \
469 tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
470 if (!status->status.ok()) return; \
471 const auto len = std::min(max_vals, attr->list().list_field##_size()); \
472 for (int i = 0; i < len; ++i) { \
473 vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
474 } \
475 }
476
477DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
478DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int", i)
479DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int", i)
480DEFINE_TF_GETATTR(Float, float, float, "float", f)
481DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
482
483void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
484 const char* attr_name, char* value,
485 size_t max_length,
486 TF_Status* status) {
487 std::string v;
488 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
489 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
490 ::tensorflow::Set_TF_Status_from_Status(status, s);
491
492 if (!status->status.ok()) return;
493
494 if (max_length <= 0) {
495 return;
496 }
497 std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
498}
499
500void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
501 const char* attr_name,
502 char** values, size_t* lengths,
503 int max_values, void* storage,
504 size_t storage_size,
505 TF_Status* status) {
506 std::vector<std::string> v;
507 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
508 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
509 ::tensorflow::Set_TF_Status_from_Status(status, s);
510
511 if (!status->status.ok()) return;
512
513 const auto len = std::min(max_values, static_cast<int>(v.size()));
514 char* p = static_cast<char*>(storage);
515 for (int i = 0; i < len; ++i) {
516 const std::string& s = v[i];
517 values[i] = p;
518 lengths[i] = s.size();
519 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
520 status->status = InvalidArgument(
521 "Not enough storage to hold the requested list of strings");
522 return;
523 }
524 memcpy(values[i], s.data(), s.size());
525 p += s.size();
526 }
527}
528
529void TF_OpKernelConstruction_GetAttrTensor(TF_OpKernelConstruction* ctx,
530 const char* attr_name,
531 TF_Tensor** val, TF_Status* status) {
532 *val = nullptr;
533 ::tensorflow::Tensor t;
534 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
535 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &t);
536 ::tensorflow::Set_TF_Status_from_Status(status, s);
537
538 if (!status->status.ok()) return;
539
540 *val = TF_TensorFromTensor(t, &status->status);
541}
542
543void TF_OpKernelConstruction_GetAttrTensorList(TF_OpKernelConstruction* ctx,
544 const char* attr_name,
545 TF_Tensor** vals, int max_values,
546 TF_Status* status) {
547 std::vector<::tensorflow::Tensor> v;
548 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
549 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
550 ::tensorflow::Set_TF_Status_from_Status(status, s);
551
552 if (!status->status.ok()) return;
553
554 const auto len = std::min(max_values, static_cast<int>(v.size()));
555 for (int i = 0; i < len; ++i) {
556 vals[i] = TF_TensorFromTensor(v[i], &status->status);
557 if (!status->status.ok()) return;
558 }
559}
560
561TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction* ctx,
562 const char* attr_name,
563 TF_Status* status) {
564 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
565 tensorflow::NameAttrList function;
566 auto cc_status = cc_ctx->GetAttr(attr_name, &function);
567 if (!cc_status.ok()) {
568 Set_TF_Status_from_Status(status, cc_status);
569 return nullptr;
570 }
571 TF_Buffer* buffer = TF_NewBuffer();
572 cc_status = tensorflow::MessageToBuffer(function, buffer);
573 Set_TF_Status_from_Status(status, cc_status);
574 if (!cc_status.ok())
575 return nullptr;
576 else
577 return buffer;
578}
579
580bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
581 const char* attr_name, TF_Status* status) {
582 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
583 return cc_ctx->HasAttr(attr_name);
584}
585
586TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
587 auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
588 TF_StringView string_view_of_name;
589 string_view_of_name.data = cc_ctx->def().name().data();
590 string_view_of_name.len = cc_ctx->def().name().length();
591 return string_view_of_name;
592}
593
594TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
595 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
596 CHECK_GE(i, 0);
597 CHECK_LT(i, cc_ctx->num_outputs());
598 return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
599}
600
601bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
602 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
603 if (i < 0 || i >= cc_ctx->num_inputs()) {
604 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
605 return false;
606 }
607 TF_SetStatus(status, TF_OK, "");
608 return cc_ctx->input_memory_type(i) == tensorflow::HOST_MEMORY;
609}
610
611bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
612 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
613 if (i < 0 || i >= cc_ctx->num_outputs()) {
614 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
615 return false;
616 }
617 TF_SetStatus(status, TF_OK, "");
618 return cc_ctx->output_memory_type(i) == tensorflow::HOST_MEMORY;
619}
620
621int64_t TF_StepId(TF_OpKernelContext* ctx) {
622 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
623}
624
625TF_Buffer* TF_OpKernelConstruction_GetNodeDef(TF_OpKernelConstruction* ctx,
626 TF_Status* status) {
627 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
628 TF_Buffer* ret = TF_NewBuffer();
629 status->status = MessageToBuffer(cc_ctx->def(), ret);
630 if (!status->status.ok()) {
631 TF_DeleteBuffer(ret);
632 return nullptr;
633 }
634 return ret;
635}
636
637uint64_t TF_GetFrameId(TF_OpKernelContext* ctx) {
638 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
639 ->frame_iter()
640 .frame_id;
641}
642
643int TF_GetGraphDefVersion(TF_OpKernelContext* ctx) {
644 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
645 ->function_library()
646 ->graph_def_version();
647}
648
649int64_t TF_GetIterId(TF_OpKernelContext* ctx) {
650 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
651 ->frame_iter()
652 .iter_id;
653}
654
655TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) {
656 auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
657 TF_StringView opkernel_name_sv;
658 opkernel_name_sv.data = cc_ctx->op_kernel().name().data();
659 opkernel_name_sv.len = cc_ctx->op_kernel().name().length();
660 return opkernel_name_sv;
661}
662
663TF_StringView TF_GetResourceMgrDefaultContainerName(TF_OpKernelContext* ctx) {
664 auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
665 TF_StringView default_container_name_sv;
666 default_container_name_sv.data =
667 cc_ctx->resource_manager()->default_container().data();
668 default_container_name_sv.len =
669 cc_ctx->resource_manager()->default_container().length();
670 return default_container_name_sv;
671}
672
673TF_StringView TF_GetOpKernelRequestedInput(TF_OpKernelContext* ctx,
674 size_t index) {
675 auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
676 TF_StringView requested_input_sv;
677 requested_input_sv.data = cc_ctx->op_kernel().requested_input(index).data();
678 requested_input_sv.len = cc_ctx->op_kernel().requested_input(index).length();
679 return requested_input_sv;
680}
681
682TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
683 TF_DataType dtype, const int64_t* dims,
684 int num_dims, size_t len, TF_Status* status) {
685 TF_SetStatus(status, TF_OK, "");
686 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
687 static_assert(sizeof(int64_t) == sizeof(int64_t),
688 "64-bit int types should match in size");
689 tensorflow::gtl::ArraySlice<const int64_t> dimarray(
690 reinterpret_cast<const int64_t*>(dims), num_dims);
691 tensorflow::Tensor* tensor;
692 tensorflow::Status s = cc_ctx->allocate_output(
693 index, tensorflow::TensorShape(dimarray), &tensor);
694 if (!s.ok()) {
695 ::tensorflow::Set_TF_Status_from_Status(status, s);
696 return nullptr;
697 }
698 TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
699 if (!s.ok()) {
700 ::tensorflow::Set_TF_Status_from_Status(status, s);
701 return nullptr;
702 }
703 return tf_tensor;
704}
705
706TF_Tensor* TF_ForwardInputOrAllocateOutput(
707 TF_OpKernelContext* context, const int* candidate_input_indices,
708 int num_candidate_input_indices, int output_index,
709 const int64_t* output_dims, int output_num_dims, int* forwarded_input,
710 TF_Status* status) {
711 TF_SetStatus(status, TF_OK, "");
712 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
713
714 static_assert(sizeof(int64_t) == sizeof(int64_t),
715 "64-bit int types should match in size");
716 tensorflow::gtl::ArraySlice<int> input_indices_array(
717 candidate_input_indices, num_candidate_input_indices);
718 tensorflow::gtl::ArraySlice<const int64_t> output_dimarray(
719 reinterpret_cast<const int64_t*>(output_dims), output_num_dims);
720 tensorflow::Tensor* output_tensor_pointer;
721 tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
722 input_indices_array, output_index,
723 tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
724 forwarded_input);
725 if (!s.ok()) {
726 ::tensorflow::Set_TF_Status_from_Status(status, s);
727 return nullptr;
728 }
729 TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
730 if (!s.ok()) {
731 ::tensorflow::Set_TF_Status_from_Status(status, s);
732 return nullptr;
733 }
734 return tf_tensor_output;
735}
736
737TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
738 const int64_t* dims, int num_dims,
739 TF_AllocatorAttributes* attributes,
740 TF_Status* status) {
741 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
742 TF_SetStatus(status, TF_OK, "");
743 static_assert(sizeof(int64_t) == sizeof(int64_t),
744 "64-bit int types should match in size");
745 tensorflow::gtl::ArraySlice<const int64_t> dimarray(
746 reinterpret_cast<const int64_t*>(dims), num_dims);
747 if (attributes && !attributes->struct_size) {
748 TF_SetStatus(
749 status, TF_INVALID_ARGUMENT,
750 "TF_AllocatorAttributes struct "
751 "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
752 return nullptr;
753 }
754 tensorflow::AllocatorAttributes allocator_attr;
755 if (attributes && attributes->on_host) {
756 allocator_attr.set_on_host(true);
757 }
758 tensorflow::Status s;
759 tensorflow::Tensor tensor;
760 s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
761 tensorflow::TensorShape(dimarray), &tensor,
762 allocator_attr);
763 if (!s.ok()) {
764 ::tensorflow::Set_TF_Status_from_Status(status, s);
765 return nullptr;
766 }
767 TF_Tensor* tf_tensor;
768 tf_tensor = TF_TensorFromTensor(tensor, &s);
769 if (!s.ok()) {
770 ::tensorflow::Set_TF_Status_from_Status(status, s);
771 return nullptr;
772 }
773 return tf_tensor;
774}
775