1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
17#define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
18
19#include "tensorflow/c/kernels.h"
20
21// --------------------------------------------------------------------------
22// Experimental kernel C API for TensorFlow.
23//
24// The API here is subject to changes in the future.
25// --------------------------------------------------------------------------
26
27// Macro to control visibility of exported symbols in the shared library (.so,
28// .dylib, .dll).
29// This duplicates the TF_EXPORT macro definition in
30// tensorflow/core/platform/macros.h in order to keep this .h file independent
31// of any other includes.
32#ifdef SWIG
33#define TF_CAPI_EXPORT
34#else
35#if defined(_WIN32)
36#ifdef TF_COMPILE_LIBRARY
37#define TF_CAPI_EXPORT __declspec(dllexport)
38#else
39#define TF_CAPI_EXPORT __declspec(dllimport)
40#endif // TF_COMPILE_LIBRARY
41#else
42#define TF_CAPI_EXPORT __attribute__((visibility("default")))
43#endif // _WIN32
44#endif // SWIG
45
46#ifdef __cplusplus
47extern "C" {
48#endif
49
50typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder;
51
52// Expose higher level Assignment operation for Pluggable vendors to implement
53// in the plugin for Training. The API takes in the context with indices for
54// the input and value tensors. It also accepts the copy callback provided by
55// pluggable vendor to do the copying of the tensors. The caller takes ownership
56// of the `source` and `dest` tensors and is responsible for freeing them with
57// TF_DeleteTensor. This function will return an error when the following
58// conditions are met:
59// 1. `validate_shape` is set to `true`
60// 2. The variable is initialized
61// 3. The shape of the value tensor doesn't match the shape of the variable
62// tensor.
63TF_CAPI_EXPORT extern void TF_AssignVariable(
64 TF_OpKernelContext* ctx, int input_index, int value_index,
65 bool validate_shape,
66 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
67 TF_Tensor* dest),
68 TF_Status* status);
69
70// Expose higher level Assignment operation for Pluggable vendors to implement
71// in the plugin for Training on ref variables. The API takes in the context
72// with indices for the input and value tensors. It also accepts the copy
73// callback provided by pluggable vendor to do the copying of the tensors. The
74// caller takes ownership of the `source` and `dest` tensors and is responsible
75// for freeing them with TF_DeleteTensor.
76TF_CAPI_EXPORT extern void TF_AssignRefVariable(
77 TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index,
78 int value_index, bool use_locking, bool validate_shape,
79 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
80 TF_Tensor* dest),
81 TF_Status* status);
82
83// Expose higher level AssignUpdate operation for Pluggable vendors to implement
84// in the plugin for Training. The API takes in the context with indices for the
85// input and value tensors. It also accepts the copy callback provided by
86// pluggable vendor to do the copying of the tensors and the update callback to
87// apply the arithmetic operation. The caller takes ownership of the `source`,
88// `dest`, `tensor` and `value` tensors and is responsible for freeing them with
89// TF_DeleteTensor.
90TF_CAPI_EXPORT extern void TF_AssignUpdateVariable(
91 TF_OpKernelContext* ctx, int input_index, int value_index, int Op,
92 int isVariantType,
93 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
94 TF_Tensor* dest),
95 void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor,
96 TF_Tensor* value, int Op),
97 TF_Status* status);
98
99// This is a helper function which acquires mutexes in-order to provide
100// thread-safe way of performing weights update during the optimizer op. It
101// returns an opaque LockHolder handle back to plugin. This handle is passed to
102// the Release API for releasing the locks when the weight update is done. The
103// caller takes ownership of the `source` and `dest` tensors and is responsible
104// for freeing them with TF_DeleteTensor.
105TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder(
106 TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
107 size_t len,
108 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
109 TF_Tensor* dest),
110 TF_VariableInputLockHolder** lockHolder, TF_Status* status);
111
112// This interface returns `out` tensor which is updated corresponding to the
113// variable passed with input index. The caller takes ownership of the `source`
114// and `dest` tensors and is responsible for freeing them with TF_DeleteTensor.
115TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable(
116 TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType,
117 bool sparse,
118 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
119 TF_Tensor* dest),
120 TF_Tensor** out, TF_Status* status);
121
122// This interface forwards the reference from input to the output tensors
123// corresponding to the indices provided with `input_index` and `output_index`
124TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput(
125 TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index);
126
127// The API releases the opaque lock handle returned with
128// `TF_MaybeLockVariableInputMutexesInOrder` API
129TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder(
130 TF_VariableInputLockHolder* lockHolder);
131
132// Allows plugin to get TF_Tensor when passed its input_name
133TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx,
134 const char* inputName,
135 TF_Tensor** tensor,
136 TF_Status* status);
137
138// Interprets the named kernel construction attribute as a shape attribute and
139// fills in `vals` with the size of each dimension. `vals` must point to an
140// array of length at least `max_values` (ideally set to total_size from
141// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,
142// &total_size)).
143TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape(
144 TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims,
145 size_t num_dims, TF_Status* status);
146
147TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i,
148 TF_Status* status);
149
150#ifndef IS_MOBILE_PLATFORM
151// Expose higher level AddN operation for Pluggable vendors to implement
152// in the plugin for Variant data types. The API takes in the context and a
153// callback provided by pluggable vendor to do a Binary Add operation on the
154// tensors unwrapped from the Variant tensors. The caller takes ownership of the
155// `a`, `b` and `out` tensors and is responsible for freeing them with
156// TF_DeleteTensor.
157TF_CAPI_EXPORT extern void TF_AddNVariant(
158 TF_OpKernelContext* ctx,
159 void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
160 TF_Tensor* out),
161 TF_Status* status);
162
163// Expose higher level ZerosLike operation for Pluggable vendors to implement
164// in the plugin for Variant data types. The API takes in the context and a
165// callback provided by pluggable vendor to do a ZerosLike operation on the
166// tensors unwrapped from the Variant tensors. The caller takes ownership of the
167// `input` and `out` tensors and is responsible for freeing them with
168// TF_DeleteTensor.
169TF_CAPI_EXPORT extern void TF_ZerosLikeVariant(
170 TF_OpKernelContext* ctx,
171 void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input,
172 TF_Tensor* out),
173 TF_Status* status);
174
175typedef struct TF_CoordinationServiceAgent TF_CoordinationServiceAgent;
176
177#endif // IS_MOBILE_PLATFORM
178
179#ifdef __cplusplus
180} /* end extern "C" */
181#endif
182
183#endif // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
184