1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
17#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
18
19// Place `<locale>` before <Python.h> to avoid build failure in macOS.
20#include <locale>
21
22// The empty line above is on purpose as otherwise clang-format will
23// automatically move <Python.h> before <locale>.
24#include <Python.h>
25
26#include "tensorflow/c/eager/c_api.h"
27#include "tensorflow/core/framework/types.pb.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/lib/gtl/inlined_vector.h"
30#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
31
32typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
33 TFE_InputTensorHandles;
34typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
35 TFE_OutputTensorHandles;
36
37// Execute a TensorFlow operation.
38//
39// 'device_name': Name of the device on which to execute the operation, or NULL
40// for automatic selection.
41// 'op_name': Name of the TensorFlow op to execute.
42// 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
43// will be provided as input to the operation.
44// 'attrs': A Python tuple alternating names and attr values.
45// 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
46// placed. On success, its elements will be filled in and the
47// caller takes ownership of each returned TFE_TensorHandle.
48// 'outputs' MUST be sized to be at least as large as the number
49// of tensors produced by the operation and will be resized to
50// the actual number of tensors produced.
51void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
52 const char* op_name, TFE_InputTensorHandles* inputs,
53 PyObject* attrs, TFE_OutputTensorHandles* outputs,
54 TF_Status* out_status);
55
56// Execute a cancelable TensorFlow operation.
57//
58// Arguments as above (for TFE_Py_Execute), with the addition of:
59// 'cancellation_manager': A pointer to a TFE_CancellationManager that can be
60// used to cancel execution of the given operation.
61typedef struct TFE_CancellationManager TFE_CancellationManager;
62void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
63 const char* op_name,
64 TFE_InputTensorHandles* inputs, PyObject* attrs,
65 TFE_CancellationManager* cancellation_manager,
66 TFE_OutputTensorHandles* outputs,
67 TF_Status* out_status);
68
69// Registers e as the Exception class for handling not ok Status. Returns
70// Py_None if registration succeeds, else throws a TypeError and returns NULL.
71//
72// This function is not thread-safe.
73PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
74
75// Registers e as the VSpace to use.
76// `vspace` must be a imperative_grad.py:VSpace named tuple.
77PyObject* TFE_Py_RegisterVSpace(PyObject* e);
78
79// Registers e as the Exception to be raised when the conditions of
80// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
81// is a signal to the calling code that it should fall back to the safer (and
82// more complete) code path.
83//
84// This function is not thread-safe.
85PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
86
87// Registers e as the gradient_function.
88// The registered function takes
89// (op_name, attrs, num_inputs, inputs, outputs, output_gradients) and returns
90// the input gradients. This function will not correctly be able to generate
91// gradients for functional ops - the gradients for those ops are calculated
92// through a different codepath (see function.py for additional information).
93//
94// This function is not thread-safe.
95PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);
96
97// Registers e as the forward_gradient_function. The registered function takes
98// (op_name, attrs, inputs, outputs, tangents) and returns the output
99// tangents. This function is used only for operations, not for custom gradients
100// or functional ops.
101//
102// This function is not thread-safe.
103PyObject* TFE_Py_RegisterJVPFunction(PyObject* e);
104
105namespace tensorflow {
106
107// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
108// `exception` if not nullptr, else using the class registered via
109// TFE_Py_RegisterExceptionClass), and returns -1.
110int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
111
112} // namespace tensorflow
113
114// Returns 0 if 'status' is ok. Otherwise, raises an exception (using
115// `exception` if not nullptr, else using the class registered via
116// TFE_Py_RegisterExceptionClass), and returns -1.
117int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
118 PyObject* exception);
119
120// Returns the string associated with the passed-in python object.
121const char* TFE_GetPythonString(PyObject* o);
122
123// Returns a unique id on each call.
124int64_t get_uid();
125
126// Wraps the output of get_uid as a Python Long object. Ownership is passed to
127// the caller.
128PyObject* TFE_Py_UID();
129
130// Deleter for Context objects, called from the Capsule that owns it.
131void TFE_DeleteContextCapsule(PyObject* context);
132
133// Returns true if o is an instance of EagerTensor, but not a subclass. Else
134// returns false.
135bool EagerTensor_CheckExact(const PyObject* o);
136
137// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
138PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
139 const bool is_packed = false);
140
141// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
142TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
143
144// Creates the `EagerTensor` class by subclassing `base_class` and returns the
145// newly created type, or nullptr on error.
146PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
147
148// Sets `profiler` as the current profiler to receive callbacks about events
149// on eager tensors. Currently, the only reported event is creation.
150// `profiler` is expected to have a `created(self, eager_tensor)` method that
151// takes the created tensor as its single argument.
152// Previous profiler, if any, is unset and will not receive any more
153// callbacks.
154// To unset the profiler, pass Py_None as the value of `profiler`.
155PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
156
157// Creates a new tape and adds it to the active set. `persistent` and
158// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
159PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
160 PyObject* watch_accessed_variables);
161
162// Removes the passed tape from the set of active tapes.
163void TFE_Py_TapeSetRemove(PyObject* tape);
164
165// Adds the passed tape to the set of active tapes.
166void TFE_Py_TapeSetAdd(PyObject* tape);
167
168// Returns true if the tape stack is empty.
169PyObject* TFE_Py_TapeSetIsEmpty();
170
171// Check if any backward tape should record an operation given inputs.
172//
173// Does not take forward accumulators into account.
174PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors);
175
176// Determine possible gradient types, taking forward accumulators into account.
177// - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecordBackprop
178// is false and no forward accumulator is watching)
179// - 1 if first-order gradients may be requested
180// - 2 if higher-order gradients may be requested
181PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors);
182
183void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
184void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id);
185
186// Stops any gradient recording on the current thread.
187//
188// Includes forward accumulators.
189void TFE_Py_TapeSetStopOnThread();
190
191// Restarts gradient recording on the current thread.
192void TFE_Py_TapeSetRestartOnThread();
193
194// Checks whether gradient recording is stopped on the current thread.
195PyObject* TFE_Py_TapeSetIsStopped();
196
197// Records an operation for the purpose of gradient computation.
198//
199// Arguments:
200// - op_type is a string for the operation type, used in the backprop code
201// - output_tensors are a list of Python Tensor objects output by the operation
202// - input_tensors are a list of input Tensors to the recorded operation
203// - backward_function is the function to be called during backprop or
204// forwardprop to, given the gradients of the output tensors, produce the
205// gradients of the input tensors. This function is automatically transposed
206// during forwardprop.
207// - forward_function is an optional special-case for forwardprop, taking input
208// jvps and returning output jvps.
209//
210// Records an operation both for backprop (gradient tape) and forwardprop
211// (forward accumulator). Equivalent to calling both
212// TFE_Py_TapeSetRecordOperationBackprop and
213// TFE_Py_TapeSetRecordOperationForwardprop.
214PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
215 PyObject* output_tensors,
216 PyObject* input_tensors,
217 PyObject* backward_function,
218 PyObject* forward_function);
219
220// Records an operation only for backprop (gradient tapes).
221//
222// Same arguments as TFE_Py_TapeSetRecordOperation.
223PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
224 PyObject* output_tensors,
225 PyObject* input_tensors,
226 PyObject* backward_function);
227
228// Records an operation only for forwardprop (forward accumulators).
229//
230// Arguments:
231// - op_type is a string for the operation type, used in the backprop code
232// - output_tensors are a list of Python Tensor objects output by the operation
233// - input_tensors are a list of input Tensors to the recorded operation
234// - backward_function is the function to be called to, given the gradients of
235// the output tensors, produce the gradients of the input tensors. This
236// function is automatically transposed to produce output gradients given
237// input gradients.
238// - forwardprop_output_indices indicates any output_tensors which contain
239// JVPs. Typically these will have come from TFE_Py_PackJVPs. May
240// be None or an empty sequence if there are no JVP outputs from the
241// operation.
242PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
243 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
244 PyObject* backward_function, PyObject* forwardprop_output_indices);
245
246// Notifies all tapes that a variable has been accessed.
247void TFE_Py_TapeVariableAccessed(PyObject* variable);
248
249// Watches the given variable object on the given tape.
250void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
251
252// Computes a gradient based on information recorded on the tape.`tape` must
253// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
254// lists of Tensor objects. `output_gradients` is either None or a python list
255// of either Tensor or None, and if not None should have the same length as
256// target.
257PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
258 PyObject* sources, PyObject* output_gradients,
259 PyObject* sources_raw,
260 PyObject* unconnected_gradients,
261 TF_Status* status);
262
263// Execute a tensorflow operation assuming that all provided inputs are
264// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
265// it will simply fail with a NotImplementedError.
266//
267// The "args" PyObject* is meant to be a tuple with the following structure:
268// Item 1: The Python eager Context object
269// Item 2: op_name: Name of the TensorFlow op to execute.
270// Item 3: name: An optional name for the operation.
271// Item 4 onwards: inputs - This is a list of inputs followed by a list of
272// attrs. It is not necessary for type attrs to be present.
273//
274// Note: the device_name and op_callbacks, which were previously passed
275// as arguments, are now read via GetEagerContextThreadLocalData().
276//
277// This is named _C since there doesn't seem to be any way to make it visible
278// in the SWIG interface without renaming due to the use of the %native
279// directive.
280PyObject* TFE_Py_FastPathExecute_C(PyObject* args);
281
282// Record the gradient for a given op.
283PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
284 PyObject* attrs, PyObject* results,
285 PyObject* forward_pass_name_scope);
286
287// Returns all variables watched by the given tape in the order those variables
288// were created.
289PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
290
291// Creates a new forward accumulator. Does not add it to the active set.
292PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
293
294// Adds a ForwardAccumulator to the active set, meaning it will watch executed
295// operations. It must not already be in the active set.
296PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator);
297// Removes a forward accumulator from the active set, meaning it will no longer
298// be watching operations.
299void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator);
300
301// Tell the forward accumulator `accumulator` to watch `tensor`, with a Tensor
302// tangent vector `tangent` of matching shape and dtype.
303void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
304 PyObject* tangent);
305
306// Looks up the Jacobian-vector product of `tensor` in the forward accumulator
307// `accumulator`. Returns None if no JVP is available.
308PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
309
310// Temporarily push or pop transient state for accumulators in the active set.
311//
312// Allows an accumulator which is currently processing an operation to
313// temporarily reset its state. This is useful when building forwardprop
314// versions of functions, where an accumulator will trigger function building
315// and then must process captured symbolic tensors while building it. Without
316// pushing and popping, accumulators ignore operations executed as a direct
317// result of their own jvp computations.
318PyObject* TFE_Py_ForwardAccumulatorPushState();
319PyObject* TFE_Py_ForwardAccumulatorPopState();
320
321// Collects state from all current forward accumulators related to `tensors`.
322//
323// This is useful for packing JVPs as function inputs before executing a
324// function which computes primals and JVPs at the same time.
325//
326// Does not include accumulators which are currently in the process of computing
327// a jvp (and so appear somewhere on the current execution stack) or any
328// accumulators more deeply nested.
329//
330// Includes JVPs for `tensors` and any higher-order JVPs for those
331// (recursively). Returns a two-element tuple (indices, jvps):
332// indices: A sequence of sequences of two-element tuples. Each forward
333// accumulator is represented as a sequence of tuples with (primal_index,
334// jvp_index). Both integers index into the concatenated `tensors + jvps`
335// array.
336// jvps: A flat list of Tensors. Best interpreted as a sequence to be
337// appended to `tensors`.
338PyObject* TFE_Py_PackJVPs(PyObject* tensors);
339
340// Variable Watcher methods.
341
342// Creates a new variable watcher and adds it to the set of active variable
343// watchers.
344PyObject* TFE_Py_VariableWatcherNew();
345
346// Removes the passed variable watcher from the set of active variable watchers.
347void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
348
349// Notifies all variable watchers that a variable has been accessed.
350void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
351
352// Returns all variables watched by the given variable_watcher in the order
353// those variables were created.
354PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
355
356// Returns an EagerTensor of dimension [len(`tensors`)] containing
357// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
358// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
359// `tensors`. For example, if `tensors` contains tensors of with shapes
360// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
361// `slice_dim` equal to 1 will return [2, 5, 7].
362// On error, returns nullptr and sets python exception.
363// REQUIRES: `tensors` is a python list/tuple of EagerTensors
364// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
365// tensors in `tensors`.
366PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
367
368// Returns the shape of this tensor's on-device representation.
369// The shape is represented as a Python tuple of integers.
370PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
371
372void TFE_Py_EnableInteractivePythonLogging();
373
374// Sets the current Python eager Context object (defined
375// in eager/context.py). This function must be called at least once before
376// eager tensors are created.
377// If an error is encountered, sets python error and returns NULL. Else, returns
378// Py_None.
379//
380// Not thread-safe.
381// TODO(mdan): Retire this - non-Python users should only need the EagerContext.
382PyObject* TFE_Py_SetEagerContext(PyObject* py_context);
383
384// Returns the current eager Context object (defined in eager/context.py)
385// that was last set using TFE_Py_SetEagerContext.
386// If an error is encountered, sets python error and returns NULL.
387// The returned PyObject is "new", i.e. the caller must call Py_DECREF on it at
388// some point.
389PyObject* GetPyEagerContext();
390
391// These are exposed since there is SWIG code that calls these.
392// Returns a pre-allocated status if it exists.
393TF_Status* GetStatus();
394// Returns the pre-allocated status to the code.
395void ReturnStatus(TF_Status* status);
396
397namespace tensorflow {
398
399// Returns the DataType for the specified tensor. Returns DT_INVALID if
400// PyObject is not a tensor.
401DataType PyTensor_DataType(PyObject* tensor);
402
403// Thread-local data associated with a Python eager Context object.
404//
405// TODO(edloper): Consider changing device_name and scope_name to a const char*
406// (with nullptr used for None). However, note that existing code (e.g.
407// TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings
408// extends beyond the point where their value is changed; so we'd need to make
409// sure that the strings stay alive (maybe using PyUnicode_InternInPlace?)
410struct EagerContextThreadLocalData {
411 bool is_eager = false;
412 bool invoking_op_callbacks = false;
413 tensorflow::Safe_PyObjectPtr device_name;
414 tensorflow::Safe_PyObjectPtr scope_name;
415 tensorflow::Safe_PyObjectPtr device_spec;
416 tensorflow::Safe_PyObjectPtr function_call_options;
417 tensorflow::Safe_PyObjectPtr executor;
418 tensorflow::Safe_PyObjectPtr op_callbacks;
419};
420
421// Create a thread-local-data structure associated with py_eager_context.
422// `is_eager` and `device_spec` are used to supply default values for those
423// fields whenever a new thread-local instance is created for py_eager_tensor.
424//
425// This function assumes that the Python GIL is held (and does not perform its
426// own locking).
427void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
428 PyObject* is_eager,
429 PyObject* device_spec);
430
431// Returns the thread-local instance of EagerContextThreadLocalData that is
432// associated with the given Python Context object. If an instance has not
433// yet been created for `py_eager_context` in this thread, then a new one is
434// created, and initialized with the default values specified in
435// MakeEagerContextThreadLocalData.
436EagerContextThreadLocalData* GetEagerContextThreadLocalData(
437 PyObject* py_eager_context);
438
439// Free data structures used to track py_eager_context.
440//
441// This frees global state associated with py_eager_context, as well as thread-
442// local state associated with py_eager_context and the current thread. If you
443// wish to destroy thread-local state associated with a single py_eager_context
444// for multiple threads, then you must call this method from each thread.
445//
446// Thread-local state assocaited with eager contexts is also automatically
447// cleaned up when the thread is destroyed.
448//
449// This function assumes that the Python GIL is held (and does not perform its
450// own locking).
451void DestroyEagerContextThreadLocalData(PyObject* py_eager_context);
452
453} // namespace tensorflow
454
455#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
456