1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ |
17 | #define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ |
18 | |
19 | #include "tensorflow/c/kernels.h" |
20 | |
21 | // -------------------------------------------------------------------------- |
22 | // Experimental kernel C API for TensorFlow. |
23 | // |
24 | // The API here is subject to changes in the future. |
25 | // -------------------------------------------------------------------------- |
26 | |
27 | // Macro to control visibility of exported symbols in the shared library (.so, |
28 | // .dylib, .dll). |
29 | // This duplicates the TF_EXPORT macro definition in |
30 | // tensorflow/core/platform/macros.h in order to keep this .h file independent |
31 | // of any other includes. |
32 | #ifdef SWIG |
33 | #define TF_CAPI_EXPORT |
34 | #else |
35 | #if defined(_WIN32) |
36 | #ifdef TF_COMPILE_LIBRARY |
37 | #define TF_CAPI_EXPORT __declspec(dllexport) |
38 | #else |
39 | #define TF_CAPI_EXPORT __declspec(dllimport) |
40 | #endif // TF_COMPILE_LIBRARY |
41 | #else |
42 | #define TF_CAPI_EXPORT __attribute__((visibility("default"))) |
43 | #endif // _WIN32 |
44 | #endif // SWIG |
45 | |
46 | #ifdef __cplusplus |
47 | extern "C" { |
48 | #endif |
49 | |
50 | typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder; |
51 | |
52 | // Expose higher level Assignment operation for Pluggable vendors to implement |
53 | // in the plugin for Training. The API takes in the context with indices for |
54 | // the input and value tensors. It also accepts the copy callback provided by |
55 | // pluggable vendor to do the copying of the tensors. The caller takes ownership |
56 | // of the `source` and `dest` tensors and is responsible for freeing them with |
57 | // TF_DeleteTensor. This function will return an error when the following |
58 | // conditions are met: |
59 | // 1. `validate_shape` is set to `true` |
60 | // 2. The variable is initialized |
61 | // 3. The shape of the value tensor doesn't match the shape of the variable |
62 | // tensor. |
63 | TF_CAPI_EXPORT extern void TF_AssignVariable( |
64 | TF_OpKernelContext* ctx, int input_index, int value_index, |
65 | bool validate_shape, |
66 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
67 | TF_Tensor* dest), |
68 | TF_Status* status); |
69 | |
70 | // Expose higher level Assignment operation for Pluggable vendors to implement |
71 | // in the plugin for Training on ref variables. The API takes in the context |
72 | // with indices for the input and value tensors. It also accepts the copy |
73 | // callback provided by pluggable vendor to do the copying of the tensors. The |
74 | // caller takes ownership of the `source` and `dest` tensors and is responsible |
75 | // for freeing them with TF_DeleteTensor. |
76 | TF_CAPI_EXPORT extern void TF_AssignRefVariable( |
77 | TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index, |
78 | int value_index, bool use_locking, bool validate_shape, |
79 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
80 | TF_Tensor* dest), |
81 | TF_Status* status); |
82 | |
83 | // Expose higher level AssignUpdate operation for Pluggable vendors to implement |
84 | // in the plugin for Training. The API takes in the context with indices for the |
85 | // input and value tensors. It also accepts the copy callback provided by |
86 | // pluggable vendor to do the copying of the tensors and the update callback to |
87 | // apply the arithmetic operation. The caller takes ownership of the `source`, |
88 | // `dest`, `tensor` and `value` tensors and is responsible for freeing them with |
89 | // TF_DeleteTensor. |
90 | TF_CAPI_EXPORT extern void TF_AssignUpdateVariable( |
91 | TF_OpKernelContext* ctx, int input_index, int value_index, int Op, |
92 | int isVariantType, |
93 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
94 | TF_Tensor* dest), |
95 | void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor, |
96 | TF_Tensor* value, int Op), |
97 | TF_Status* status); |
98 | |
99 | // This is a helper function which acquires mutexes in-order to provide |
100 | // thread-safe way of performing weights update during the optimizer op. It |
101 | // returns an opaque LockHolder handle back to plugin. This handle is passed to |
102 | // the Release API for releasing the locks when the weight update is done. The |
103 | // caller takes ownership of the `source` and `dest` tensors and is responsible |
104 | // for freeing them with TF_DeleteTensor. |
105 | TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder( |
106 | TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs, |
107 | size_t len, |
108 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
109 | TF_Tensor* dest), |
110 | TF_VariableInputLockHolder** lockHolder, TF_Status* status); |
111 | |
112 | // This interface returns `out` tensor which is updated corresponding to the |
113 | // variable passed with input index. The caller takes ownership of the `source` |
114 | // and `dest` tensors and is responsible for freeing them with TF_DeleteTensor. |
115 | TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable( |
116 | TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType, |
117 | bool sparse, |
118 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
119 | TF_Tensor* dest), |
120 | TF_Tensor** out, TF_Status* status); |
121 | |
122 | // This interface forwards the reference from input to the output tensors |
123 | // corresponding to the indices provided with `input_index` and `output_index` |
124 | TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput( |
125 | TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index); |
126 | |
127 | // The API releases the opaque lock handle returned with |
128 | // `TF_MaybeLockVariableInputMutexesInOrder` API |
129 | TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder( |
130 | TF_VariableInputLockHolder* lockHolder); |
131 | |
132 | // Allows plugin to get TF_Tensor when passed its input_name |
133 | TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx, |
134 | const char* inputName, |
135 | TF_Tensor** tensor, |
136 | TF_Status* status); |
137 | |
138 | // Interprets the named kernel construction attribute as a shape attribute and |
139 | // fills in `vals` with the size of each dimension. `vals` must point to an |
140 | // array of length at least `max_values` (ideally set to total_size from |
141 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, |
142 | // &total_size)). |
143 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape( |
144 | TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims, |
145 | size_t num_dims, TF_Status* status); |
146 | |
147 | TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, |
148 | TF_Status* status); |
149 | |
150 | #ifndef IS_MOBILE_PLATFORM |
151 | // Expose higher level AddN operation for Pluggable vendors to implement |
152 | // in the plugin for Variant data types. The API takes in the context and a |
153 | // callback provided by pluggable vendor to do a Binary Add operation on the |
154 | // tensors unwrapped from the Variant tensors. The caller takes ownership of the |
155 | // `a`, `b` and `out` tensors and is responsible for freeing them with |
156 | // TF_DeleteTensor. |
157 | TF_CAPI_EXPORT extern void TF_AddNVariant( |
158 | TF_OpKernelContext* ctx, |
159 | void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, |
160 | TF_Tensor* out), |
161 | TF_Status* status); |
162 | |
163 | // Expose higher level ZerosLike operation for Pluggable vendors to implement |
164 | // in the plugin for Variant data types. The API takes in the context and a |
165 | // callback provided by pluggable vendor to do a ZerosLike operation on the |
166 | // tensors unwrapped from the Variant tensors. The caller takes ownership of the |
167 | // `input` and `out` tensors and is responsible for freeing them with |
168 | // TF_DeleteTensor. |
169 | TF_CAPI_EXPORT extern void TF_ZerosLikeVariant( |
170 | TF_OpKernelContext* ctx, |
171 | void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input, |
172 | TF_Tensor* out), |
173 | TF_Status* status); |
174 | |
175 | typedef struct TF_CoordinationServiceAgent TF_CoordinationServiceAgent; |
176 | |
177 | #endif // IS_MOBILE_PLATFORM |
178 | |
179 | #ifdef __cplusplus |
180 | } /* end extern "C" */ |
181 | #endif |
182 | |
183 | #endif // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ |
184 | |