1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Routines for registering new ops and for implementing op shape inference
17// functions.
18//
19// This API is alpha software and is subject to change.
20//
21// REGISTRATION
22// ------------
23//
24// In order to register a new op, create a new TF_OpDefinitionBuilder:
25//
26// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName");
27//
28// Inputs, outputs and attributes can be added to the builder with the
29// corresponding functions, e.g.
30//
31// TF_OpDefinitionBuilderAddInput(builder, "input1: int32");
32// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64");
33// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32");
34//
35// The builder may then be registered with TensorFlow using the
36// TF_RegisterOpDefinition function. E.g.
37//
38// TF_Status* status = TF_NewStatus();
39// TF_RegisterOpDefinition(builder, &status);
40// if (TF_GetCode(status) != TF_OK) {
41// // handle error
42// }
43//
44// SHAPE INFERENCE
45// ---------------
46//
47// You can provide a shape inference function that TensorFlow will call when it
48// wants to understand the shape of outputs that the op will produce. Use the
49// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape
50// inference function pointer with TensorFlow. The following is an example of a
51// very simple shape inference function:
52//
53// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
54// TF_ShapeHandle* input = TF_NewShapeHandle();
55// TF_ShapeInferenceContextGetInput(ctx, 0, input, status);
56// if (TF_GetCode(status) == TF_OK) {
57// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status);
58// }
59// TF_DeleteShapeHandle(input);
60// }
61//
62// The following code registers the inference function with TensorFlow:
63//
64// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
65//
66// For more details about shape inference, see the documentation for
67// TF_OpDefinitionBuilderSetShapeInferenceFunction.
68
69#ifndef TENSORFLOW_C_OPS_H_
70#define TENSORFLOW_C_OPS_H_
71
72#include <stdbool.h>
73#include <stdint.h>
74#include <stdlib.h>
75
76#include "tensorflow/c/tf_datatype.h"
77#include "tensorflow/c/tf_status.h"
78
79#ifdef SWIG
80#define TF_CAPI_EXPORT
81#else
82#if defined(_WIN32)
83#ifdef TF_COMPILE_LIBRARY
84#define TF_CAPI_EXPORT __declspec(dllexport)
85#else
86#define TF_CAPI_EXPORT __declspec(dllimport)
87#endif // TF_COMPILE_LIBRARY
88#else
89#define TF_CAPI_EXPORT __attribute__((visibility("default")))
90#endif // _WIN32
91#endif // SWIG
92
93#ifdef __cplusplus
94extern "C" {
95#endif
96
97struct TF_DimensionHandle;
98struct TF_OpDefinitionBuilder;
99struct TF_ShapeHandle;
100struct TF_ShapeInferenceContext;
101
102// Returns a newly allocated op definition builder for the given op name. The
103// returned builder may be customized with the `TF_OpDefinitionBuilder...`
104// functions and then registered with TensorFlow with TF_RegisterOpDefinition.
105//
106// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or
107// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never
108// registered.
109TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(
110 const char* op_name);
111
112// Registers the given op builder with TensorFlow. Indicates success or
113// otherwise in the given status.
114//
115// `builder` is freed whether the op was successfully registered or not. You
116// must call either this function or TF_DeleteOpDefinitionBuilder to free the
117// builder, but never both.
118TF_CAPI_EXPORT extern void TF_RegisterOpDefinition(
119 TF_OpDefinitionBuilder* builder, TF_Status* status);
120
121// Frees the given op definition builder. You must call either this function or
122// TF_RegisterOpDefinition to free the builder, but never both.
123TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
124 TF_OpDefinitionBuilder* builder);
125
126//----------------------------------------------------
127// Attribute functions.
128
129// Adds an attr to the given TF_OpDefinitionBuilder. The spec has
130// format "<name>:<type>" or "<name>:<type>=<default>"
131// where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*.
132// By convention, names containing only capital letters are reserved for
133// attributes whose values can be inferred by the operator implementation if not
134// supplied by the user. If the attribute name contains characters other than
135// capital letters, the operator expects the user to provide the attribute value
136// at operation runtime.
137//
138// <type> can be:
139// "string", "int", "float", "bool", "type", "shape", or "tensor"
140// "numbertype", "realnumbertype", "quantizedtype"
141// (meaning "type" with a restriction on valid values)
142// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
143// (meaning "type" with a restriction containing unions of value types)
144// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
145// (meaning "string" with a restriction on valid values)
146// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
147// (meaning lists of the above types)
148// "int >= 2" (meaning "int" with a restriction on valid values)
149// "list(string) >= 2", "list(int) >= 2"
150// (meaning "list(string)" / "list(int)" with length at least 2)
151// <default>, if included, should use the Proto text format
152// of <type>. For lists use [a, b, c] format.
153//
154// Note that any attr specifying the length of an input or output will
155// get a default minimum of 1 unless the >= # syntax is used.
156TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddAttr(
157 TF_OpDefinitionBuilder* builder, const char* attr_spec);
158
159// Adds an input to this TF_OpDefinitionBuilder.
160// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
161// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
162// * For a single tensor: <type>
163// * For a sequence of tensors with the same type: <number>*<type>
164// * For a sequence of tensors with different types: <type-list>
165// Where:
166// <type> is either one of "float", "int32", "string", ...
167// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
168// with type "type".
169// <number> is the name of an attr with type "int".
170// <type-list> is the name of an attr with type "list(type)".
171TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
172 TF_OpDefinitionBuilder* builder, const char* input_spec);
173
174// Adds an output to this TF_OpDefinitionBuilder.
175// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
176// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
177// * For a single tensor: <type>
178// * For a sequence of tensors with the same type: <number>*<type>
179// * For a sequence of tensors with different types: <type-list>
180// Where:
181// <type> is either one of "float", "int32", "string", ...
182// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
183// with type "type".
184// <number> is the name of an attr with type "int".
185// <type-list> is the name of an attr with type "list(type)".
186TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
187 TF_OpDefinitionBuilder* builder, const char* output_spec);
188
189// Sets the commutative property for the op built by the given builder.
190TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
191 TF_OpDefinitionBuilder* builder, bool is_commutative);
192
193// Sets the is_aggregate property of the builder to the given value.
194//
195// If is_aggregate is true, then the operation produced by this builder accepts
196// N >= 2 inputs and produces 1 output all of the same type. Should be
197// associative and commutative, and produce output with the same shape as the
198// input. The optimizer may replace an aggregate op taking input from multiple
199// devices with a tree of aggregate ops that aggregate locally within each
200// device (and possibly within groups of nearby devices) before communicating.
201TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate(
202 TF_OpDefinitionBuilder* builder, bool is_aggregate);
203
204// Sets the is_stateful property of the builder to the given value.
205//
206// The op built by this builder is stateful if its behavior depends on some
207// state beyond its input tensors (e.g. variable reading op) or if it has a
208// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
209// must always produce the same output for the same input and have no
210// side-effects.
211//
212// By default Ops may be moved between devices. Stateful ops should either not
213// be moved, or should only be moved if that state can also be moved (e.g. via
214// some sort of save / restore). Stateful ops are guaranteed to never be
215// optimized away by Common Subexpression Elimination (CSE).
216TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful(
217 TF_OpDefinitionBuilder* builder, bool is_stateful);
218
219// Sets the allows_uninitialized_input property of the operation built by this
220// builder.
221//
222// By default, all inputs to an Op must be initialized Tensors. Ops that may
223// initialize tensors for the first time should set this field to true, to allow
224// the Op to take an uninitialized Tensor as input.
225TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput(
226 TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input);
227
228// Adds a deprecation warning for the given op. This indicates to the user that
229// `version` is the first TensorFlow GraphDef version for which the operation is
230// deprecated. `explanation` should contain the reason for the deprecation and
231// what to use instead.
232//
233// This function is only an indicator that the operation may disappear in a
234// version of TensorFlow after `version`. It does not affect op registration.
235TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated(
236 TF_OpDefinitionBuilder* builder, int version, const char* explanation);
237
238// Sets the shape inference function for the op.
239TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction(
240 TF_OpDefinitionBuilder* builder,
241 void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
242 TF_Status* status));
243
244//----------------------------------------------------
245// Functions for TF_ShapeInferenceContext.
246//
247// Functions for implementing shape inference functions. TensorFlow uses these
248// functions to determine the shape of tensors produced by an operation without
249// having to actually run the operation. If an operation chooses to provide a
250// shape inference function, it will be invoked by TensorFlow as needed.
251//
252// When invoked by TensorFlow, the shape inference function is provided with a
253// TF_ShapeInferenceContext pointer. The function's implementation will use the
254// accessor and mutator functions with names beginning with
255// TF_ShapeInferenceContext to examine the input state and determine the output
256// shape.
257
258// Returns the number of inputs in the given shape inference context.
259TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs(
260 TF_ShapeInferenceContext* ctx);
261
262// Returns a newly allocated shape handle. The shapes represented by these
263// handles may be queried or mutated with the corresponding
264// TF_ShapeInferenceContext... functions.
265TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle();
266
267// Places the ith input of the given shape inference context into the given
268// shape handle, or returns a status other than TF_OK indicating why the input
269// could not be retrieved
270// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)).
271TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput(
272 TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle,
273 TF_Status* status);
274
275// Places the given shape handle into the `i`th output position of the given
276// context. Internally, the shape handle is copied; the caller may subsequently
277// delete `handle`.
278TF_CAPI_EXPORT
279extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
280 int i, TF_ShapeHandle* handle,
281 TF_Status* status);
282
283// Returns a newly-allocated scalar shape handle. The returned handle should
284// be freed with TF_DeleteShapeHandle.
285TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
286 TF_ShapeInferenceContext* ctx);
287
288// Returns a newly-allocate shape handle representing a vector of the given
289// size. The returned handle should be freed with TF_DeleteShapeHandle.
290TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
291 TF_ShapeInferenceContext* ctx, size_t size);
292
293// Returns a newly allocated dimension handle. It must be freed with
294// TF_DeleteDimensionHandle.
295TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle();
296
297// Interprets the named shape inference context attribute as a TF_DataType and
298// places it into *val. *status is set to TF_OK.
299//
300// If the attribute could not be found or could not be interpreted as
301// TF_DataType, *status is populated with an error.
302TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType(
303 TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val,
304 TF_Status* status);
305
306// Returns the rank of the shape represented by the given handle.
307TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank(
308 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
309
310// Returns 1 if `handle` has a known rank, 0 otherwise.
311TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown(
312 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
313
314// If <handle> has rank <rank>, or its rank is unknown, return OK and return the
315// shape with asserted rank in <*result>. Otherwise an error is placed into
316// `status`.
317TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank(
318 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
319 TF_ShapeHandle* result, TF_Status* status);
320
321// If <handle> has rank at least <rank>, or its rank is unknown, return OK and
322// return the shape with asserted rank in <*result>. Otherwise an error is
323// placed into `status`.
324TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast(
325 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
326 TF_ShapeHandle* result, TF_Status* status);
327
328// If <handle> has rank at most <rank>, or its rank is unknown, return OK and
329// return the shape with asserted rank in <*result>. Otherwise an error is
330// placed into `status`.
331TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost(
332 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
333 TF_ShapeHandle* result, TF_Status* status);
334
335// Places a handle to the ith dimension of the given shape into *result.
336TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
337 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
338 TF_DimensionHandle* result);
339
340// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
341// [start:end]. <start> and <end> can be negative, to index from the end of the
342// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of
343// <shape_handle>.
344TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape(
345 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start,
346 int64_t end, TF_ShapeHandle* result, TF_Status* status);
347
348// Places an unknown shape in all outputs for the given inference context. Used
349// for shape inference functions with ops whose output shapes are unknown.
350TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape(
351 TF_ShapeInferenceContext* ctx, TF_Status* status);
352
353// Returns whether the given handle represents a known dimension.
354TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown(
355 TF_DimensionHandle* dim_handle);
356
357// Returns the value of the given dimension.
358TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue(
359 TF_DimensionHandle* dim_handle);
360
361// Returns in <*result> the result of appending the dimensions of <second> to
362// those of <first>.
363TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes(
364 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first,
365 TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status);
366
367// Frees the given shape handle.
368TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle);
369
370// Frees the given dimension handle.
371TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle);
372
373#ifdef __cplusplus
374} /* end extern "C" */
375#endif
376
377#endif // TENSORFLOW_C_OPS_H_
378