1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
17 | #define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
18 | |
19 | // Place `<locale>` before <Python.h> to avoid build failure in macOS. |
20 | #include <locale> |
21 | |
22 | // The empty line above is on purpose as otherwise clang-format will |
23 | // automatically move <Python.h> before <locale>. |
24 | #include <Python.h> |
25 | |
26 | #include "tensorflow/c/eager/c_api.h" |
27 | #include "tensorflow/core/framework/types.pb.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
30 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
31 | |
32 | typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> |
33 | TFE_InputTensorHandles; |
34 | typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> |
35 | TFE_OutputTensorHandles; |
36 | |
37 | // Execute a TensorFlow operation. |
38 | // |
39 | // 'device_name': Name of the device on which to execute the operation, or NULL |
40 | // for automatic selection. |
41 | // 'op_name': Name of the TensorFlow op to execute. |
42 | // 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors |
43 | // will be provided as input to the operation. |
44 | // 'attrs': A Python tuple alternating names and attr values. |
45 | // 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will |
46 | // placed. On success, its elements will be filled in and the |
47 | // caller takes ownership of each returned TFE_TensorHandle. |
48 | // 'outputs' MUST be sized to be at least as large as the number |
49 | // of tensors produced by the operation and will be resized to |
50 | // the actual number of tensors produced. |
51 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, |
52 | const char* op_name, TFE_InputTensorHandles* inputs, |
53 | PyObject* attrs, TFE_OutputTensorHandles* outputs, |
54 | TF_Status* out_status); |
55 | |
56 | // Execute a cancelable TensorFlow operation. |
57 | // |
58 | // Arguments as above (for TFE_Py_Execute), with the addition of: |
59 | // 'cancellation_manager': A pointer to a TFE_CancellationManager that can be |
60 | // used to cancel execution of the given operation. |
61 | typedef struct TFE_CancellationManager TFE_CancellationManager; |
62 | void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, |
63 | const char* op_name, |
64 | TFE_InputTensorHandles* inputs, PyObject* attrs, |
65 | TFE_CancellationManager* cancellation_manager, |
66 | TFE_OutputTensorHandles* outputs, |
67 | TF_Status* out_status); |
68 | |
69 | // Registers e as the Exception class for handling not ok Status. Returns |
70 | // Py_None if registration succeeds, else throws a TypeError and returns NULL. |
71 | // |
72 | // This function is not thread-safe. |
73 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); |
74 | |
75 | // Registers e as the VSpace to use. |
76 | // `vspace` must be a imperative_grad.py:VSpace named tuple. |
77 | PyObject* TFE_Py_RegisterVSpace(PyObject* e); |
78 | |
79 | // Registers e as the Exception to be raised when the conditions of |
80 | // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it |
81 | // is a signal to the calling code that it should fall back to the safer (and |
82 | // more complete) code path. |
83 | // |
84 | // This function is not thread-safe. |
85 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); |
86 | |
87 | // Registers e as the gradient_function. |
88 | // The registered function takes |
89 | // (op_name, attrs, num_inputs, inputs, outputs, output_gradients) and returns |
90 | // the input gradients. This function will not correctly be able to generate |
91 | // gradients for functional ops - the gradients for those ops are calculated |
92 | // through a different codepath (see function.py for additional information). |
93 | // |
94 | // This function is not thread-safe. |
95 | PyObject* TFE_Py_RegisterGradientFunction(PyObject* e); |
96 | |
97 | // Registers e as the forward_gradient_function. The registered function takes |
98 | // (op_name, attrs, inputs, outputs, tangents) and returns the output |
99 | // tangents. This function is used only for operations, not for custom gradients |
100 | // or functional ops. |
101 | // |
102 | // This function is not thread-safe. |
103 | PyObject* TFE_Py_RegisterJVPFunction(PyObject* e); |
104 | |
105 | namespace tensorflow { |
106 | |
107 | // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using |
108 | // `exception` if not nullptr, else using the class registered via |
109 | // TFE_Py_RegisterExceptionClass), and returns -1. |
110 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception); |
111 | |
112 | } // namespace tensorflow |
113 | |
114 | // Returns 0 if 'status' is ok. Otherwise, raises an exception (using |
115 | // `exception` if not nullptr, else using the class registered via |
116 | // TFE_Py_RegisterExceptionClass), and returns -1. |
117 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, |
118 | PyObject* exception); |
119 | |
120 | // Returns the string associated with the passed-in python object. |
121 | const char* TFE_GetPythonString(PyObject* o); |
122 | |
123 | // Returns a unique id on each call. |
124 | int64_t get_uid(); |
125 | |
126 | // Wraps the output of get_uid as a Python Long object. Ownership is passed to |
127 | // the caller. |
128 | PyObject* TFE_Py_UID(); |
129 | |
130 | // Deleter for Context objects, called from the Capsule that owns it. |
131 | void TFE_DeleteContextCapsule(PyObject* context); |
132 | |
133 | // Returns true if o is an instance of EagerTensor, but not a subclass. Else |
134 | // returns false. |
135 | bool EagerTensor_CheckExact(const PyObject* o); |
136 | |
137 | // Helper function to construct a new EagerTensor from a TFE_TensorHandle. |
138 | PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle, |
139 | const bool is_packed = false); |
140 | |
141 | // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error. |
142 | TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); |
143 | |
144 | // Creates the `EagerTensor` class by subclassing `base_class` and returns the |
145 | // newly created type, or nullptr on error. |
146 | PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); |
147 | |
148 | // Sets `profiler` as the current profiler to receive callbacks about events |
149 | // on eager tensors. Currently, the only reported event is creation. |
150 | // `profiler` is expected to have a `created(self, eager_tensor)` method that |
151 | // takes the created tensor as its single argument. |
152 | // Previous profiler, if any, is unset and will not receive any more |
153 | // callbacks. |
154 | // To unset the profiler, pass Py_None as the value of `profiler`. |
155 | PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); |
156 | |
157 | // Creates a new tape and adds it to the active set. `persistent` and |
158 | // `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`). |
159 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent, |
160 | PyObject* watch_accessed_variables); |
161 | |
162 | // Removes the passed tape from the set of active tapes. |
163 | void TFE_Py_TapeSetRemove(PyObject* tape); |
164 | |
165 | // Adds the passed tape to the set of active tapes. |
166 | void TFE_Py_TapeSetAdd(PyObject* tape); |
167 | |
168 | // Returns true if the tape stack is empty. |
169 | PyObject* TFE_Py_TapeSetIsEmpty(); |
170 | |
171 | // Check if any backward tape should record an operation given inputs. |
172 | // |
173 | // Does not take forward accumulators into account. |
174 | PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors); |
175 | |
176 | // Determine possible gradient types, taking forward accumulators into account. |
177 | // - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecordBackprop |
178 | // is false and no forward accumulator is watching) |
179 | // - 1 if first-order gradients may be requested |
180 | // - 2 if higher-order gradients may be requested |
181 | PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors); |
182 | |
183 | void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor); |
184 | void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id); |
185 | |
186 | // Stops any gradient recording on the current thread. |
187 | // |
188 | // Includes forward accumulators. |
189 | void TFE_Py_TapeSetStopOnThread(); |
190 | |
191 | // Restarts gradient recording on the current thread. |
192 | void TFE_Py_TapeSetRestartOnThread(); |
193 | |
194 | // Checks whether gradient recording is stopped on the current thread. |
195 | PyObject* TFE_Py_TapeSetIsStopped(); |
196 | |
197 | // Records an operation for the purpose of gradient computation. |
198 | // |
199 | // Arguments: |
200 | // - op_type is a string for the operation type, used in the backprop code |
201 | // - output_tensors are a list of Python Tensor objects output by the operation |
202 | // - input_tensors are a list of input Tensors to the recorded operation |
203 | // - backward_function is the function to be called during backprop or |
204 | // forwardprop to, given the gradients of the output tensors, produce the |
205 | // gradients of the input tensors. This function is automatically transposed |
206 | // during forwardprop. |
207 | // - forward_function is an optional special-case for forwardprop, taking input |
208 | // jvps and returning output jvps. |
209 | // |
210 | // Records an operation both for backprop (gradient tape) and forwardprop |
211 | // (forward accumulator). Equivalent to calling both |
212 | // TFE_Py_TapeSetRecordOperationBackprop and |
213 | // TFE_Py_TapeSetRecordOperationForwardprop. |
214 | PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, |
215 | PyObject* output_tensors, |
216 | PyObject* input_tensors, |
217 | PyObject* backward_function, |
218 | PyObject* forward_function); |
219 | |
220 | // Records an operation only for backprop (gradient tapes). |
221 | // |
222 | // Same arguments as TFE_Py_TapeSetRecordOperation. |
223 | PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, |
224 | PyObject* output_tensors, |
225 | PyObject* input_tensors, |
226 | PyObject* backward_function); |
227 | |
228 | // Records an operation only for forwardprop (forward accumulators). |
229 | // |
230 | // Arguments: |
231 | // - op_type is a string for the operation type, used in the backprop code |
232 | // - output_tensors are a list of Python Tensor objects output by the operation |
233 | // - input_tensors are a list of input Tensors to the recorded operation |
234 | // - backward_function is the function to be called to, given the gradients of |
235 | // the output tensors, produce the gradients of the input tensors. This |
236 | // function is automatically transposed to produce output gradients given |
237 | // input gradients. |
238 | // - forwardprop_output_indices indicates any output_tensors which contain |
239 | // JVPs. Typically these will have come from TFE_Py_PackJVPs. May |
240 | // be None or an empty sequence if there are no JVP outputs from the |
241 | // operation. |
242 | PyObject* TFE_Py_TapeSetRecordOperationForwardprop( |
243 | PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, |
244 | PyObject* backward_function, PyObject* forwardprop_output_indices); |
245 | |
246 | // Notifies all tapes that a variable has been accessed. |
247 | void TFE_Py_TapeVariableAccessed(PyObject* variable); |
248 | |
249 | // Watches the given variable object on the given tape. |
250 | void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); |
251 | |
252 | // Computes a gradient based on information recorded on the tape.`tape` must |
253 | // have been produced by TFE_Py_NewTape. `target` and `sources` must be python |
254 | // lists of Tensor objects. `output_gradients` is either None or a python list |
255 | // of either Tensor or None, and if not None should have the same length as |
256 | // target. |
257 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, |
258 | PyObject* sources, PyObject* output_gradients, |
259 | PyObject* sources_raw, |
260 | PyObject* unconnected_gradients, |
261 | TF_Status* status); |
262 | |
263 | // Execute a tensorflow operation assuming that all provided inputs are |
264 | // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, |
265 | // it will simply fail with a NotImplementedError. |
266 | // |
267 | // The "args" PyObject* is meant to be a tuple with the following structure: |
268 | // Item 1: The Python eager Context object |
269 | // Item 2: op_name: Name of the TensorFlow op to execute. |
270 | // Item 3: name: An optional name for the operation. |
271 | // Item 4 onwards: inputs - This is a list of inputs followed by a list of |
272 | // attrs. It is not necessary for type attrs to be present. |
273 | // |
274 | // Note: the device_name and op_callbacks, which were previously passed |
275 | // as arguments, are now read via GetEagerContextThreadLocalData(). |
276 | // |
277 | // This is named _C since there doesn't seem to be any way to make it visible |
278 | // in the SWIG interface without renaming due to the use of the %native |
279 | // directive. |
280 | PyObject* TFE_Py_FastPathExecute_C(PyObject* args); |
281 | |
282 | // Record the gradient for a given op. |
283 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, |
284 | PyObject* attrs, PyObject* results, |
285 | PyObject* forward_pass_name_scope); |
286 | |
287 | // Returns all variables watched by the given tape in the order those variables |
288 | // were created. |
289 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); |
290 | |
291 | // Creates a new forward accumulator. Does not add it to the active set. |
292 | PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch); |
293 | |
294 | // Adds a ForwardAccumulator to the active set, meaning it will watch executed |
295 | // operations. It must not already be in the active set. |
296 | PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator); |
297 | // Removes a forward accumulator from the active set, meaning it will no longer |
298 | // be watching operations. |
299 | void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator); |
300 | |
301 | // Tell the forward accumulator `accumulator` to watch `tensor`, with a Tensor |
302 | // tangent vector `tangent` of matching shape and dtype. |
303 | void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, |
304 | PyObject* tangent); |
305 | |
306 | // Looks up the Jacobian-vector product of `tensor` in the forward accumulator |
307 | // `accumulator`. Returns None if no JVP is available. |
308 | PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor); |
309 | |
310 | // Temporarily push or pop transient state for accumulators in the active set. |
311 | // |
312 | // Allows an accumulator which is currently processing an operation to |
313 | // temporarily reset its state. This is useful when building forwardprop |
314 | // versions of functions, where an accumulator will trigger function building |
315 | // and then must process captured symbolic tensors while building it. Without |
316 | // pushing and popping, accumulators ignore operations executed as a direct |
317 | // result of their own jvp computations. |
318 | PyObject* TFE_Py_ForwardAccumulatorPushState(); |
319 | PyObject* TFE_Py_ForwardAccumulatorPopState(); |
320 | |
321 | // Collects state from all current forward accumulators related to `tensors`. |
322 | // |
323 | // This is useful for packing JVPs as function inputs before executing a |
324 | // function which computes primals and JVPs at the same time. |
325 | // |
326 | // Does not include accumulators which are currently in the process of computing |
327 | // a jvp (and so appear somewhere on the current execution stack) or any |
328 | // accumulators more deeply nested. |
329 | // |
330 | // Includes JVPs for `tensors` and any higher-order JVPs for those |
331 | // (recursively). Returns a two-element tuple (indices, jvps): |
332 | // indices: A sequence of sequences of two-element tuples. Each forward |
333 | // accumulator is represented as a sequence of tuples with (primal_index, |
334 | // jvp_index). Both integers index into the concatenated `tensors + jvps` |
335 | // array. |
336 | // jvps: A flat list of Tensors. Best interpreted as a sequence to be |
337 | // appended to `tensors`. |
338 | PyObject* TFE_Py_PackJVPs(PyObject* tensors); |
339 | |
340 | // Variable Watcher methods. |
341 | |
342 | // Creates a new variable watcher and adds it to the set of active variable |
343 | // watchers. |
344 | PyObject* TFE_Py_VariableWatcherNew(); |
345 | |
346 | // Removes the passed variable watcher from the set of active variable watchers. |
347 | void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher); |
348 | |
349 | // Notifies all variable watchers that a variable has been accessed. |
350 | void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable); |
351 | |
352 | // Returns all variables watched by the given variable_watcher in the order |
353 | // those variables were created. |
354 | PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher); |
355 | |
356 | // Returns an EagerTensor of dimension [len(`tensors`)] containing |
357 | // the `slice_dim`'th dimension of each tensor in `tensors`. In other words, |
358 | // TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in |
359 | // `tensors`. For example, if `tensors` contains tensors of with shapes |
360 | // [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with |
361 | // `slice_dim` equal to 1 will return [2, 5, 7]. |
362 | // On error, returns nullptr and sets python exception. |
363 | // REQUIRES: `tensors` is a python list/tuple of EagerTensors |
364 | // REQUIRES: `slice_dim` is non-negative and smaller than the rank of all |
365 | // tensors in `tensors`. |
366 | PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim); |
367 | |
368 | // Returns the shape of this tensor's on-device representation. |
369 | // The shape is represented as a Python tuple of integers. |
370 | PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor); |
371 | |
372 | void TFE_Py_EnableInteractivePythonLogging(); |
373 | |
374 | // Sets the current Python eager Context object (defined |
375 | // in eager/context.py). This function must be called at least once before |
376 | // eager tensors are created. |
377 | // If an error is encountered, sets python error and returns NULL. Else, returns |
378 | // Py_None. |
379 | // |
380 | // Not thread-safe. |
381 | // TODO(mdan): Retire this - non-Python users should only need the EagerContext. |
382 | PyObject* TFE_Py_SetEagerContext(PyObject* py_context); |
383 | |
384 | // Returns the current eager Context object (defined in eager/context.py) |
385 | // that was last set using TFE_Py_SetEagerContext. |
386 | // If an error is encountered, sets python error and returns NULL. |
387 | // The returned PyObject is "new", i.e. the caller must call Py_DECREF on it at |
388 | // some point. |
389 | PyObject* GetPyEagerContext(); |
390 | |
391 | // These are exposed since there is SWIG code that calls these. |
392 | // Returns a pre-allocated status if it exists. |
393 | TF_Status* GetStatus(); |
394 | // Returns the pre-allocated status to the code. |
395 | void ReturnStatus(TF_Status* status); |
396 | |
397 | namespace tensorflow { |
398 | |
399 | // Returns the DataType for the specified tensor. Returns DT_INVALID if |
400 | // PyObject is not a tensor. |
401 | DataType PyTensor_DataType(PyObject* tensor); |
402 | |
403 | // Thread-local data associated with a Python eager Context object. |
404 | // |
405 | // TODO(edloper): Consider changing device_name and scope_name to a const char* |
406 | // (with nullptr used for None). However, note that existing code (e.g. |
407 | // TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings |
408 | // extends beyond the point where their value is changed; so we'd need to make |
409 | // sure that the strings stay alive (maybe using PyUnicode_InternInPlace?) |
410 | struct EagerContextThreadLocalData { |
411 | bool is_eager = false; |
412 | bool invoking_op_callbacks = false; |
413 | tensorflow::Safe_PyObjectPtr device_name; |
414 | tensorflow::Safe_PyObjectPtr scope_name; |
415 | tensorflow::Safe_PyObjectPtr device_spec; |
416 | tensorflow::Safe_PyObjectPtr function_call_options; |
417 | tensorflow::Safe_PyObjectPtr executor; |
418 | tensorflow::Safe_PyObjectPtr op_callbacks; |
419 | }; |
420 | |
421 | // Create a thread-local-data structure associated with py_eager_context. |
422 | // `is_eager` and `device_spec` are used to supply default values for those |
423 | // fields whenever a new thread-local instance is created for py_eager_tensor. |
424 | // |
425 | // This function assumes that the Python GIL is held (and does not perform its |
426 | // own locking). |
427 | void MakeEagerContextThreadLocalData(PyObject* py_eager_context, |
428 | PyObject* is_eager, |
429 | PyObject* device_spec); |
430 | |
431 | // Returns the thread-local instance of EagerContextThreadLocalData that is |
432 | // associated with the given Python Context object. If an instance has not |
433 | // yet been created for `py_eager_context` in this thread, then a new one is |
434 | // created, and initialized with the default values specified in |
435 | // MakeEagerContextThreadLocalData. |
436 | EagerContextThreadLocalData* GetEagerContextThreadLocalData( |
437 | PyObject* py_eager_context); |
438 | |
439 | // Free data structures used to track py_eager_context. |
440 | // |
441 | // This frees global state associated with py_eager_context, as well as thread- |
442 | // local state associated with py_eager_context and the current thread. If you |
443 | // wish to destroy thread-local state associated with a single py_eager_context |
444 | // for multiple threads, then you must call this method from each thread. |
445 | // |
446 | // Thread-local state assocaited with eager contexts is also automatically |
447 | // cleaned up when the thread is destroyed. |
448 | // |
449 | // This function assumes that the Python GIL is held (and does not perform its |
450 | // own locking). |
451 | void DestroyEagerContextThreadLocalData(PyObject* py_eager_context); |
452 | |
453 | } // namespace tensorflow |
454 | |
455 | #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
456 | |