1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_KERNELS_H_
17#define TENSORFLOW_C_KERNELS_H_
18
19#include <stdint.h>
20
21#include "tensorflow/c/c_api.h"
22#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
23#include "tensorflow/c/tf_datatype.h"
24#include "tensorflow/c/tf_status.h"
25#include "tensorflow/c/tf_tensor.h"
26
27// Macro to control visibility of exported symbols in the shared library (.so,
28// .dylib, .dll).
29// This duplicates the TF_EXPORT macro definition in
30// tensorflow/core/platform/macros.h in order to keep this .h file independent
31// of any other includes.
32#ifdef SWIG
33#define TF_CAPI_EXPORT
34#else
35#if defined(_WIN32)
36#ifdef TF_COMPILE_LIBRARY
37#define TF_CAPI_EXPORT __declspec(dllexport)
38#else
39#define TF_CAPI_EXPORT __declspec(dllimport)
40#endif // TF_COMPILE_LIBRARY
41#else
42#define TF_CAPI_EXPORT __attribute__((visibility("default")))
43#endif // _WIN32
44#endif // SWIG
45
46#ifdef __cplusplus
47extern "C" {
48#endif
49
50typedef struct TF_Tensor TF_Tensor;
51
52// --------------------------------------------------------------------------
53// C API for TensorFlow Kernels.
54//
55// This API allows developers to register custom kernel implementations for
56// TensorFlow.
57//
58// See c_api.h header comments for a discussion about API conventions.
59//
60// Users wishing to extend TensorFlow with new kernels will call
61// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
62// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
63// kernels when necessary.
64
65typedef struct TF_KernelBuilder TF_KernelBuilder;
66typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
67typedef struct TF_OpKernelContext TF_OpKernelContext;
68
69// TF_InitKernel to do op/kernel registration.
70// Plugin should implement TF_InitKernel to register kernels. This function
71// should register all kernels in a plugin.
72void TF_InitKernel();
73
74// Allocates a new kernel builder and returns a pointer to it.
75//
76// If non-null, TensorFlow will call create_func when it needs to instantiate
77// the kernel. The pointer returned by create_func will be passed to
78// compute_func and delete_func, thereby functioning as a "this" pointer for
79// referring to kernel instances.
80//
81// The TF_OpKernelConstruction pointer passed to create_func is owned by
82// TensorFlow and will be deleted once create_func returns. It must not be used
83// after this.
84//
85// When TensorFlow needs to perform a computation with this kernel, it will
86// call compute_func. This function will receive the pointer returned by
87// create_func (or null if no create_func was provided), along with the inputs
88// to the computation.
89//
90// The TF_OpKernelContext pointer received by compute_func is owned by
91// TensorFlow and will be deleted once compute_func returns. It must not be used
92// after this.
93//
94// Finally, when TensorFlow no longer needs the kernel, it will call
95// delete_func if one is provided. This function will receive the pointer
96// returned in `create_func` or nullptr if no `create_func` was provided.
97//
98// The caller should pass the result of this function to
99// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
100// some reason, the kernel builder will not be registered, the caller should
101// delete it with TF_DeleteKernelBuilder.
102TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
103 const char* op_name, const char* device_name,
104 void* (*create_func)(TF_OpKernelConstruction*),
105 void (*compute_func)(void*, TF_OpKernelContext*),
106 void (*delete_func)(void*));
107
108// Specifies that this kernel's attribute only supports the given type.
109TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint(
110 TF_KernelBuilder* kernel_builder, const char* attr_name,
111 const TF_DataType type, TF_Status* status);
112
113// Specify that this kernel requires/provides an input/output arg
114// in host memory (instead of the default, device memory).
115TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
116 TF_KernelBuilder* kernel_builder, const char* arg_name);
117
118// Specify a priority number for this kernel.
119TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
120 TF_KernelBuilder* kernel_builder, int32_t priority_number);
121
122// Specify a label for this kernel.
123TF_CAPI_EXPORT extern void TF_KernelBuilder_Label(
124 TF_KernelBuilder* kernel_builder, const char* label);
125
126// Register the given kernel builder with the TensorFlow runtime. If
127// registration fails, the given status will be populated.
128//
129// This call takes ownership of the `builder` pointer.
130TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
131 TF_KernelBuilder* builder,
132 TF_Status* status);
133
134// Register the given kernel builder with the TensorFlow runtime. If
135// registration fails, the given status will be populated.
136//
137// This method is the same as TF_RegisterKernelBuilder except it takes in a
138// serialized KernelDef, and uses it for registration, instead of building a new
139// one. Users can choose to not provide a serialized KernelDef and in that case
140// it's identical to TF_RegisterKernelBuilder.
141TF_CAPI_EXPORT extern void TF_RegisterKernelBuilderWithKernelDef(
142 const char* serialized_kernel_def, const char* name,
143 TF_KernelBuilder* builder, TF_Status* status);
144
145// Deletes the given TF_KernelBuilder. This should be called only if the kernel
146// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
147TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
148
149// --------------------------------------------------------------------------
150// OpKernelContext routines
151
152// TF_GetStream returns the SP_Stream available in ctx.
153// This function returns a stream only for devices registered using the
154// StreamExecutor C API
155// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
156// nullptr and set error status in all other cases.
157// Experimental: this function doesn't have compatibility guarantees and subject
158// to change at any time.
159TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
160 TF_Status* status);
161
162// TF_NumInputs returns the number of inputs available in ctx.
163TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
164
165// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
166// kernel.
167TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
168
169// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
170// populated and its ownership is passed to the caller. In any other case,
171// *tensor is not modified.
172//
173// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
174TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
175 TF_Tensor** tensor, TF_Status* status);
176
177typedef struct {
178 size_t struct_size;
179 void* priv; // Not used, for possible extension.
180 int start; // output
181 int stop; // output
182 TF_Status* status; // output
183} TF_InputRange_Args;
184const size_t TF_InputRange_Args_STRUCT_SIZE =
185 TF_OFFSET_OF_END(TF_InputRange_Args, status);
186
187// Retrieves the start and stop indices, given the input name. Equivalent to
188// OpKernel::InputRange(). `args` will contain the result indices and status.
189TF_CAPI_EXPORT extern void TF_InputRange(TF_OpKernelContext* ctx,
190 const char* name,
191 TF_InputRange_Args* args);
192
193// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
194// TF_OK, ctx is left unmodified.
195//
196// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
197TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i,
198 const TF_Tensor* tensor,
199 TF_Status* status);
200
201// Retrieves the ith output from ctx. If TF_GetCode(status) is TF_OK, *tensor is
202// populated and its ownership is passed to the caller. In any other case,
203// *tensor is not modified.
204//
205// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
206TF_CAPI_EXPORT extern TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx,
207 int i, TF_Status* status);
208
209// Retrieves a serialized FunctionDefLibrary. Status will be set.
210TF_CAPI_EXPORT extern void TF_GetSerializedFunctionDefLibrary(
211 TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library,
212 TF_Status* status);
213
214// Retrieves a serialized ConfigProto. Status will be set.
215TF_CAPI_EXPORT extern void TF_GetSerializedConfigProto(
216 TF_OpKernelContext* ctx, TF_Buffer* serialized_config_proto,
217 TF_Status* status);
218
219// Notifies the given OpKernelConstruction that kernel construction has failed.
220TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure(
221 TF_OpKernelConstruction* ctx, TF_Status* status);
222
223// Notifies the given OpKernelContext that the kernel's compute function has
224// failed.
225TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx,
226 TF_Status* status);
227
228// Returns the expected output data type of the ith output. If i < 0 or
229// i >= TF_NumOutputs(ctx), the program aborts.
230TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
231 TF_OpKernelContext* ctx, int i);
232
233// Returns true if the ith input is allocated in host memory. If i < 0 or i >=
234// TF_NumInputs(ctx), the program aborts.
235TF_CAPI_EXPORT extern bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i,
236 TF_Status* status);
237
238// Returns true if the ith output is allocated in host memory. If i < 0 or i >=
239// TF_NumOutputs(ctx), the program aborts.
240TF_CAPI_EXPORT extern bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i,
241 TF_Status* status);
242
243// Returns the step ID of the given context.
244TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
245
246// Returns the serialized NodeDef protocol buffer for the kernel
247TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetNodeDef(
248 TF_OpKernelConstruction* ctx, TF_Status* status);
249
250// Returns the frame ID of the given context.
251TF_CAPI_EXPORT extern uint64_t TF_GetFrameId(TF_OpKernelContext* ctx);
252
253// Returns the Iter ID of the given context.
254TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx);
255
256// Returns the graph def version of the given context.
257TF_CAPI_EXPORT extern int TF_GetGraphDefVersion(TF_OpKernelContext* ctx);
258
259// Returns the name of the OpKernel.
260//
261// The returned TF_StringView's underlying string is owned by the OpKernel and
262// has the same lifetime as the OpKernel.
263TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx);
264
265// Returns the default container of the resource manager in OpKernelContext.
266//
267// The returned TF_StringView's underlying string is owned by the OpKernel and
268// has the same lifetime as the OpKernel.
269TF_CAPI_EXPORT extern TF_StringView TF_GetResourceMgrDefaultContainerName(
270 TF_OpKernelContext* ctx);
271
272// Returns the name of the requested input at `index` from the OpKernel.
273//
274// The returned TF_StringView's underlying string is owned by the OpKernel and
275// has the same lifetime as the OpKernel.
276TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelRequestedInput(
277 TF_OpKernelContext* ctx, size_t index);
278
279// Get the list_size and total_size of the attribute `attr_name` of `oper`.
280// list_size - the length of the list.
281// total_size - total size of the list.
282// (1) If attr_type == TF_ATTR_STRING
283// then total_size is the cumulative byte size
284// of all the strings in the list.
285// (3) If attr_type == TF_ATTR_SHAPE
286// then total_size is the number of dimensions
287// of the shape valued attribute, or -1
288// if its rank is unknown.
289// (4) If attr_type == TF_ATTR_SHAPE
290// then total_size is the cumulative number
291// of dimensions of all shapes in the list.
292// (5) Otherwise, total_size is undefined.
293TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
294 TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
295 int32_t* total_size, TF_Status* status);
296
297// Interprets the named kernel construction attribute as a TF_DataType and
298// places it into *val. *status is set to TF_OK.
299//
300// If the attribute could not be found or could not be interpreted as
301// TF_DataType, *status is populated with an error.
302TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType(
303 TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val,
304 TF_Status* status);
305
306// Interprets the named kernel construction attribute as int32_t and
307// places it into *val. *status is set to TF_OK.
308//
309// If the attribute could not be found or could not be interpreted as
310// int32, *status is populated with an error.
311TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
312 TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
313 TF_Status* status);
314
315// Interprets the named kernel construction attribute as int64_t and
316// places it into *val. *status is set to TF_OK.
317//
318// If the attribute could not be found or could not be interpreted as
319// int64, *status is populated with an error.
320TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
321 TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
322 TF_Status* status);
323
324// Interprets the named kernel construction attribute as float and
325// places it into *val. *status is set to TF_OK.
326//
327// If the attribute could not be found or could not be interpreted as
328// float, *status is populated with an error.
329TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
330 TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
331 TF_Status* status);
332
333// Interprets the named kernel construction attribute as bool and
334// places it into *val. *status is set to TF_OK.
335//
336// If the attribute could not be found or could not be interpreted as
337// bool, *status is populated with an error.
338TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
339 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
340 TF_Status* status);
341
342// Interprets the named kernel construction attribute as string and
343// places it into *val. `val` must
344// point to an array of length at least `max_length` (ideally set to
345// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
346// attr_name, list_size, total_size)). *status is set to TF_OK.
347//
348// If the attribute could not be found or could not be interpreted as
349// string, *status is populated with an error.
350TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
351 TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
352 size_t max_length, TF_Status* status);
353
354// Interprets the named kernel construction attribute as tensor and places it
355// into *val. Allocates a new TF_Tensor which the caller is expected to take
356// ownership of (and can deallocate using TF_DeleteTensor). *status is set to
357// TF_OK.
358//
359// If the attribute could not be found or could not be interpreted as
360// tensor, *status is populated with an error.
361TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensor(
362 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** val,
363 TF_Status* status);
364
365// Interprets the named kernel construction attribute as a TF_DataType array and
366// places it into *vals. *status is set to TF_OK.
367// `vals` must point to an array of length at least `max_values` (ideally set
368// to list_size from
369// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
370// total_size)).
371TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
372 TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
373 int max_vals, TF_Status* status);
374
375// Interprets the named kernel construction attribute as int32_t array and
376// places it into *vals. *status is set to TF_OK.
377// `vals` must point to an array of length at least `max_values` (ideally set
378// to list_size from
379// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
380// total_size)).
381TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
382 TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
383 int max_vals, TF_Status* status);
384
385// Interprets the named kernel construction attribute as int64_t array and
386// places it into *vals. *status is set to TF_OK.
387// `vals` must point to an array of length at least `max_values` (ideally set
388// to list_size from
389// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
390// total_size)).
391TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
392 TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
393 int max_vals, TF_Status* status);
394
395// Interprets the named kernel construction attribute as float array and
396// places it into *vals. *status is set to TF_OK.
397// `vals` must point to an array of length at least `max_values` (ideally set
398// to list_size from
399// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
400// total_size)).
401TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
402 TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
403 int max_vals, TF_Status* status);
404
405// Interprets the named kernel construction attribute as bool array and
406// places it into *vals. *status is set to TF_OK.
407// `vals` must point to an array of length at least `max_values` (ideally set
408// to list_size from
409// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
410// total_size)).
411TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
412 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
413 int max_vals, TF_Status* status);
414
415// Interprets the named kernel construction attribute as string array and fills
416// in `vals` and `lengths`, each of which must point to an array of length at
417// least `max_values`. *status is set to TF_OK. The elements of values will
418// point to addresses in `storage` which must be at least `storage_size` bytes
419// in length. Ideally, max_values would be set to list_size and `storage` would
420// be at least total_size, obtained from
421// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
422// total_size).
423TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
424 TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
425 size_t* lengths, int max_values, void* storage, size_t storage_size,
426 TF_Status* status);
427
428// Interprets the named kernel construction attribute as tensor array and places
429// it into *vals. *status is set to TF_OK.
430// `vals` must point to an array of length at least `max_values`
431// (ideally set to list_size from TF_OpKernelConstruction_GetAttrSize(ctx,
432// attr_name, list_size, total_size)).
433//
434// The caller takes ownership of all the non-null TF_Tensor* entries in `vals`
435// (which can be deleted using TF_DeleteTensor(vals[i])).
436TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorList(
437 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** vals,
438 int max_values, TF_Status* status);
439
440// Interprets the named kernel construction attribute as a
441// tensorflow::NameAttrList and returns the serialized proto as TF_Buffer.
442// `status` will be set. The caller takes ownership of the returned TF_Buffer
443// (if not null) and is responsible for managing its lifetime.
444TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(
445 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
446
447// Return true if the kernel construction has the attr_name
448TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
449 TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
450
451// Returns the unique operation name for this OpKernel.
452TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
453 TF_OpKernelConstruction* ctx);
454
455// Allocates Tensor for output at given index. Caller takes ownership of
456// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
457//
458// This function should be used to allocate outputs inside kernel
459// compute function.
460TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
461 int index, TF_DataType dtype,
462 const int64_t* dims, int num_dims,
463 size_t len, TF_Status* status);
464
465// Tries to forward one of the inputs given in input_indices to
466// output[output_index]. If none of the given inputs can be forwarded, calls
467// allocate_output() to allocate a new output buffer. The index of the
468// forwarded input will be assign to output argument forwarded_input (if it's
469// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
470// -1.
471TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
472 TF_OpKernelContext* context, const int* candidate_input_indices,
473 int num_candidate_input_indices, int output_index,
474 const int64_t* output_dims, int output_num_dims, int* forwarded_input,
475 TF_Status* status);
476
477// Allocates a temporary Tensor of the specified type and shape. The
478// Tensor must not be used after kernel construction is
479// complete.
480//
481// num_dims must equal the size of array dims
482TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp(
483 TF_OpKernelContext* context, TF_DataType dtype, const int64_t* dims,
484 int num_dims, TF_AllocatorAttributes* alloc_attrs, TF_Status* status);
485
486#ifdef __cplusplus
487} /* end extern "C" */
488#endif
489
490#endif // TENSORFLOW_C_KERNELS_H_
491