1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_C_KERNELS_H_ |
17 | #define TENSORFLOW_C_KERNELS_H_ |
18 | |
19 | #include <stdint.h> |
20 | |
21 | #include "tensorflow/c/c_api.h" |
22 | #include "tensorflow/c/experimental/stream_executor/stream_executor.h" |
23 | #include "tensorflow/c/tf_datatype.h" |
24 | #include "tensorflow/c/tf_status.h" |
25 | #include "tensorflow/c/tf_tensor.h" |
26 | |
27 | // Macro to control visibility of exported symbols in the shared library (.so, |
28 | // .dylib, .dll). |
29 | // This duplicates the TF_EXPORT macro definition in |
30 | // tensorflow/core/platform/macros.h in order to keep this .h file independent |
31 | // of any other includes. |
32 | #ifdef SWIG |
33 | #define TF_CAPI_EXPORT |
34 | #else |
35 | #if defined(_WIN32) |
36 | #ifdef TF_COMPILE_LIBRARY |
37 | #define TF_CAPI_EXPORT __declspec(dllexport) |
38 | #else |
39 | #define TF_CAPI_EXPORT __declspec(dllimport) |
40 | #endif // TF_COMPILE_LIBRARY |
41 | #else |
42 | #define TF_CAPI_EXPORT __attribute__((visibility("default"))) |
43 | #endif // _WIN32 |
44 | #endif // SWIG |
45 | |
46 | #ifdef __cplusplus |
47 | extern "C" { |
48 | #endif |
49 | |
50 | typedef struct TF_Tensor TF_Tensor; |
51 | |
52 | // -------------------------------------------------------------------------- |
53 | // C API for TensorFlow Kernels. |
54 | // |
55 | // This API allows developers to register custom kernel implementations for |
56 | // TensorFlow. |
57 | // |
58 | // See c_api.h header comments for a discussion about API conventions. |
59 | // |
60 | // Users wishing to extend TensorFlow with new kernels will call |
61 | // `TF_NewKernelBuilder`. The resulting kernel builder can be registered with |
62 | // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided |
63 | // kernels when necessary. |
64 | |
65 | typedef struct TF_KernelBuilder TF_KernelBuilder; |
66 | typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; |
67 | typedef struct TF_OpKernelContext TF_OpKernelContext; |
68 | |
69 | // TF_InitKernel to do op/kernel registration. |
70 | // Plugin should implement TF_InitKernel to register kernels. This function |
71 | // should register all kernels in a plugin. |
72 | void TF_InitKernel(); |
73 | |
74 | // Allocates a new kernel builder and returns a pointer to it. |
75 | // |
76 | // If non-null, TensorFlow will call create_func when it needs to instantiate |
77 | // the kernel. The pointer returned by create_func will be passed to |
78 | // compute_func and delete_func, thereby functioning as a "this" pointer for |
79 | // referring to kernel instances. |
80 | // |
81 | // The TF_OpKernelConstruction pointer passed to create_func is owned by |
82 | // TensorFlow and will be deleted once create_func returns. It must not be used |
83 | // after this. |
84 | // |
85 | // When TensorFlow needs to perform a computation with this kernel, it will |
86 | // call compute_func. This function will receive the pointer returned by |
87 | // create_func (or null if no create_func was provided), along with the inputs |
88 | // to the computation. |
89 | // |
90 | // The TF_OpKernelContext pointer received by compute_func is owned by |
91 | // TensorFlow and will be deleted once compute_func returns. It must not be used |
92 | // after this. |
93 | // |
94 | // Finally, when TensorFlow no longer needs the kernel, it will call |
95 | // delete_func if one is provided. This function will receive the pointer |
96 | // returned in `create_func` or nullptr if no `create_func` was provided. |
97 | // |
98 | // The caller should pass the result of this function to |
99 | // TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for |
100 | // some reason, the kernel builder will not be registered, the caller should |
101 | // delete it with TF_DeleteKernelBuilder. |
102 | TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( |
103 | const char* op_name, const char* device_name, |
104 | void* (*create_func)(TF_OpKernelConstruction*), |
105 | void (*compute_func)(void*, TF_OpKernelContext*), |
106 | void (*delete_func)(void*)); |
107 | |
108 | // Specifies that this kernel's attribute only supports the given type. |
109 | TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint( |
110 | TF_KernelBuilder* kernel_builder, const char* attr_name, |
111 | const TF_DataType type, TF_Status* status); |
112 | |
113 | // Specify that this kernel requires/provides an input/output arg |
114 | // in host memory (instead of the default, device memory). |
115 | TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory( |
116 | TF_KernelBuilder* kernel_builder, const char* arg_name); |
117 | |
118 | // Specify a priority number for this kernel. |
119 | TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority( |
120 | TF_KernelBuilder* kernel_builder, int32_t priority_number); |
121 | |
122 | // Specify a label for this kernel. |
123 | TF_CAPI_EXPORT extern void TF_KernelBuilder_Label( |
124 | TF_KernelBuilder* kernel_builder, const char* label); |
125 | |
126 | // Register the given kernel builder with the TensorFlow runtime. If |
127 | // registration fails, the given status will be populated. |
128 | // |
129 | // This call takes ownership of the `builder` pointer. |
130 | TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, |
131 | TF_KernelBuilder* builder, |
132 | TF_Status* status); |
133 | |
134 | // Register the given kernel builder with the TensorFlow runtime. If |
135 | // registration fails, the given status will be populated. |
136 | // |
137 | // This method is the same as TF_RegisterKernelBuilder except it takes in a |
138 | // serialized KernelDef, and uses it for registration, instead of building a new |
139 | // one. Users can choose to not provide a serialized KernelDef and in that case |
140 | // it's identical to TF_RegisterKernelBuilder. |
141 | TF_CAPI_EXPORT extern void TF_RegisterKernelBuilderWithKernelDef( |
142 | const char* serialized_kernel_def, const char* name, |
143 | TF_KernelBuilder* builder, TF_Status* status); |
144 | |
145 | // Deletes the given TF_KernelBuilder. This should be called only if the kernel |
146 | // builder is not registered with TensorFlow via TF_RegisterKernelBuilder. |
147 | TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); |
148 | |
149 | // -------------------------------------------------------------------------- |
150 | // OpKernelContext routines |
151 | |
152 | // TF_GetStream returns the SP_Stream available in ctx. |
153 | // This function returns a stream only for devices registered using the |
154 | // StreamExecutor C API |
155 | // (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return |
156 | // nullptr and set error status in all other cases. |
157 | // Experimental: this function doesn't have compatibility guarantees and subject |
158 | // to change at any time. |
159 | TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx, |
160 | TF_Status* status); |
161 | |
162 | // TF_NumInputs returns the number of inputs available in ctx. |
163 | TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); |
164 | |
165 | // TF_NumOutputs returns the number of outputs to be placed in *ctx by the |
166 | // kernel. |
167 | TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); |
168 | |
169 | // Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is |
170 | // populated and its ownership is passed to the caller. In any other case, |
171 | // *tensor is not modified. |
172 | // |
173 | // If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. |
174 | TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, |
175 | TF_Tensor** tensor, TF_Status* status); |
176 | |
177 | typedef struct { |
178 | size_t struct_size; |
179 | void* priv; // Not used, for possible extension. |
180 | int start; // output |
181 | int stop; // output |
182 | TF_Status* status; // output |
183 | } TF_InputRange_Args; |
184 | const size_t TF_InputRange_Args_STRUCT_SIZE = |
185 | TF_OFFSET_OF_END(TF_InputRange_Args, status); |
186 | |
187 | // Retrieves the start and stop indices, given the input name. Equivalent to |
188 | // OpKernel::InputRange(). `args` will contain the result indices and status. |
189 | TF_CAPI_EXPORT extern void TF_InputRange(TF_OpKernelContext* ctx, |
190 | const char* name, |
191 | TF_InputRange_Args* args); |
192 | |
193 | // Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but |
194 | // TF_OK, ctx is left unmodified. |
195 | // |
196 | // If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. |
197 | TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, |
198 | const TF_Tensor* tensor, |
199 | TF_Status* status); |
200 | |
201 | // Retrieves the ith output from ctx. If TF_GetCode(status) is TF_OK, *tensor is |
202 | // populated and its ownership is passed to the caller. In any other case, |
203 | // *tensor is not modified. |
204 | // |
205 | // If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. |
206 | TF_CAPI_EXPORT extern TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, |
207 | int i, TF_Status* status); |
208 | |
209 | // Retrieves a serialized FunctionDefLibrary. Status will be set. |
210 | TF_CAPI_EXPORT extern void TF_GetSerializedFunctionDefLibrary( |
211 | TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library, |
212 | TF_Status* status); |
213 | |
214 | // Retrieves a serialized ConfigProto. Status will be set. |
215 | TF_CAPI_EXPORT extern void TF_GetSerializedConfigProto( |
216 | TF_OpKernelContext* ctx, TF_Buffer* serialized_config_proto, |
217 | TF_Status* status); |
218 | |
219 | // Notifies the given OpKernelConstruction that kernel construction has failed. |
220 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( |
221 | TF_OpKernelConstruction* ctx, TF_Status* status); |
222 | |
223 | // Notifies the given OpKernelContext that the kernel's compute function has |
224 | // failed. |
225 | TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, |
226 | TF_Status* status); |
227 | |
228 | // Returns the expected output data type of the ith output. If i < 0 or |
229 | // i >= TF_NumOutputs(ctx), the program aborts. |
230 | TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( |
231 | TF_OpKernelContext* ctx, int i); |
232 | |
233 | // Returns true if the ith input is allocated in host memory. If i < 0 or i >= |
234 | // TF_NumInputs(ctx), the program aborts. |
235 | TF_CAPI_EXPORT extern bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, |
236 | TF_Status* status); |
237 | |
238 | // Returns true if the ith output is allocated in host memory. If i < 0 or i >= |
239 | // TF_NumOutputs(ctx), the program aborts. |
240 | TF_CAPI_EXPORT extern bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, |
241 | TF_Status* status); |
242 | |
243 | // Returns the step ID of the given context. |
244 | TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); |
245 | |
246 | // Returns the serialized NodeDef protocol buffer for the kernel |
247 | TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetNodeDef( |
248 | TF_OpKernelConstruction* ctx, TF_Status* status); |
249 | |
250 | // Returns the frame ID of the given context. |
251 | TF_CAPI_EXPORT extern uint64_t TF_GetFrameId(TF_OpKernelContext* ctx); |
252 | |
253 | // Returns the Iter ID of the given context. |
254 | TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx); |
255 | |
256 | // Returns the graph def version of the given context. |
257 | TF_CAPI_EXPORT extern int TF_GetGraphDefVersion(TF_OpKernelContext* ctx); |
258 | |
259 | // Returns the name of the OpKernel. |
260 | // |
261 | // The returned TF_StringView's underlying string is owned by the OpKernel and |
262 | // has the same lifetime as the OpKernel. |
263 | TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx); |
264 | |
265 | // Returns the default container of the resource manager in OpKernelContext. |
266 | // |
267 | // The returned TF_StringView's underlying string is owned by the OpKernel and |
268 | // has the same lifetime as the OpKernel. |
269 | TF_CAPI_EXPORT extern TF_StringView TF_GetResourceMgrDefaultContainerName( |
270 | TF_OpKernelContext* ctx); |
271 | |
272 | // Returns the name of the requested input at `index` from the OpKernel. |
273 | // |
274 | // The returned TF_StringView's underlying string is owned by the OpKernel and |
275 | // has the same lifetime as the OpKernel. |
276 | TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelRequestedInput( |
277 | TF_OpKernelContext* ctx, size_t index); |
278 | |
279 | // Get the list_size and total_size of the attribute `attr_name` of `oper`. |
280 | // list_size - the length of the list. |
281 | // total_size - total size of the list. |
282 | // (1) If attr_type == TF_ATTR_STRING |
283 | // then total_size is the cumulative byte size |
284 | // of all the strings in the list. |
285 | // (3) If attr_type == TF_ATTR_SHAPE |
286 | // then total_size is the number of dimensions |
287 | // of the shape valued attribute, or -1 |
288 | // if its rank is unknown. |
289 | // (4) If attr_type == TF_ATTR_SHAPE |
290 | // then total_size is the cumulative number |
291 | // of dimensions of all shapes in the list. |
292 | // (5) Otherwise, total_size is undefined. |
293 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize( |
294 | TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size, |
295 | int32_t* total_size, TF_Status* status); |
296 | |
297 | // Interprets the named kernel construction attribute as a TF_DataType and |
298 | // places it into *val. *status is set to TF_OK. |
299 | // |
300 | // If the attribute could not be found or could not be interpreted as |
301 | // TF_DataType, *status is populated with an error. |
302 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( |
303 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, |
304 | TF_Status* status); |
305 | |
306 | // Interprets the named kernel construction attribute as int32_t and |
307 | // places it into *val. *status is set to TF_OK. |
308 | // |
309 | // If the attribute could not be found or could not be interpreted as |
310 | // int32, *status is populated with an error. |
311 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( |
312 | TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, |
313 | TF_Status* status); |
314 | |
315 | // Interprets the named kernel construction attribute as int64_t and |
316 | // places it into *val. *status is set to TF_OK. |
317 | // |
318 | // If the attribute could not be found or could not be interpreted as |
319 | // int64, *status is populated with an error. |
320 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64( |
321 | TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val, |
322 | TF_Status* status); |
323 | |
324 | // Interprets the named kernel construction attribute as float and |
325 | // places it into *val. *status is set to TF_OK. |
326 | // |
327 | // If the attribute could not be found or could not be interpreted as |
328 | // float, *status is populated with an error. |
329 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat( |
330 | TF_OpKernelConstruction* ctx, const char* attr_name, float* val, |
331 | TF_Status* status); |
332 | |
333 | // Interprets the named kernel construction attribute as bool and |
334 | // places it into *val. *status is set to TF_OK. |
335 | // |
336 | // If the attribute could not be found or could not be interpreted as |
337 | // bool, *status is populated with an error. |
338 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool( |
339 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val, |
340 | TF_Status* status); |
341 | |
342 | // Interprets the named kernel construction attribute as string and |
343 | // places it into *val. `val` must |
344 | // point to an array of length at least `max_length` (ideally set to |
345 | // total_size from TF_OpKernelConstruction_GetAttrSize(ctx, |
346 | // attr_name, list_size, total_size)). *status is set to TF_OK. |
347 | // |
348 | // If the attribute could not be found or could not be interpreted as |
349 | // string, *status is populated with an error. |
350 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString( |
351 | TF_OpKernelConstruction* ctx, const char* attr_name, char* val, |
352 | size_t max_length, TF_Status* status); |
353 | |
354 | // Interprets the named kernel construction attribute as tensor and places it |
355 | // into *val. Allocates a new TF_Tensor which the caller is expected to take |
356 | // ownership of (and can deallocate using TF_DeleteTensor). *status is set to |
357 | // TF_OK. |
358 | // |
359 | // If the attribute could not be found or could not be interpreted as |
360 | // tensor, *status is populated with an error. |
361 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensor( |
362 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** val, |
363 | TF_Status* status); |
364 | |
365 | // Interprets the named kernel construction attribute as a TF_DataType array and |
366 | // places it into *vals. *status is set to TF_OK. |
367 | // `vals` must point to an array of length at least `max_values` (ideally set |
368 | // to list_size from |
369 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
370 | // total_size)). |
371 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList( |
372 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals, |
373 | int max_vals, TF_Status* status); |
374 | |
375 | // Interprets the named kernel construction attribute as int32_t array and |
376 | // places it into *vals. *status is set to TF_OK. |
377 | // `vals` must point to an array of length at least `max_values` (ideally set |
378 | // to list_size from |
379 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
380 | // total_size)). |
381 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List( |
382 | TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals, |
383 | int max_vals, TF_Status* status); |
384 | |
385 | // Interprets the named kernel construction attribute as int64_t array and |
386 | // places it into *vals. *status is set to TF_OK. |
387 | // `vals` must point to an array of length at least `max_values` (ideally set |
388 | // to list_size from |
389 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
390 | // total_size)). |
391 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List( |
392 | TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals, |
393 | int max_vals, TF_Status* status); |
394 | |
395 | // Interprets the named kernel construction attribute as float array and |
396 | // places it into *vals. *status is set to TF_OK. |
397 | // `vals` must point to an array of length at least `max_values` (ideally set |
398 | // to list_size from |
399 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
400 | // total_size)). |
401 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList( |
402 | TF_OpKernelConstruction* ctx, const char* attr_name, float* vals, |
403 | int max_vals, TF_Status* status); |
404 | |
405 | // Interprets the named kernel construction attribute as bool array and |
406 | // places it into *vals. *status is set to TF_OK. |
407 | // `vals` must point to an array of length at least `max_values` (ideally set |
408 | // to list_size from |
409 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
410 | // total_size)). |
411 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList( |
412 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals, |
413 | int max_vals, TF_Status* status); |
414 | |
415 | // Interprets the named kernel construction attribute as string array and fills |
416 | // in `vals` and `lengths`, each of which must point to an array of length at |
417 | // least `max_values`. *status is set to TF_OK. The elements of values will |
418 | // point to addresses in `storage` which must be at least `storage_size` bytes |
419 | // in length. Ideally, max_values would be set to list_size and `storage` would |
420 | // be at least total_size, obtained from |
421 | // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, |
422 | // total_size). |
423 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList( |
424 | TF_OpKernelConstruction* ctx, const char* attr_name, char** vals, |
425 | size_t* lengths, int max_values, void* storage, size_t storage_size, |
426 | TF_Status* status); |
427 | |
428 | // Interprets the named kernel construction attribute as tensor array and places |
429 | // it into *vals. *status is set to TF_OK. |
430 | // `vals` must point to an array of length at least `max_values` |
431 | // (ideally set to list_size from TF_OpKernelConstruction_GetAttrSize(ctx, |
432 | // attr_name, list_size, total_size)). |
433 | // |
434 | // The caller takes ownership of all the non-null TF_Tensor* entries in `vals` |
435 | // (which can be deleted using TF_DeleteTensor(vals[i])). |
436 | TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorList( |
437 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** vals, |
438 | int max_values, TF_Status* status); |
439 | |
440 | // Interprets the named kernel construction attribute as a |
441 | // tensorflow::NameAttrList and returns the serialized proto as TF_Buffer. |
442 | // `status` will be set. The caller takes ownership of the returned TF_Buffer |
443 | // (if not null) and is responsible for managing its lifetime. |
444 | TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetAttrFunction( |
445 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); |
446 | |
447 | // Return true if the kernel construction has the attr_name |
448 | TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr( |
449 | TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); |
450 | |
451 | // Returns the unique operation name for this OpKernel. |
452 | TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( |
453 | TF_OpKernelConstruction* ctx); |
454 | |
455 | // Allocates Tensor for output at given index. Caller takes ownership of |
456 | // returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor). |
457 | // |
458 | // This function should be used to allocate outputs inside kernel |
459 | // compute function. |
460 | TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, |
461 | int index, TF_DataType dtype, |
462 | const int64_t* dims, int num_dims, |
463 | size_t len, TF_Status* status); |
464 | |
465 | // Tries to forward one of the inputs given in input_indices to |
466 | // output[output_index]. If none of the given inputs can be forwarded, calls |
467 | // allocate_output() to allocate a new output buffer. The index of the |
468 | // forwarded input will be assign to output argument forwarded_input (if it's |
469 | // not nullptr). If no inputs are forwarded, forwarded_input will be assigned |
470 | // -1. |
471 | TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput( |
472 | TF_OpKernelContext* context, const int* candidate_input_indices, |
473 | int num_candidate_input_indices, int output_index, |
474 | const int64_t* output_dims, int output_num_dims, int* forwarded_input, |
475 | TF_Status* status); |
476 | |
477 | // Allocates a temporary Tensor of the specified type and shape. The |
478 | // Tensor must not be used after kernel construction is |
479 | // complete. |
480 | // |
481 | // num_dims must equal the size of array dims |
482 | TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp( |
483 | TF_OpKernelContext* context, TF_DataType dtype, const int64_t* dims, |
484 | int num_dims, TF_AllocatorAttributes* alloc_attrs, TF_Status* status); |
485 | |
486 | #ifdef __cplusplus |
487 | } /* end extern "C" */ |
488 | #endif |
489 | |
490 | #endif // TENSORFLOW_C_KERNELS_H_ |
491 | |