1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <atomic>
17#include <cstring>
18#include <unordered_map>
19
20#include "absl/debugging/leak_check.h"
21#include "absl/strings/str_cat.h"
22#include "absl/strings/str_replace.h"
23#include "absl/types/variant.h"
24#include "tensorflow/c/c_api.h"
25#include "tensorflow/c/c_api_internal.h"
26#include "tensorflow/c/eager/c_api.h"
27#include "tensorflow/c/eager/c_api_internal.h"
28#include "tensorflow/c/eager/tape.h"
29#include "tensorflow/c/eager/tfe_context_internal.h"
30#include "tensorflow/c/eager/tfe_op_internal.h"
31#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
32#include "tensorflow/c/tf_status.h"
33#include "tensorflow/core/framework/types.pb.h"
34#include "tensorflow/core/lib/core/errors.h"
35#include "tensorflow/core/lib/gtl/cleanup.h"
36#include "tensorflow/core/lib/gtl/compactptrset.h"
37#include "tensorflow/core/lib/gtl/flatmap.h"
38#include "tensorflow/core/lib/gtl/flatset.h"
39#include "tensorflow/core/lib/strings/strcat.h"
40#include "tensorflow/core/lib/strings/stringprintf.h"
41#include "tensorflow/core/platform/casts.h"
42#include "tensorflow/core/platform/errors.h"
43#include "tensorflow/core/platform/mutex.h"
44#include "tensorflow/core/platform/protobuf.h"
45#include "tensorflow/core/platform/status.h"
46#include "tensorflow/core/platform/statusor.h"
47#include "tensorflow/core/platform/types.h"
48#include "tensorflow/core/profiler/lib/traceme.h"
49#include "tensorflow/core/util/managed_stack_trace.h"
50#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
51#include "tensorflow/python/eager/pywrap_tensor.h"
52#include "tensorflow/python/eager/pywrap_tfe.h"
53#include "tensorflow/python/lib/core/py_util.h"
54#include "tensorflow/python/lib/core/safe_ptr.h"
55#include "tensorflow/python/util/stack_trace.h"
56#include "tensorflow/python/util/util.h"
57
58using tensorflow::Status;
59using tensorflow::string;
60using tensorflow::strings::Printf;
61
62namespace {
63// NOTE: Items are retrieved from and returned to these unique_ptrs, and they
64// act as arenas. This is important if the same thread requests 2 items without
65// releasing one.
66// The following sequence of events on the same thread will still succeed:
67// - GetOp <- Returns existing.
68// - GetOp <- Allocates and returns a new pointer.
69// - ReleaseOp <- Sets the item in the unique_ptr.
70// - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
71// This occurs when a PyFunc kernel is run. This behavior makes it safe in that
72// case, as well as the case where python decides to reuse the underlying
73// C++ thread in 2 python threads case.
74struct OpDeleter {
75 void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
76};
77thread_local std::unordered_map<TFE_Context*,
78 std::unique_ptr<TFE_Op, OpDeleter>>
79 thread_local_eager_operation_map; // NOLINT
80thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
81 nullptr;
82
83std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
84 auto it = thread_local_eager_operation_map.find(ctx);
85 if (it == thread_local_eager_operation_map.end()) {
86 return nullptr;
87 }
88 return std::move(it->second);
89}
90
91TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
92 const char* raw_device_name, TF_Status* status) {
93 auto op = ReleaseThreadLocalOp(ctx);
94 if (!op) {
95 op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
96 }
97 status->status =
98 tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
99 if (!status->status.ok()) {
100 op.reset();
101 }
102 return op.release();
103}
104
105void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
106 if (op) {
107 tensorflow::unwrap(op)->Clear();
108 thread_local_eager_operation_map[ctx].reset(op);
109 }
110}
111
112TF_Status* ReleaseThreadLocalStatus() {
113 if (thread_local_tf_status == nullptr) {
114 return nullptr;
115 }
116 return thread_local_tf_status.release();
117}
118
119struct InputInfo {
120 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
121
122 int i;
123 bool is_list = false;
124};
125
126// Takes in output gradients, returns input gradients.
127typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)>
128 PyBackwardFunction;
129
130using AttrToInputsMap =
131 tensorflow::gtl::FlatMap<string,
132 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
133
134tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
135 static auto* all_attr_to_input_maps =
136 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
137 return all_attr_to_input_maps;
138}
139
140// This function doesn't use a lock, since we depend on the GIL directly.
141AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
142#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
143 DCHECK(PyGILState_Check())
144 << "This function needs to hold the GIL when called.";
145#endif
146 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
147 auto* output =
148 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
149 if (output != nullptr) {
150 return output;
151 }
152
153 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
154
155 // Store a list of InputIndex -> List of corresponding inputs.
156 for (int i = 0; i < op_def.input_arg_size(); i++) {
157 if (!op_def.input_arg(i).type_attr().empty()) {
158 auto it = m->find(op_def.input_arg(i).type_attr());
159 if (it == m->end()) {
160 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
161 }
162 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
163 }
164 }
165
166 auto* retval = m.get();
167 (*all_attr_to_input_maps)[op_def.name()] = m.release();
168
169 return retval;
170}
171
172// This function doesn't use a lock, since we depend on the GIL directly.
173tensorflow::gtl::FlatMap<
174 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
175GetAllAttrToDefaultsMaps() {
176 static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
177 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
178 return all_attr_to_defaults_maps;
179}
180
181tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
182GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
183#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
184 DCHECK(PyGILState_Check())
185 << "This function needs to hold the GIL when called.";
186#endif
187 auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
188 auto* output =
189 tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
190 if (output != nullptr) {
191 return output;
192 }
193
194 auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
195
196 for (const auto& attr : op_def.attr()) {
197 if (attr.type() == "type" && attr.has_default_value()) {
198 new_map->insert({attr.name(), attr.default_value().type()});
199 }
200 }
201
202 (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
203
204 return new_map;
205}
206
207struct FastPathOpExecInfo {
208 TFE_Context* ctx;
209 const char* device_name;
210
211 bool run_callbacks;
212 bool run_post_exec_callbacks;
213 bool run_gradient_callback;
214
215 // The op name of the main op being executed.
216 PyObject* name;
217 // The op type name of the main op being executed.
218 PyObject* op_name;
219 PyObject* callbacks;
220
221 // All the args passed into the FastPathOpExecInfo.
222 PyObject* args;
223
224 // DTypes can come from another input that has the same attr. So build that
225 // map.
226 const AttrToInputsMap* attr_to_inputs_map;
227 const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
228 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
229};
230
231#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
232 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
233 type* value) { \
234 if (check_fn(py_value)) { \
235 *value = static_cast<type>(parse_fn(py_value)); \
236 return true; \
237 } else { \
238 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
239 tensorflow::strings::StrCat( \
240 "Expecting " #type " value for attr ", key, ", got ", \
241 py_value->ob_type->tp_name) \
242 .c_str()); \
243 return false; \
244 } \
245 }
246
247#if PY_MAJOR_VERSION >= 3
248PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
249PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
250#else
251PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
252#endif
253PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
254#undef PARSE_VALUE
255
256#if PY_MAJOR_VERSION < 3
257bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
258 int64_t* value) {
259 if (PyInt_Check(py_value)) {
260 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
261 return true;
262 } else if (PyLong_Check(py_value)) {
263 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
264 return true;
265 }
266 TF_SetStatus(
267 status, TF_INVALID_ARGUMENT,
268 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
269 ", got ", py_value->ob_type->tp_name)
270 .c_str());
271 return false;
272}
273#endif
274
275Py_ssize_t TensorShapeNumDims(PyObject* value) {
276 const auto size = PySequence_Size(value);
277 if (size == -1) {
278 // TensorShape.__len__ raises an error in the scenario where the shape is an
279 // unknown, which needs to be cleared.
280 // TODO(nareshmodi): ensure that this is actually a TensorShape.
281 PyErr_Clear();
282 }
283 return size;
284}
285
286bool IsInteger(PyObject* py_value) {
287#if PY_MAJOR_VERSION >= 3
288 return PyLong_Check(py_value);
289#else
290 return PyInt_Check(py_value) || PyLong_Check(py_value);
291#endif
292}
293
294// This function considers a Dimension._value of None to be valid, and sets the
295// value to be -1 in that case.
296bool ParseDimensionValue(const string& key, PyObject* py_value,
297 TF_Status* status, int64_t* value) {
298 if (IsInteger(py_value)) {
299 return ParseInt64Value(key, py_value, status, value);
300 }
301
302 tensorflow::Safe_PyObjectPtr dimension_value(
303 PyObject_GetAttrString(py_value, "_value"));
304 if (dimension_value == nullptr) {
305 PyErr_Clear();
306 TF_SetStatus(
307 status, TF_INVALID_ARGUMENT,
308 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
309 ", got ", py_value->ob_type->tp_name)
310 .c_str());
311 return false;
312 }
313
314 if (dimension_value.get() == Py_None) {
315 *value = -1;
316 return true;
317 }
318
319 return ParseInt64Value(key, dimension_value.get(), status, value);
320}
321
322bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
323 tensorflow::StringPiece* value) {
324 if (PyBytes_Check(py_value)) {
325 Py_ssize_t size = 0;
326 char* buf = nullptr;
327 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
328 *value = tensorflow::StringPiece(buf, size);
329 return true;
330 }
331#if PY_MAJOR_VERSION >= 3
332 if (PyUnicode_Check(py_value)) {
333 Py_ssize_t size = 0;
334 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
335 if (buf == nullptr) return false;
336 *value = tensorflow::StringPiece(buf, size);
337 return true;
338 }
339#endif
340 TF_SetStatus(
341 status, TF_INVALID_ARGUMENT,
342 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
343 ", got ", py_value->ob_type->tp_name)
344 .c_str());
345 return false;
346}
347
348bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
349 unsigned char* value) {
350 if (PyBool_Check(py_value)) {
351 *value = PyObject_IsTrue(py_value);
352 return true;
353 }
354 TF_SetStatus(
355 status, TF_INVALID_ARGUMENT,
356 tensorflow::strings::StrCat("Expecting bool value for attr ", key,
357 ", got ", py_value->ob_type->tp_name)
358 .c_str());
359 return false;
360}
361
362// The passed in py_value is expected to be an object of the python type
363// dtypes.DType or an int.
364bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
365 int* value) {
366 if (IsInteger(py_value)) {
367 return ParseIntValue(key, py_value, status, value);
368 }
369
370 tensorflow::Safe_PyObjectPtr py_type_enum(
371 PyObject_GetAttrString(py_value, "_type_enum"));
372 if (py_type_enum == nullptr) {
373 PyErr_Clear();
374 TF_SetStatus(
375 status, TF_INVALID_ARGUMENT,
376 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
377 ", got ", py_value->ob_type->tp_name)
378 .c_str());
379 return false;
380 }
381
382 return ParseIntValue(key, py_type_enum.get(), status, value);
383}
384
385bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
386 PyObject* py_list, TF_AttrType type,
387 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
388 TF_Status* status) {
389 if (!PySequence_Check(py_list)) {
390 TF_SetStatus(
391 status, TF_INVALID_ARGUMENT,
392 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
393 ", got ", py_list->ob_type->tp_name)
394 .c_str());
395 return false;
396 }
397 const int num_values = PySequence_Size(py_list);
398 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
399
400#define PARSE_LIST(c_type, parse_fn) \
401 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
402 for (int i = 0; i < num_values; ++i) { \
403 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
404 if (py_value == nullptr) { \
405 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
406 tensorflow::strings::StrCat( \
407 "Expecting sequence of " #c_type " for attr ", key, \
408 ", got ", py_list->ob_type->tp_name) \
409 .c_str()); \
410 return false; \
411 } else if (!parse_fn(key, py_value.get(), status, &values[i])) { \
412 return false; \
413 } \
414 }
415
416 if (type == TF_ATTR_STRING) {
417 std::unique_ptr<const void*[]> values(new const void*[num_values]);
418 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
419 for (int i = 0; i < num_values; ++i) {
420 tensorflow::StringPiece value;
421 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
422 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
423 values[i] = value.data();
424 lengths[i] = value.size();
425 }
426 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
427 } else if (type == TF_ATTR_INT) {
428 PARSE_LIST(int64_t, ParseInt64Value);
429 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
430 } else if (type == TF_ATTR_FLOAT) {
431 PARSE_LIST(float, ParseFloatValue);
432 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
433 } else if (type == TF_ATTR_BOOL) {
434 PARSE_LIST(unsigned char, ParseBoolValue);
435 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
436 } else if (type == TF_ATTR_TYPE) {
437 PARSE_LIST(int, ParseTypeValue);
438 TFE_OpSetAttrTypeList(op, key,
439 reinterpret_cast<const TF_DataType*>(values.get()),
440 num_values);
441 } else if (type == TF_ATTR_SHAPE) {
442 // Make one pass through the input counting the total number of
443 // dims across all the input lists.
444 int total_dims = 0;
445 for (int i = 0; i < num_values; ++i) {
446 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
447 if (py_value.get() != Py_None) {
448 if (!PySequence_Check(py_value.get())) {
449 TF_SetStatus(
450 status, TF_INVALID_ARGUMENT,
451 tensorflow::strings::StrCat(
452 "Expecting None or sequence value for element", i,
453 " of attr ", key, ", got ", py_value->ob_type->tp_name)
454 .c_str());
455 return false;
456 }
457 const auto size = TensorShapeNumDims(py_value.get());
458 if (size >= 0) {
459 total_dims += size;
460 }
461 }
462 }
463 // Allocate a buffer that can fit all of the dims together.
464 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
465 // Copy the input dims into the buffer and set dims to point to
466 // the start of each list's dims.
467 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
468 std::unique_ptr<int[]> num_dims(new int[num_values]);
469 int64_t* offset = buffer.get();
470 for (int i = 0; i < num_values; ++i) {
471 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
472 if (py_value.get() == Py_None) {
473 dims[i] = nullptr;
474 num_dims[i] = -1;
475 } else {
476 const auto size = TensorShapeNumDims(py_value.get());
477 if (size == -1) {
478 dims[i] = nullptr;
479 num_dims[i] = -1;
480 continue;
481 }
482 dims[i] = offset;
483 num_dims[i] = size;
484 for (int j = 0; j < size; ++j) {
485 tensorflow::Safe_PyObjectPtr inner_py_value(
486 PySequence_ITEM(py_value.get(), j));
487 if (inner_py_value.get() == Py_None) {
488 *offset = -1;
489 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
490 offset)) {
491 return false;
492 }
493 ++offset;
494 }
495 }
496 }
497 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
498 status);
499 if (!status->status.ok()) return false;
500 } else if (type == TF_ATTR_FUNC) {
501 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
502 for (int i = 0; i < num_values; ++i) {
503 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
504 // Allow:
505 // (1) String function name, OR
506 // (2) A Python object with a .name attribute
507 // (A crude test for being a
508 // tensorflow.python.framework.function._DefinedFunction)
509 // (which is what the various "defun" or "Defun" decorators do).
510 // And in the future also allow an object that can encapsulate
511 // the function name and its attribute values.
512 tensorflow::StringPiece func_name;
513 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
514 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
515 if (name_attr == nullptr ||
516 !ParseStringValue(key, name_attr, status, &func_name)) {
517 TF_SetStatus(
518 status, TF_INVALID_ARGUMENT,
519 tensorflow::strings::StrCat(
520 "unable to set function value attribute from a ",
521 py_value.get()->ob_type->tp_name,
522 " object. If you think this is an error, please file an "
523 "issue at "
524 "https://github.com/tensorflow/tensorflow/issues/new")
525 .c_str());
526 return false;
527 }
528 }
529 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
530 if (!status->status.ok()) return false;
531 }
532 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
533 if (!status->status.ok()) return false;
534 } else {
535 TF_SetStatus(status, TF_UNIMPLEMENTED,
536 tensorflow::strings::StrCat("Attr ", key,
537 " has unhandled list type ", type)
538 .c_str());
539 return false;
540 }
541#undef PARSE_LIST
542 return true;
543}
544
545TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
546 TF_Status* status) {
547 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
548 for (const auto& attr : func.attr()) {
549 if (!status->status.ok()) return nullptr;
550 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
551 if (!status->status.ok()) return nullptr;
552 }
553 return func_op;
554}
555
556void SetOpAttrListDefault(
557 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
558 const char* key, TF_AttrType type,
559 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
560 TF_Status* status) {
561 if (type == TF_ATTR_STRING) {
562 int num_values = attr.default_value().list().s_size();
563 std::unique_ptr<const void*[]> values(new const void*[num_values]);
564 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
565 (*attr_list_sizes)[key] = num_values;
566 for (int i = 0; i < num_values; i++) {
567 const string& v = attr.default_value().list().s(i);
568 values[i] = v.data();
569 lengths[i] = v.size();
570 }
571 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
572 } else if (type == TF_ATTR_INT) {
573 int num_values = attr.default_value().list().i_size();
574 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
575 (*attr_list_sizes)[key] = num_values;
576 for (int i = 0; i < num_values; i++) {
577 values[i] = attr.default_value().list().i(i);
578 }
579 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
580 } else if (type == TF_ATTR_FLOAT) {
581 int num_values = attr.default_value().list().f_size();
582 std::unique_ptr<float[]> values(new float[num_values]);
583 (*attr_list_sizes)[key] = num_values;
584 for (int i = 0; i < num_values; i++) {
585 values[i] = attr.default_value().list().f(i);
586 }
587 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
588 } else if (type == TF_ATTR_BOOL) {
589 int num_values = attr.default_value().list().b_size();
590 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
591 (*attr_list_sizes)[key] = num_values;
592 for (int i = 0; i < num_values; i++) {
593 values[i] = attr.default_value().list().b(i);
594 }
595 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
596 } else if (type == TF_ATTR_TYPE) {
597 int num_values = attr.default_value().list().type_size();
598 std::unique_ptr<int[]> values(new int[num_values]);
599 (*attr_list_sizes)[key] = num_values;
600 for (int i = 0; i < num_values; i++) {
601 values[i] = attr.default_value().list().type(i);
602 }
603 TFE_OpSetAttrTypeList(op, key,
604 reinterpret_cast<const TF_DataType*>(values.get()),
605 attr.default_value().list().type_size());
606 } else if (type == TF_ATTR_SHAPE) {
607 int num_values = attr.default_value().list().shape_size();
608 (*attr_list_sizes)[key] = num_values;
609 int total_dims = 0;
610 for (int i = 0; i < num_values; ++i) {
611 if (!attr.default_value().list().shape(i).unknown_rank()) {
612 total_dims += attr.default_value().list().shape(i).dim_size();
613 }
614 }
615 // Allocate a buffer that can fit all of the dims together.
616 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
617 // Copy the input dims into the buffer and set dims to point to
618 // the start of each list's dims.
619 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
620 std::unique_ptr<int[]> num_dims(new int[num_values]);
621 int64_t* offset = buffer.get();
622 for (int i = 0; i < num_values; ++i) {
623 const auto& shape = attr.default_value().list().shape(i);
624 if (shape.unknown_rank()) {
625 dims[i] = nullptr;
626 num_dims[i] = -1;
627 } else {
628 for (int j = 0; j < shape.dim_size(); j++) {
629 *offset = shape.dim(j).size();
630 ++offset;
631 }
632 }
633 }
634 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
635 status);
636 } else if (type == TF_ATTR_FUNC) {
637 int num_values = attr.default_value().list().func_size();
638 (*attr_list_sizes)[key] = num_values;
639 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
640 for (int i = 0; i < num_values; i++) {
641 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
642 }
643 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
644 } else {
645 TF_SetStatus(status, TF_UNIMPLEMENTED,
646 "Lists of tensors are not yet implemented for default valued "
647 "attributes for an operation.");
648 }
649}
650
651bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
652 PyObject* py_value, TF_AttrType type,
653 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
654 TF_Status* status) {
655 if (type == TF_ATTR_STRING) {
656 tensorflow::StringPiece value;
657 if (!ParseStringValue(key, py_value, status, &value)) return false;
658 TFE_OpSetAttrString(op, key, value.data(), value.size());
659 } else if (type == TF_ATTR_INT) {
660 int64_t value;
661 if (!ParseInt64Value(key, py_value, status, &value)) return false;
662 TFE_OpSetAttrInt(op, key, value);
663 // attr_list_sizes is set for all int attributes (since at this point we are
664 // not aware if that attribute might be used to calculate the size of an
665 // output list or not).
666 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
667 } else if (type == TF_ATTR_FLOAT) {
668 float value;
669 if (!ParseFloatValue(key, py_value, status, &value)) return false;
670 TFE_OpSetAttrFloat(op, key, value);
671 } else if (type == TF_ATTR_BOOL) {
672 unsigned char value;
673 if (!ParseBoolValue(key, py_value, status, &value)) return false;
674 TFE_OpSetAttrBool(op, key, value);
675 } else if (type == TF_ATTR_TYPE) {
676 int value;
677 if (!ParseTypeValue(key, py_value, status, &value)) return false;
678 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
679 } else if (type == TF_ATTR_SHAPE) {
680 if (py_value == Py_None) {
681 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
682 } else {
683 if (!PySequence_Check(py_value)) {
684 TF_SetStatus(status, TF_INVALID_ARGUMENT,
685 tensorflow::strings::StrCat(
686 "Expecting None or sequence value for attr", key,
687 ", got ", py_value->ob_type->tp_name)
688 .c_str());
689 return false;
690 }
691 const auto num_dims = TensorShapeNumDims(py_value);
692 if (num_dims == -1) {
693 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
694 return true;
695 }
696 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
697 for (int i = 0; i < num_dims; ++i) {
698 tensorflow::Safe_PyObjectPtr inner_py_value(
699 PySequence_ITEM(py_value, i));
700 // If an error is generated when iterating through object, we can
701 // sometimes get a nullptr.
702 if (inner_py_value.get() == Py_None) {
703 dims[i] = -1;
704 } else if (inner_py_value.get() == nullptr ||
705 !ParseDimensionValue(key, inner_py_value.get(), status,
706 &dims[i])) {
707 return false;
708 }
709 }
710 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
711 }
712 if (!status->status.ok()) return false;
713 } else if (type == TF_ATTR_FUNC) {
714 // Allow:
715 // (1) String function name, OR
716 // (2) A Python object with a .name attribute
717 // (A crude test for being a
718 // tensorflow.python.framework.function._DefinedFunction)
719 // (which is what the various "defun" or "Defun" decorators do).
720 // And in the future also allow an object that can encapsulate
721 // the function name and its attribute values.
722 tensorflow::StringPiece func_name;
723 if (!ParseStringValue(key, py_value, status, &func_name)) {
724 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
725 if (name_attr == nullptr ||
726 !ParseStringValue(key, name_attr, status, &func_name)) {
727 TF_SetStatus(
728 status, TF_INVALID_ARGUMENT,
729 tensorflow::strings::StrCat(
730 "unable to set function value attribute from a ",
731 py_value->ob_type->tp_name,
732 " object. If you think this is an error, please file an issue "
733 "at https://github.com/tensorflow/tensorflow/issues/new")
734 .c_str());
735 return false;
736 }
737 }
738 TF_SetStatus(status, TF_OK, "");
739 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
740 } else {
741 TF_SetStatus(
742 status, TF_UNIMPLEMENTED,
743 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
744 .c_str());
745 return false;
746 }
747 return true;
748}
749
750void SetOpAttrScalarDefault(
751 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
752 const char* attr_name,
753 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
754 TF_Status* status) {
755 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
756 if (default_value.value_case() == tensorflow::AttrValue::kI) {
757 (*attr_list_sizes)[attr_name] = default_value.i();
758 }
759}
760
761// start_index is the index at which the Tuple/List attrs will start getting
762// processed.
763void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
764 TF_Status* out_status) {
765 if (attrs == Py_None) return;
766 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
767 if ((len & 1) != 0) {
768 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
769 "Expecting attrs tuple to have even length.");
770 return;
771 }
772 // Parse attrs
773 for (Py_ssize_t i = 0; i < len; i += 2) {
774 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
775 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
776#if PY_MAJOR_VERSION >= 3
777 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
778 : PyUnicode_AsUTF8(py_key);
779#else
780 const char* key = PyBytes_AsString(py_key);
781#endif
782 unsigned char is_list = 0;
783 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
784 if (!out_status->status.ok()) return;
785 if (is_list != 0) {
786 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
787 return;
788 } else {
789 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
790 return;
791 }
792 }
793}
794
795// This function will set the op attrs required. If an attr has the value of
796// None, then it will read the AttrDef to get the default value and set that
797// instead. Any failure in this function will simply fall back to the slow
798// path.
799void SetOpAttrWithDefaults(
800 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
801 const char* attr_name, PyObject* attr_value,
802 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
803 TF_Status* status) {
804 unsigned char is_list = 0;
805 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
806 if (!status->status.ok()) return;
807 if (attr_value == Py_None) {
808 if (is_list != 0) {
809 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
810 status);
811 } else {
812 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
813 attr_list_sizes, status);
814 }
815 } else {
816 if (is_list != 0) {
817 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
818 status);
819 } else {
820 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
821 status);
822 }
823 }
824}
825
826PyObject* GetPythonObjectFromInt(int num) {
827#if PY_MAJOR_VERSION >= 3
828 return PyLong_FromLong(num);
829#else
830 return PyInt_FromLong(num);
831#endif
832}
833
834// Python subclass of Exception that is created on not ok Status.
835tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
836PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
837
838// Python subclass of Exception that is created to signal fallback.
839PyObject* fallback_exception_class = nullptr;
840
841// Python function that returns input gradients given output gradients.
842PyObject* gradient_function = nullptr;
843
844// Python function that returns output gradients given input gradients.
845PyObject* forward_gradient_function = nullptr;
846
847static std::atomic<int64_t> _uid;
848
849// This struct is responsible for marking thread_local storage as destroyed.
850// Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker
851// is safe because it's a trivial type, so long as nobody creates a new
852// thread_local in the space where now-destroyed marker used to be.
853// Hopefully creating new thread_locals while destructing a thread is rare.
854struct ThreadLocalDestructionMarker {
855 ~ThreadLocalDestructionMarker() { alive = false; }
856 bool alive = true;
857};
858
859} // namespace
860
861TF_Status* GetStatus() {
862 TF_Status* maybe_status = ReleaseThreadLocalStatus();
863 if (maybe_status) {
864 TF_SetStatus(maybe_status, TF_OK, "");
865 return maybe_status;
866 } else {
867 return TF_NewStatus();
868 }
869}
870
871void ReturnStatus(TF_Status* status) {
872 TF_SetStatus(status, TF_OK, "");
873 thread_local_tf_status.reset(status);
874}
875
876void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
877 const char* op_name, TFE_InputTensorHandles* inputs,
878 PyObject* attrs, TFE_OutputTensorHandles* outputs,
879 TF_Status* out_status) {
880 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
881 /*cancellation_manager=*/nullptr, outputs,
882 out_status);
883}
884
885void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
886 const char* op_name,
887 TFE_InputTensorHandles* inputs, PyObject* attrs,
888 TFE_CancellationManager* cancellation_manager,
889 TFE_OutputTensorHandles* outputs,
890 TF_Status* out_status) {
891 tensorflow::profiler::TraceMe activity(
892 "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
893
894 TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
895
896 auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
897 if (!out_status->status.ok()) return;
898
899 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
900 tensorflow::StackTrace::kStackTraceInitialSize));
901
902 for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
903 TFE_OpAddInput(op, inputs->at(i), out_status);
904 }
905 if (cancellation_manager && out_status->status.ok()) {
906 TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
907 }
908 if (out_status->status.ok()) {
909 SetOpAttrs(ctx, op, attrs, 0, out_status);
910 }
911 Py_BEGIN_ALLOW_THREADS;
912
913 int num_outputs = outputs->size();
914
915 if (out_status->status.ok()) {
916 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
917 }
918
919 if (out_status->status.ok()) {
920 outputs->resize(num_outputs);
921 } else {
922 TF_SetStatus(out_status, TF_GetCode(out_status),
923 tensorflow::strings::StrCat(TF_Message(out_status),
924 " [Op:", op_name, "]")
925 .c_str());
926 }
927
928 Py_END_ALLOW_THREADS;
929}
930
931PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
932 tensorflow::mutex_lock l(exception_class_mutex);
933 if (exception_class != nullptr) {
934 Py_DECREF(exception_class);
935 }
936 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
937 exception_class = nullptr;
938 PyErr_SetString(PyExc_TypeError,
939 "TFE_Py_RegisterExceptionClass: "
940 "Registered class should be subclass of Exception.");
941 return nullptr;
942 }
943
944 Py_INCREF(e);
945 exception_class = e;
946 Py_RETURN_NONE;
947}
948
949PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
950 if (fallback_exception_class != nullptr) {
951 Py_DECREF(fallback_exception_class);
952 }
953 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
954 fallback_exception_class = nullptr;
955 PyErr_SetString(PyExc_TypeError,
956 "TFE_Py_RegisterFallbackExceptionClass: "
957 "Registered class should be subclass of Exception.");
958 return nullptr;
959 } else {
960 Py_INCREF(e);
961 fallback_exception_class = e;
962 Py_RETURN_NONE;
963 }
964}
965
966PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
967 if (gradient_function != nullptr) {
968 Py_DECREF(gradient_function);
969 }
970 if (!PyCallable_Check(e)) {
971 gradient_function = nullptr;
972 PyErr_SetString(PyExc_TypeError,
973 "TFE_Py_RegisterGradientFunction: "
974 "Registered object should be function.");
975 return nullptr;
976 } else {
977 Py_INCREF(e);
978 gradient_function = e;
979 Py_RETURN_NONE;
980 }
981}
982
983PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
984 if (forward_gradient_function != nullptr) {
985 Py_DECREF(forward_gradient_function);
986 }
987 if (!PyCallable_Check(e)) {
988 forward_gradient_function = nullptr;
989 PyErr_SetString(PyExc_TypeError,
990 "TFE_Py_RegisterJVPFunction: "
991 "Registered object should be function.");
992 return nullptr;
993 } else {
994 Py_INCREF(e);
995 forward_gradient_function = e;
996 Py_RETURN_NONE;
997 }
998}
999
1000void RaiseFallbackException(const char* message) {
1001 if (fallback_exception_class != nullptr) {
1002 PyErr_SetString(fallback_exception_class, message);
1003 return;
1004 }
1005
1006 PyErr_SetString(
1007 PyExc_RuntimeError,
1008 tensorflow::strings::StrCat(
1009 "Fallback exception type not set, attempting to fallback due to ",
1010 message)
1011 .data());
1012}
1013
1014// Format and return `status`' error message with the attached stack trace if
1015// available. `status` must have an error.
1016std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
1017 tensorflow::DCheckPyGilState();
1018 DCHECK(!status.ok());
1019
1020 std::vector<tensorflow::StackFrame> stack_trace =
1021 tensorflow::errors::GetStackTrace(status);
1022
1023 if (stack_trace.empty()) return status.error_message();
1024
1025 PyObject* linecache = PyImport_ImportModule("linecache");
1026 PyObject* getline =
1027 PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
1028 DCHECK(getline);
1029
1030 std::ostringstream result;
1031 result << "Exception originated from\n\n";
1032
1033 for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1034 PyObject* line_str_obj = PyObject_CallFunction(
1035 getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1036 stack_frame.line_number);
1037 tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1038 tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1039 result << " File \"" << stack_frame.file_name << "\", line "
1040 << stack_frame.line_number << ", in " << stack_frame.function_name
1041 << '\n';
1042
1043 if (!line_str.empty()) result << " " << line_str << '\n';
1044 Py_XDECREF(line_str_obj);
1045 }
1046
1047 Py_DecRef(getline);
1048 Py_DecRef(linecache);
1049
1050 result << '\n' << status.error_message();
1051 return result.str();
1052}
1053
1054namespace tensorflow {
1055
1056int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1057 if (status->status.ok()) return 0;
1058 const char* msg = TF_Message(status);
1059 if (exception == nullptr) {
1060 tensorflow::mutex_lock l(exception_class_mutex);
1061 if (exception_class != nullptr) {
1062 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1063 for (const auto& payload :
1064 tensorflow::errors::GetPayloads(status->status)) {
1065 PyDict_SetItem(payloads.get(),
1066 PyBytes_FromString(payload.first.c_str()),
1067 PyBytes_FromString(payload.second.c_str()));
1068 }
1069 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1070 "siO", FormatErrorStatusStackTrace(status->status).c_str(),
1071 TF_GetCode(status), payloads.get()));
1072 if (PyErr_Occurred()) {
1073 // NOTE: This hides the actual error (i.e. the reason `status` was not
1074 // TF_OK), but there is nothing we can do at this point since we can't
1075 // generate a reasonable error from the status.
1076 // Consider adding a message explaining this.
1077 return -1;
1078 }
1079 PyErr_SetObject(exception_class, val.get());
1080 return -1;
1081 } else {
1082 exception = PyExc_RuntimeError;
1083 }
1084 }
1085 // May be update already set exception.
1086 PyErr_SetString(exception, msg);
1087 return -1;
1088}
1089
1090} // namespace tensorflow
1091
1092int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1093 PyObject* exception) {
1094 if (status.ok()) return 0;
1095 const char* msg = status.error_message().c_str();
1096 if (exception == nullptr) {
1097 tensorflow::mutex_lock l(exception_class_mutex);
1098 if (exception_class != nullptr) {
1099 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1100 for (const auto& element : tensorflow::errors::GetPayloads(status)) {
1101 PyDict_SetItem(payloads.get(),
1102 PyBytes_FromString(element.first.c_str()),
1103 PyBytes_FromString(element.second.c_str()));
1104 }
1105 tensorflow::Safe_PyObjectPtr val(
1106 Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(),
1107 status.code(), payloads.get()));
1108 PyErr_SetObject(exception_class, val.get());
1109 return -1;
1110 } else {
1111 exception = PyExc_RuntimeError;
1112 }
1113 }
1114 // May be update already set exception.
1115 PyErr_SetString(exception, msg);
1116 return -1;
1117}
1118
1119const char* TFE_GetPythonString(PyObject* o) {
1120#if PY_MAJOR_VERSION >= 3
1121 if (PyBytes_Check(o)) {
1122 return PyBytes_AsString(o);
1123 } else {
1124 return PyUnicode_AsUTF8(o);
1125 }
1126#else
1127 return PyBytes_AsString(o);
1128#endif
1129}
1130
1131int64_t get_uid() { return _uid++; }
1132
1133PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1134
1135void TFE_DeleteContextCapsule(PyObject* context) {
1136 TFE_Context* ctx =
1137 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1138 auto op = ReleaseThreadLocalOp(ctx);
1139 op.reset();
1140 TFE_DeleteContext(ctx);
1141}
1142
1143static int64_t MakeInt(PyObject* integer) {
1144#if PY_MAJOR_VERSION >= 3
1145 return PyLong_AsLong(integer);
1146#else
1147 return PyInt_AsLong(integer);
1148#endif
1149}
1150
1151static int64_t FastTensorId(PyObject* tensor) {
1152 if (EagerTensor_CheckExact(tensor)) {
1153 return PyEagerTensor_ID(tensor);
1154 }
1155 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1156 if (id_field == nullptr) {
1157 return -1;
1158 }
1159 int64_t id = MakeInt(id_field);
1160 Py_DECREF(id_field);
1161 return id;
1162}
1163
1164namespace tensorflow {
1165DataType PyTensor_DataType(PyObject* tensor) {
1166 if (EagerTensor_CheckExact(tensor)) {
1167 return PyEagerTensor_Dtype(tensor);
1168 } else {
1169#if PY_MAJOR_VERSION < 3
1170 // Python 2.x:
1171 static PyObject* dtype_attr = PyString_InternFromString("dtype");
1172 static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1173#else
1174 // Python 3.x:
1175 static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1176 static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1177#endif
1178 Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1179 if (!dtype_field) {
1180 return DT_INVALID;
1181 }
1182
1183 Safe_PyObjectPtr enum_field(
1184 PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1185 if (!enum_field) {
1186 return DT_INVALID;
1187 }
1188
1189 return static_cast<DataType>(MakeInt(enum_field.get()));
1190 }
1191}
1192} // namespace tensorflow
1193
1194class PyTapeTensor {
1195 public:
1196 PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1197 const tensorflow::TensorShape& shape)
1198 : id_(id), dtype_(dtype), shape_(shape) {}
1199 PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1200 : id_(id), dtype_(dtype), shape_(shape) {
1201 Py_INCREF(absl::get<1>(shape_));
1202 }
1203 PyTapeTensor(const PyTapeTensor& other) {
1204 id_ = other.id_;
1205 dtype_ = other.dtype_;
1206 shape_ = other.shape_;
1207 if (shape_.index() == 1) {
1208 Py_INCREF(absl::get<1>(shape_));
1209 }
1210 }
1211
1212 ~PyTapeTensor() {
1213 if (shape_.index() == 1) {
1214 Py_DECREF(absl::get<1>(shape_));
1215 }
1216 }
1217 PyObject* GetShape() const;
1218 PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
1219 int64_t GetID() const { return id_; }
1220 tensorflow::DataType GetDType() const { return dtype_; }
1221
1222 PyObject* OnesLike() const;
1223 PyObject* ZerosLike() const;
1224
1225 private:
1226 int64_t id_;
1227 tensorflow::DataType dtype_;
1228
1229 // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1230 // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1231 // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1232 // tensors.
1233 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1234};
1235
1236static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1237
1238class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1239 PyTapeTensor> {
1240 public:
1241 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1242 Py_INCREF(py_vspace_);
1243 }
1244
1245 tensorflow::Status Initialize() {
1246 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1247 if (num_elements_ == nullptr) {
1248 return tensorflow::errors::InvalidArgument("invalid vspace");
1249 }
1250 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1251 if (aggregate_fn_ == nullptr) {
1252 return tensorflow::errors::InvalidArgument("invalid vspace");
1253 }
1254 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1255 if (zeros_fn_ == nullptr) {
1256 return tensorflow::errors::InvalidArgument("invalid vspace");
1257 }
1258 zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1259 if (zeros_like_fn_ == nullptr) {
1260 return tensorflow::errors::InvalidArgument("invalid vspace");
1261 }
1262 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1263 if (ones_fn_ == nullptr) {
1264 return tensorflow::errors::InvalidArgument("invalid vspace");
1265 }
1266 ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1267 if (ones_like_fn_ == nullptr) {
1268 return tensorflow::errors::InvalidArgument("invalid vspace");
1269 }
1270 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1271 if (graph_shape_fn_ == nullptr) {
1272 return tensorflow::errors::InvalidArgument("invalid vspace");
1273 }
1274 return ::tensorflow::OkStatus();
1275 }
1276
1277 ~PyVSpace() override {
1278 Py_XDECREF(num_elements_);
1279 Py_XDECREF(aggregate_fn_);
1280 Py_XDECREF(zeros_fn_);
1281 Py_XDECREF(zeros_like_fn_);
1282 Py_XDECREF(ones_fn_);
1283 Py_XDECREF(ones_like_fn_);
1284 Py_XDECREF(graph_shape_fn_);
1285
1286 Py_DECREF(py_vspace_);
1287 }
1288
1289 int64_t NumElements(PyObject* tensor) const final {
1290 if (EagerTensor_CheckExact(tensor)) {
1291 return PyEagerTensor_NumElements(tensor);
1292 }
1293 PyObject* arglist =
1294 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1295 PyObject* result = PyEval_CallObject(num_elements_, arglist);
1296 Py_DECREF(arglist);
1297 if (result == nullptr) {
1298 // The caller detects whether a python exception has been raised.
1299 return -1;
1300 }
1301 int64_t r = MakeInt(result);
1302 Py_DECREF(result);
1303 return r;
1304 }
1305
1306 PyObject* AggregateGradients(
1307 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1308 PyObject* list = PyList_New(gradient_tensors.size());
1309 for (int i = 0; i < gradient_tensors.size(); ++i) {
1310 // Note: stealing a reference to the gradient tensors.
1311 CHECK(gradient_tensors[i] != nullptr);
1312 CHECK(gradient_tensors[i] != Py_None);
1313 PyList_SET_ITEM(list, i,
1314 reinterpret_cast<PyObject*>(gradient_tensors[i]));
1315 }
1316 PyObject* arglist = Py_BuildValue("(O)", list);
1317 CHECK(arglist != nullptr);
1318 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1319 Py_DECREF(arglist);
1320 Py_DECREF(list);
1321 return result;
1322 }
1323
1324 int64_t TensorId(PyObject* tensor) const final {
1325 return FastTensorId(tensor);
1326 }
1327
1328 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1329
1330 PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1331 if (PyErr_Occurred()) {
1332 return nullptr;
1333 }
1334 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1335 PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1336 Py_DECREF(arg_list);
1337 return result;
1338 }
1339
1340 PyObject* OnesLike(PyObject* tensor) const {
1341 if (PyErr_Occurred()) {
1342 return nullptr;
1343 }
1344 return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1345 }
1346
1347 // Builds a tensor filled with ones with the same shape and dtype as `t`.
1348 Status BuildOnesLike(const PyTapeTensor& t,
1349 PyObject** result) const override {
1350 *result = t.OnesLike();
1351 return ::tensorflow::OkStatus();
1352 }
1353
1354 PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1355 if (PyErr_Occurred()) {
1356 return nullptr;
1357 }
1358 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1359 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1360 Py_DECREF(arg_list);
1361 return result;
1362 }
1363
1364 PyObject* ZerosLike(PyObject* tensor) const {
1365 if (PyErr_Occurred()) {
1366 return nullptr;
1367 }
1368 return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1369 }
1370
1371 PyObject* GraphShape(PyObject* tensor) const {
1372 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1373 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1374 Py_DECREF(arg_list);
1375 return result;
1376 }
1377
1378 tensorflow::Status CallBackwardFunction(
1379 const string& op_type, PyBackwardFunction* backward_function,
1380 const std::vector<int64_t>& unneeded_gradients,
1381 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1382 absl::Span<PyObject*> result) const final {
1383 PyObject* grads = PyTuple_New(output_gradients.size());
1384 for (int i = 0; i < output_gradients.size(); ++i) {
1385 if (output_gradients[i] == nullptr) {
1386 Py_INCREF(Py_None);
1387 PyTuple_SET_ITEM(grads, i, Py_None);
1388 } else {
1389 PyTuple_SET_ITEM(grads, i,
1390 reinterpret_cast<PyObject*>(output_gradients[i]));
1391 }
1392 }
1393 PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1394 Py_DECREF(grads);
1395 if (py_result == nullptr) {
1396 return tensorflow::errors::Internal("gradient function threw exceptions");
1397 }
1398 PyObject* seq =
1399 PySequence_Fast(py_result, "expected a sequence of gradients");
1400 if (seq == nullptr) {
1401 return tensorflow::errors::InvalidArgument(
1402 "gradient function did not return a list");
1403 }
1404 int len = PySequence_Fast_GET_SIZE(seq);
1405 if (len != result.size()) {
1406 return tensorflow::errors::Internal(
1407 "Recorded operation '", op_type,
1408 "' returned too few gradients. Expected ", result.size(),
1409 " but received ", len);
1410 }
1411 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1412 VLOG(1) << "Gradient length is " << len;
1413 for (int i = 0; i < len; ++i) {
1414 PyObject* item = seq_array[i];
1415 if (item == Py_None) {
1416 result[i] = nullptr;
1417 } else {
1418 Py_INCREF(item);
1419 result[i] = item;
1420 }
1421 }
1422 Py_DECREF(seq);
1423 Py_DECREF(py_result);
1424 return ::tensorflow::OkStatus();
1425 }
1426
1427 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1428
1429 PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1430 return TapeTensorFromTensor(tensor);
1431 }
1432
1433 private:
1434 PyObject* py_vspace_;
1435
1436 PyObject* num_elements_;
1437 PyObject* aggregate_fn_;
1438 PyObject* zeros_fn_;
1439 PyObject* zeros_like_fn_;
1440 PyObject* ones_fn_;
1441 PyObject* ones_like_fn_;
1442 PyObject* graph_shape_fn_;
1443};
1444PyVSpace* py_vspace = nullptr;
1445
1446bool HasAccumulator();
1447
1448PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1449 if (py_vspace != nullptr) {
1450 if (HasAccumulator()) {
1451 // Accumulators reference py_vspace, so we can't swap it out while one is
1452 // active. This is unlikely to ever happen.
1453 MaybeRaiseExceptionFromStatus(
1454 tensorflow::errors::Internal(
1455 "Can't change the vspace implementation while a "
1456 "forward accumulator is active."),
1457 nullptr);
1458 }
1459 delete py_vspace;
1460 }
1461
1462 py_vspace = new PyVSpace(e);
1463 auto status = py_vspace->Initialize();
1464 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1465 delete py_vspace;
1466 return nullptr;
1467 }
1468
1469 Py_RETURN_NONE;
1470}
1471
1472PyObject* PyTapeTensor::GetShape() const {
1473 if (shape_.index() == 0) {
1474 auto& shape = absl::get<0>(shape_);
1475 PyObject* py_shape = PyTuple_New(shape.dims());
1476 for (int i = 0; i < shape.dims(); ++i) {
1477 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1478 }
1479
1480 return py_shape;
1481 }
1482
1483 return py_vspace->GraphShape(absl::get<1>(shape_));
1484}
1485
1486PyObject* PyTapeTensor::OnesLike() const {
1487 if (shape_.index() == 1) {
1488 PyObject* tensor = absl::get<1>(shape_);
1489 return py_vspace->OnesLike(tensor);
1490 }
1491 PyObject* py_shape = GetShape();
1492 PyObject* dtype_field = GetPyDType();
1493 PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1494 Py_DECREF(dtype_field);
1495 Py_DECREF(py_shape);
1496 return result;
1497}
1498
1499PyObject* PyTapeTensor::ZerosLike() const {
1500 if (GetDType() == tensorflow::DT_RESOURCE) {
1501 // Gradient functions for ops which return resource tensors accept
1502 // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1503 // issues with ZerosLike.
1504 Py_RETURN_NONE;
1505 }
1506 if (shape_.index() == 1) {
1507 PyObject* tensor = absl::get<1>(shape_);
1508 return py_vspace->ZerosLike(tensor);
1509 }
1510 PyObject* py_shape = GetShape();
1511 PyObject* dtype_field = GetPyDType();
1512 PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1513 Py_DECREF(dtype_field);
1514 Py_DECREF(py_shape);
1515 return result;
1516}
1517
1518// Keeps track of all variables that have been accessed during execution.
1519class VariableWatcher {
1520 public:
1521 VariableWatcher() {}
1522
1523 ~VariableWatcher() {
1524 for (const IdAndVariable& v : watched_variables_) {
1525 Py_DECREF(v.variable);
1526 }
1527 }
1528
1529 int64_t WatchVariable(PyObject* v) {
1530 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1531 if (handle == nullptr) {
1532 return -1;
1533 }
1534 int64_t id = FastTensorId(handle.get());
1535
1536 tensorflow::mutex_lock l(watched_variables_mu_);
1537 auto insert_result = watched_variables_.emplace(id, v);
1538
1539 if (insert_result.second) {
1540 // Only increment the reference count if we aren't already watching this
1541 // variable.
1542 Py_INCREF(v);
1543 }
1544
1545 return id;
1546 }
1547
1548 PyObject* GetVariablesAsPyTuple() {
1549 tensorflow::mutex_lock l(watched_variables_mu_);
1550 PyObject* result = PyTuple_New(watched_variables_.size());
1551 Py_ssize_t pos = 0;
1552 for (const IdAndVariable& id_and_variable : watched_variables_) {
1553 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1554 Py_INCREF(id_and_variable.variable);
1555 }
1556 return result;
1557 }
1558
1559 private:
1560 // We store an IdAndVariable in the map since the map needs to be locked
1561 // during insert, but should not call back into python during insert to avoid
1562 // deadlocking with the GIL.
1563 struct IdAndVariable {
1564 int64_t id;
1565 PyObject* variable;
1566
1567 IdAndVariable(int64_t id, PyObject* variable)
1568 : id(id), variable(variable) {}
1569 };
1570 struct CompareById {
1571 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1572 return lhs.id < rhs.id;
1573 }
1574 };
1575
1576 tensorflow::mutex watched_variables_mu_;
1577 std::set<IdAndVariable, CompareById> watched_variables_
1578 TF_GUARDED_BY(watched_variables_mu_);
1579};
1580
1581class GradientTape
1582 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1583 PyTapeTensor> {
1584 public:
1585 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1586 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1587 PyTapeTensor>(persistent),
1588 watch_accessed_variables_(watch_accessed_variables) {}
1589
1590 virtual ~GradientTape() {}
1591
1592 void VariableAccessed(PyObject* v) {
1593 if (watch_accessed_variables_) {
1594 WatchVariable(v);
1595 }
1596 }
1597
1598 void WatchVariable(PyObject* v) {
1599 int64_t id = variable_watcher_.WatchVariable(v);
1600
1601 if (!PyErr_Occurred()) {
1602 this->Watch(id);
1603 }
1604 }
1605
1606 PyObject* GetVariablesAsPyTuple() {
1607 return variable_watcher_.GetVariablesAsPyTuple();
1608 }
1609
1610 private:
1611 bool watch_accessed_variables_;
1612 VariableWatcher variable_watcher_;
1613};
1614
1615typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1616 PyTapeTensor>
1617 ForwardAccumulator;
1618
1619// Incremented when a GradientTape or accumulator is newly added to a set, and
1620// used to enforce an ordering between them.
1621std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1622
1623typedef struct {
1624 PyObject_HEAD
1625 /* Type-specific fields go here. */
1626 GradientTape* tape;
1627 // A nesting order between GradientTapes and ForwardAccumulators, used to
1628 // ensure that GradientTapes do not watch the products of outer
1629 // ForwardAccumulators.
1630 int64_t nesting_id;
1631} TFE_Py_Tape;
1632
1633static void TFE_Py_Tape_Delete(PyObject* tape) {
1634 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1635 Py_TYPE(tape)->tp_free(tape);
1636}
1637
1638static PyTypeObject TFE_Py_Tape_Type = {
1639 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1640 sizeof(TFE_Py_Tape), /* tp_basicsize */
1641 0, /* tp_itemsize */
1642 &TFE_Py_Tape_Delete, /* tp_dealloc */
1643#if PY_VERSION_HEX < 0x03080000
1644 nullptr, /* tp_print */
1645#else
1646 0, /* tp_vectorcall_offset */
1647#endif
1648 nullptr, /* tp_getattr */
1649 nullptr, /* tp_setattr */
1650 nullptr, /* tp_reserved */
1651 nullptr, /* tp_repr */
1652 nullptr, /* tp_as_number */
1653 nullptr, /* tp_as_sequence */
1654 nullptr, /* tp_as_mapping */
1655 nullptr, /* tp_hash */
1656 nullptr, /* tp_call */
1657 nullptr, /* tp_str */
1658 nullptr, /* tp_getattro */
1659 nullptr, /* tp_setattro */
1660 nullptr, /* tp_as_buffer */
1661 Py_TPFLAGS_DEFAULT, /* tp_flags */
1662 "TFE_Py_Tape objects", /* tp_doc */
1663};
1664
1665typedef struct {
1666 PyObject_HEAD
1667 /* Type-specific fields go here. */
1668 ForwardAccumulator* accumulator;
1669 // A nesting order between GradientTapes and ForwardAccumulators, used to
1670 // ensure that GradientTapes do not watch the products of outer
1671 // ForwardAccumulators.
1672 int64_t nesting_id;
1673} TFE_Py_ForwardAccumulator;
1674
1675static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1676 delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1677 Py_TYPE(accumulator)->tp_free(accumulator);
1678}
1679
1680static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1681 PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1682 sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
1683 0, /* tp_itemsize */
1684 &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
1685#if PY_VERSION_HEX < 0x03080000
1686 nullptr, /* tp_print */
1687#else
1688 0, /* tp_vectorcall_offset */
1689#endif
1690 nullptr, /* tp_getattr */
1691 nullptr, /* tp_setattr */
1692 nullptr, /* tp_reserved */
1693 nullptr, /* tp_repr */
1694 nullptr, /* tp_as_number */
1695 nullptr, /* tp_as_sequence */
1696 nullptr, /* tp_as_mapping */
1697 nullptr, /* tp_hash */
1698 nullptr, /* tp_call */
1699 nullptr, /* tp_str */
1700 nullptr, /* tp_getattro */
1701 nullptr, /* tp_setattro */
1702 nullptr, /* tp_as_buffer */
1703 Py_TPFLAGS_DEFAULT, /* tp_flags */
1704 "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1705};
1706
1707typedef struct {
1708 PyObject_HEAD
1709 /* Type-specific fields go here. */
1710 VariableWatcher* variable_watcher;
1711} TFE_Py_VariableWatcher;
1712
1713static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1714 delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1715 ->variable_watcher;
1716 Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1717}
1718
1719static PyTypeObject TFE_Py_VariableWatcher_Type = {
1720 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1721 sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
1722 0, /* tp_itemsize */
1723 &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
1724#if PY_VERSION_HEX < 0x03080000
1725 nullptr, /* tp_print */
1726#else
1727 0, /* tp_vectorcall_offset */
1728#endif
1729 nullptr, /* tp_getattr */
1730 nullptr, /* tp_setattr */
1731 nullptr, /* tp_reserved */
1732 nullptr, /* tp_repr */
1733 nullptr, /* tp_as_number */
1734 nullptr, /* tp_as_sequence */
1735 nullptr, /* tp_as_mapping */
1736 nullptr, /* tp_hash */
1737 nullptr, /* tp_call */
1738 nullptr, /* tp_str */
1739 nullptr, /* tp_getattro */
1740 nullptr, /* tp_setattro */
1741 nullptr, /* tp_as_buffer */
1742 Py_TPFLAGS_DEFAULT, /* tp_flags */
1743 "TFE_Py_VariableWatcher objects", /* tp_doc */
1744};
1745
1746// Note: in the current design no mutex is needed here because of the python
1747// GIL, which is always held when any TFE_Py_* methods are called. We should
1748// revisit this if/when decide to not hold the GIL while manipulating the tape
1749// stack.
1750tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1751 thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1752 tape_set;
1753 thread_local ThreadLocalDestructionMarker marker;
1754 if (!marker.alive) {
1755 // This thread is being destroyed. It is unsafe to access tape_set.
1756 return nullptr;
1757 }
1758 if (tape_set == nullptr) {
1759 tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1760 }
1761 return tape_set.get();
1762}
1763
1764tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
1765GetVariableWatcherSet() {
1766 thread_local std::unique_ptr<
1767 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1768 variable_watcher_set;
1769 thread_local ThreadLocalDestructionMarker marker;
1770 if (!marker.alive) {
1771 // This thread is being destroyed. It is unsafe to access
1772 // variable_watcher_set.
1773 return nullptr;
1774 }
1775 if (variable_watcher_set == nullptr) {
1776 variable_watcher_set.reset(
1777 new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1778 }
1779 return variable_watcher_set.get();
1780}
1781
1782// A linked hash set, where iteration is in insertion order.
1783//
1784// Nested accumulators rely on op recording happening in insertion order, so an
1785// unordered data structure like CompactPointerSet is not suitable. Outer
1786// accumulators need to observe operations first so they know to watch the inner
1787// accumulator's jvp computation.
1788//
1789// Not thread safe.
1790class AccumulatorSet {
1791 public:
1792 // Returns true if `element` was newly inserted, false if it already exists.
1793 bool insert(TFE_Py_ForwardAccumulator* element) {
1794 if (map_.find(element) != map_.end()) {
1795 return false;
1796 }
1797 ListType::iterator it = ordered_.insert(ordered_.end(), element);
1798 map_.insert(std::make_pair(element, it));
1799 return true;
1800 }
1801
1802 void erase(TFE_Py_ForwardAccumulator* element) {
1803 MapType::iterator existing = map_.find(element);
1804 if (existing == map_.end()) {
1805 return;
1806 }
1807 ListType::iterator list_position = existing->second;
1808 map_.erase(existing);
1809 ordered_.erase(list_position);
1810 }
1811
1812 bool empty() const { return ordered_.empty(); }
1813
1814 size_t size() const { return ordered_.size(); }
1815
1816 private:
1817 typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1818 typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1819 ListType::iterator>
1820 MapType;
1821
1822 public:
1823 typedef ListType::const_iterator const_iterator;
1824 typedef ListType::const_reverse_iterator const_reverse_iterator;
1825
1826 const_iterator begin() const { return ordered_.begin(); }
1827 const_iterator end() const { return ordered_.end(); }
1828
1829 const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
1830 const_reverse_iterator rend() const { return ordered_.rend(); }
1831
1832 private:
1833 MapType map_;
1834 ListType ordered_;
1835};
1836
1837AccumulatorSet* GetAccumulatorSet() {
1838 thread_local std::unique_ptr<AccumulatorSet> accumulator_set;
1839 thread_local ThreadLocalDestructionMarker marker;
1840 if (!marker.alive) {
1841 // This thread is being destroyed. It is unsafe to access accumulator_set.
1842 return nullptr;
1843 }
1844 if (accumulator_set == nullptr) {
1845 accumulator_set.reset(new AccumulatorSet);
1846 }
1847 return accumulator_set.get();
1848}
1849
1850inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1851
1852inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1853
1854inline bool HasAccumulatorOrTape() {
1855 return HasGradientTape() || HasAccumulator();
1856}
1857
1858// A safe copy of a set, used for tapes and accumulators. The copy is not
1859// affected by other python threads changing the set of active tapes.
1860template <typename ContainerType>
1861class SafeSetCopy {
1862 public:
1863 explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1864 for (auto* member : set_copy_) {
1865 Py_INCREF(member);
1866 }
1867 }
1868
1869 ~SafeSetCopy() {
1870 for (auto* member : set_copy_) {
1871 Py_DECREF(member);
1872 }
1873 }
1874
1875 typename ContainerType::const_iterator begin() const {
1876 return set_copy_.begin();
1877 }
1878
1879 typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1880
1881 bool empty() const { return set_copy_.empty(); }
1882 size_t size() const { return set_copy_.size(); }
1883
1884 protected:
1885 ContainerType set_copy_;
1886};
1887
1888class SafeTapeSet
1889 : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1890 public:
1891 SafeTapeSet()
1892 : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1893 *GetTapeSet()) {}
1894};
1895
1896class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1897 public:
1898 SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1899
1900 typename AccumulatorSet::const_reverse_iterator rbegin() const {
1901 return set_copy_.rbegin();
1902 }
1903
1904 typename AccumulatorSet::const_reverse_iterator rend() const {
1905 return set_copy_.rend();
1906 }
1907};
1908
1909class SafeVariableWatcherSet
1910 : public SafeSetCopy<
1911 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1912 public:
1913 SafeVariableWatcherSet()
1914 : SafeSetCopy<
1915 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1916 *GetVariableWatcherSet()) {}
1917};
1918
1919bool* ThreadTapeIsStopped() {
1920 thread_local bool thread_tape_is_stopped{false};
1921 return &thread_tape_is_stopped;
1922}
1923
1924void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1925
1926void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1927
1928PyObject* TFE_Py_TapeSetIsStopped() {
1929 if (*ThreadTapeIsStopped()) {
1930 Py_RETURN_TRUE;
1931 }
1932 Py_RETURN_FALSE;
1933}
1934
1935PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1936 PyObject* watch_accessed_variables) {
1937 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1938 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1939 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1940 tape->tape = new GradientTape(persistent == Py_True,
1941 watch_accessed_variables == Py_True);
1942 Py_INCREF(tape);
1943 tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1944 GetTapeSet()->insert(tape);
1945 return reinterpret_cast<PyObject*>(tape);
1946}
1947
1948void TFE_Py_TapeSetAdd(PyObject* tape) {
1949 Py_INCREF(tape);
1950 TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1951 if (!GetTapeSet()->insert(tfe_tape).second) {
1952 // Already exists in the tape set.
1953 Py_DECREF(tape);
1954 } else {
1955 tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1956 }
1957}
1958
1959PyObject* TFE_Py_TapeSetIsEmpty() {
1960 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1961 Py_RETURN_TRUE;
1962 }
1963 Py_RETURN_FALSE;
1964}
1965
1966void TFE_Py_TapeSetRemove(PyObject* tape) {
1967 auto* stack = GetTapeSet();
1968 if (stack != nullptr) {
1969 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1970 }
1971 // We kept a reference to the tape in the set to ensure it wouldn't get
1972 // deleted under us; cleaning it up here.
1973 Py_DECREF(tape);
1974}
1975
1976static std::vector<int64_t> MakeIntList(PyObject* list) {
1977 if (list == Py_None) {
1978 return {};
1979 }
1980 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1981 if (seq == nullptr) {
1982 return {};
1983 }
1984 int len = PySequence_Size(list);
1985 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1986 std::vector<int64_t> tensor_ids;
1987 tensor_ids.reserve(len);
1988 for (int i = 0; i < len; ++i) {
1989 PyObject* item = seq_array[i];
1990#if PY_MAJOR_VERSION >= 3
1991 if (PyLong_Check(item)) {
1992#else
1993 if (PyLong_Check(item) || PyInt_Check(item)) {
1994#endif
1995 int64_t id = MakeInt(item);
1996 tensor_ids.push_back(id);
1997 } else {
1998 tensor_ids.push_back(-1);
1999 }
2000 }
2001 Py_DECREF(seq);
2002 return tensor_ids;
2003}
2004
2005// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
2006// null. Returns true on success and false on a Python exception.
2007bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids,
2008 std::vector<tensorflow::DataType>* dtypes) {
2009 tensorflow::Safe_PyObjectPtr seq(
2010 PySequence_Fast(tensors, "expected a sequence"));
2011 if (seq == nullptr) {
2012 return false;
2013 }
2014 int len = PySequence_Fast_GET_SIZE(seq.get());
2015 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2016 tensor_ids->reserve(len);
2017 dtypes->reserve(len);
2018 for (int i = 0; i < len; ++i) {
2019 PyObject* item = seq_array[i];
2020 tensor_ids->push_back(FastTensorId(item));
2021 dtypes->push_back(tensorflow::PyTensor_DataType(item));
2022 }
2023 return true;
2024}
2025
2026bool TapeCouldPossiblyRecord(PyObject* tensors) {
2027 if (tensors == Py_None) {
2028 return false;
2029 }
2030 if (*ThreadTapeIsStopped()) {
2031 return false;
2032 }
2033 if (!HasAccumulatorOrTape()) {
2034 return false;
2035 }
2036 return true;
2037}
2038
2039bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
2040
2041bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
2042
2043PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
2044 if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
2045 Py_RETURN_FALSE;
2046 }
2047 // TODO(apassos) consider not building a list and changing the API to check
2048 // each tensor individually.
2049 std::vector<int64_t> tensor_ids;
2050 std::vector<tensorflow::DataType> dtypes;
2051 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2052 return nullptr;
2053 }
2054 auto& tape_set = *GetTapeSet();
2055 for (TFE_Py_Tape* tape : tape_set) {
2056 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2057 Py_RETURN_TRUE;
2058 }
2059 }
2060
2061 Py_RETURN_FALSE;
2062}
2063
2064PyObject* TFE_Py_ForwardAccumulatorPushState() {
2065 auto& forward_accumulators = *GetAccumulatorSet();
2066 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2067 accumulator->accumulator->PushState();
2068 }
2069 Py_RETURN_NONE;
2070}
2071
2072PyObject* TFE_Py_ForwardAccumulatorPopState() {
2073 auto& forward_accumulators = *GetAccumulatorSet();
2074 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2075 accumulator->accumulator->PopState();
2076 }
2077 Py_RETURN_NONE;
2078}
2079
2080PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2081 if (!TapeCouldPossiblyRecord(tensors)) {
2082 return GetPythonObjectFromInt(0);
2083 }
2084 std::vector<int64_t> tensor_ids;
2085 std::vector<tensorflow::DataType> dtypes;
2086 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2087 return nullptr;
2088 }
2089
2090 // If there is a persistent tape watching, or if there are multiple tapes
2091 // watching, we'll return immediately indicating that higher-order tape
2092 // gradients are possible.
2093 bool some_tape_watching = false;
2094 if (CouldBackprop()) {
2095 auto& tape_set = *GetTapeSet();
2096 for (TFE_Py_Tape* tape : tape_set) {
2097 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2098 if (tape->tape->IsPersistent() || some_tape_watching) {
2099 // Either this is the second tape watching, or this tape is
2100 // persistent: higher-order gradients are possible.
2101 return GetPythonObjectFromInt(2);
2102 }
2103 some_tape_watching = true;
2104 }
2105 }
2106 }
2107 if (CouldForwardprop()) {
2108 auto& forward_accumulators = *GetAccumulatorSet();
2109 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2110 if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2111 if (some_tape_watching) {
2112 // This is the second tape watching: higher-order gradients are
2113 // possible. Note that there's no equivalent of persistence for
2114 // forward-mode.
2115 return GetPythonObjectFromInt(2);
2116 }
2117 some_tape_watching = true;
2118 }
2119 }
2120 }
2121 if (some_tape_watching) {
2122 // There's exactly one non-persistent tape. The user can request first-order
2123 // gradients but won't be able to get higher-order tape gradients.
2124 return GetPythonObjectFromInt(1);
2125 } else {
2126 // There are no tapes. The user can't request tape gradients.
2127 return GetPythonObjectFromInt(0);
2128 }
2129}
2130
2131void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2132 if (!CouldBackprop()) {
2133 return;
2134 }
2135 int64_t tensor_id = FastTensorId(tensor);
2136 if (PyErr_Occurred()) {
2137 return;
2138 }
2139 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2140}
2141
2142bool ListContainsNone(PyObject* list) {
2143 if (list == Py_None) return true;
2144 tensorflow::Safe_PyObjectPtr seq(
2145 PySequence_Fast(list, "expected a sequence"));
2146 if (seq == nullptr) {
2147 return false;
2148 }
2149
2150 int len = PySequence_Size(list);
2151 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2152 for (int i = 0; i < len; ++i) {
2153 PyObject* item = seq_array[i];
2154 if (item == Py_None) return true;
2155 }
2156
2157 return false;
2158}
2159
2160// As an optimization, the tape generally keeps only the shape and dtype of
2161// tensors, and uses this information to generate ones/zeros tensors. However,
2162// some tensors require OnesLike/ZerosLike because their gradients do not match
2163// their inference shape/dtype.
2164bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2165 return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2166}
2167
2168static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2169 if (EagerTensor_CheckExact(tensor)) {
2170 tensorflow::ImmediateExecutionTensorHandle* handle =
2171 tensorflow::unwrap(EagerTensor_Handle(tensor));
2172 int64_t id = PyEagerTensor_ID(tensor);
2173 tensorflow::DataType dtype =
2174 static_cast<tensorflow::DataType>(handle->DataType());
2175 if (DTypeNeedsHandleData(dtype)) {
2176 return PyTapeTensor(id, dtype, tensor);
2177 }
2178
2179 tensorflow::TensorShape tensor_shape;
2180 int num_dims;
2181 tensorflow::Status status = handle->NumDims(&num_dims);
2182 if (status.ok()) {
2183 for (int i = 0; i < num_dims; ++i) {
2184 int64_t dim_size;
2185 status = handle->Dim(i, &dim_size);
2186 if (!status.ok()) break;
2187 tensor_shape.AddDim(dim_size);
2188 }
2189 }
2190
2191 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2192 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2193 tensorflow::TensorShape({}));
2194 } else {
2195 return PyTapeTensor(id, dtype, tensor_shape);
2196 }
2197 }
2198 int64_t id = FastTensorId(tensor);
2199 if (PyErr_Occurred()) {
2200 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2201 tensorflow::TensorShape({}));
2202 }
2203 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2204 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2205 Py_DECREF(dtype_object);
2206 tensorflow::DataType dtype =
2207 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2208 Py_DECREF(dtype_enum);
2209 if (PyErr_Occurred()) {
2210 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2211 tensorflow::TensorShape({}));
2212 }
2213 static char _shape_tuple[] = "_shape_tuple";
2214 tensorflow::Safe_PyObjectPtr shape_tuple(
2215 PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2216 if (PyErr_Occurred()) {
2217 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2218 tensorflow::TensorShape({}));
2219 }
2220
2221 if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2222 return PyTapeTensor(id, dtype, tensor);
2223 }
2224
2225 auto l = MakeIntList(shape_tuple.get());
2226 // Replace -1, which represents accidental Nones which can occur in graph mode
2227 // and can cause errors in shape construction with 0s.
2228 for (auto& c : l) {
2229 if (c < 0) {
2230 c = 0;
2231 }
2232 }
2233 tensorflow::TensorShape shape(l);
2234 return PyTapeTensor(id, dtype, shape);
2235}
2236
2237// Populates output_info from output_seq, which must come from PySequence_Fast.
2238//
2239// Does not take ownership of output_seq. Returns true on success and false if a
2240// Python exception has been set.
2241bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2242 std::vector<PyTapeTensor>* output_info) {
2243 Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2244 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2245 output_info->reserve(output_len);
2246 for (Py_ssize_t i = 0; i < output_len; ++i) {
2247 output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2248 if (PyErr_Occurred() != nullptr) {
2249 return false;
2250 }
2251 }
2252 return true;
2253}
2254
2255std::vector<int64_t> MakeTensorIDList(PyObject* tensors) {
2256 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2257 if (seq == nullptr) {
2258 return {};
2259 }
2260 int len = PySequence_Fast_GET_SIZE(seq);
2261 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2262 std::vector<int64_t> list;
2263 list.reserve(len);
2264 for (int i = 0; i < len; ++i) {
2265 PyObject* tensor = seq_array[i];
2266 list.push_back(FastTensorId(tensor));
2267 if (PyErr_Occurred()) {
2268 Py_DECREF(seq);
2269 return list;
2270 }
2271 }
2272 Py_DECREF(seq);
2273 return list;
2274}
2275
2276void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2277 if (!CouldBackprop()) {
2278 return;
2279 }
2280 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2281 tape->tape->VariableAccessed(variable);
2282 }
2283}
2284
2285void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2286 if (!CouldBackprop()) {
2287 return;
2288 }
2289 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2290}
2291
2292PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2293 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2294}
2295
2296PyObject* TFE_Py_VariableWatcherNew() {
2297 TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2298 if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2299 TFE_Py_VariableWatcher* variable_watcher =
2300 PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2301 variable_watcher->variable_watcher = new VariableWatcher();
2302 Py_INCREF(variable_watcher);
2303 GetVariableWatcherSet()->insert(variable_watcher);
2304 return reinterpret_cast<PyObject*>(variable_watcher);
2305}
2306
2307void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2308 auto* stack = GetVariableWatcherSet();
2309 stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2310 // We kept a reference to the variable watcher in the set to ensure it
2311 // wouldn't get deleted under us; cleaning it up here.
2312 Py_DECREF(variable_watcher);
2313}
2314
2315void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2316 for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2317 variable_watcher->variable_watcher->WatchVariable(variable);
2318 }
2319}
2320
2321PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2322 return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2323 ->variable_watcher->GetVariablesAsPyTuple();
2324}
2325
2326namespace {
2327std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2328 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2329 if (seq == nullptr) {
2330 return {};
2331 }
2332 int len = PySequence_Fast_GET_SIZE(seq);
2333 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2334 std::vector<tensorflow::DataType> list;
2335 list.reserve(len);
2336 for (int i = 0; i < len; ++i) {
2337 PyObject* tensor = seq_array[i];
2338 list.push_back(tensorflow::PyTensor_DataType(tensor));
2339 }
2340 Py_DECREF(seq);
2341 return list;
2342}
2343
2344PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2345 PyObject* weak_tensor_ref) {
2346 auto* accumulator_set = GetAccumulatorSet();
2347 if (accumulator_set != nullptr) {
2348 int64_t parsed_tensor_id = MakeInt(tensor_id);
2349 for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) {
2350 accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2351 }
2352 }
2353 Py_DECREF(weak_tensor_ref);
2354 Py_DECREF(tensor_id);
2355 Py_INCREF(Py_None);
2356 return Py_None;
2357}
2358
2359static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2360 "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2361 METH_O, "ForwardAccumulatorDeleteGradient"};
2362
2363void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2364 tensorflow::Safe_PyObjectPtr callback(
2365 PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2366 PyLong_FromLong(tensor_id)));
2367 // We need to keep a reference to the weakref active if we want our callback
2368 // called. The callback itself now owns the weakref object and the tensor ID
2369 // object.
2370 PyWeakref_NewRef(tensor, callback.get());
2371}
2372
2373void TapeSetRecordBackprop(
2374 const string& op_type, const std::vector<PyTapeTensor>& output_info,
2375 const std::vector<int64_t>& input_ids,
2376 const std::vector<tensorflow::DataType>& input_dtypes,
2377 const std::function<PyBackwardFunction*()>& backward_function_getter,
2378 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2379 tensorflow::uint64 max_gradient_tape_id) {
2380 if (!CouldBackprop()) {
2381 return;
2382 }
2383 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2384 if (tape->nesting_id < max_gradient_tape_id) {
2385 tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2386 backward_function_getter,
2387 backward_function_killer);
2388 }
2389 }
2390}
2391
2392bool TapeSetRecordForwardprop(
2393 const string& op_type, PyObject* output_seq,
2394 const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2395 const std::vector<int64_t>& input_ids,
2396 const std::vector<tensorflow::DataType>& input_dtypes,
2397 const std::function<PyBackwardFunction*()>& backward_function_getter,
2398 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2399 const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2400 PyObject* forwardprop_output_indices,
2401 tensorflow::uint64* max_gradient_tape_id) {
2402 *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2403 if (!CouldForwardprop()) {
2404 return true;
2405 }
2406 auto accumulator_set = SafeAccumulatorSet();
2407 tensorflow::Safe_PyObjectPtr input_seq(
2408 PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2409 if (input_seq == nullptr || PyErr_Occurred()) return false;
2410 Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2411 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2412 for (int i = 0; i < output_info.size(); ++i) {
2413 RegisterForwardAccumulatorCleanup(output_seq_array[i],
2414 output_info[i].GetID());
2415 }
2416 if (forwardprop_output_indices != nullptr &&
2417 forwardprop_output_indices != Py_None) {
2418 tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2419 forwardprop_output_indices, "Expected a sequence of indices"));
2420 if (indices_fast == nullptr || PyErr_Occurred()) {
2421 return false;
2422 }
2423 if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2424 accumulator_set.size()) {
2425 MaybeRaiseExceptionFromStatus(
2426 tensorflow::errors::Internal(
2427 "Accumulators were added or removed from the active set "
2428 "between packing and unpacking."),
2429 nullptr);
2430 }
2431 PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2432 Py_ssize_t accumulator_index = 0;
2433 for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2434 it != accumulator_set.rend(); ++it, ++accumulator_index) {
2435 tensorflow::Safe_PyObjectPtr jvp_index_seq(
2436 PySequence_Fast(indices_fast_array[accumulator_index],
2437 "Expected a sequence of jvp indices."));
2438 if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2439 return false;
2440 }
2441 Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2442 PyObject** jvp_index_seq_array =
2443 PySequence_Fast_ITEMS(jvp_index_seq.get());
2444 for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2445 PyObject* tuple = jvp_index_seq_array[jvp_index];
2446 int64_t primal_tensor_id =
2447 output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2448 (*it)->accumulator->Watch(
2449 primal_tensor_id,
2450 output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2451 }
2452 }
2453 } else {
2454 std::vector<PyTapeTensor> input_info;
2455 input_info.reserve(input_len);
2456 PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2457 for (Py_ssize_t i = 0; i < input_len; ++i) {
2458 input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2459 }
2460 for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2461 tensorflow::Status status = accumulator->accumulator->Accumulate(
2462 op_type, input_info, output_info, input_ids, input_dtypes,
2463 forward_function, backward_function_getter, backward_function_killer);
2464 if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
2465 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2466 return false;
2467 }
2468 if (accumulator->accumulator->BusyAccumulating()) {
2469 // Ensure inner accumulators don't see outer accumulators' jvps. This
2470 // mostly happens on its own, with some potentially surprising
2471 // exceptions, so the blanket policy is for consistency.
2472 *max_gradient_tape_id = accumulator->nesting_id;
2473 break;
2474 }
2475 }
2476 }
2477 return true;
2478}
2479
2480PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2481 PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2482 for (int i = 0; i < input_tangents.size(); ++i) {
2483 PyObject* element;
2484 if (input_tangents[i] == nullptr) {
2485 element = Py_None;
2486 } else {
2487 element = input_tangents[i];
2488 }
2489 Py_INCREF(element);
2490 PyTuple_SET_ITEM(py_input_tangents, i, element);
2491 }
2492 return py_input_tangents;
2493}
2494
2495tensorflow::Status ParseTangentOutputs(
2496 PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2497 if (user_output == Py_None) {
2498 // No connected gradients.
2499 return ::tensorflow::OkStatus();
2500 }
2501 tensorflow::Safe_PyObjectPtr fast_result(
2502 PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2503 if (fast_result == nullptr) {
2504 return tensorflow::errors::InvalidArgument(
2505 "forward gradient function did not return a sequence.");
2506 }
2507 int len = PySequence_Fast_GET_SIZE(fast_result.get());
2508 PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2509 output_tangents->reserve(len);
2510 for (int i = 0; i < len; ++i) {
2511 PyObject* item = fast_result_array[i];
2512 if (item == Py_None) {
2513 output_tangents->push_back(nullptr);
2514 } else {
2515 Py_INCREF(item);
2516 output_tangents->push_back(item);
2517 }
2518 }
2519 return ::tensorflow::OkStatus();
2520}
2521
2522// Calls the registered forward_gradient_function, computing `output_tangents`
2523// from `input_tangents`. `output_tangents` must not be null.
2524//
2525// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2526// the forward function is being called.
2527tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2528 PyObject* inputs, PyObject* results,
2529 const std::vector<PyObject*>& input_tangents,
2530 std::vector<PyObject*>* output_tangents,
2531 bool use_batch) {
2532 if (forward_gradient_function == nullptr) {
2533 return tensorflow::errors::Internal(
2534 "No forward gradient function registered.");
2535 }
2536 tensorflow::Safe_PyObjectPtr py_input_tangents(
2537 TangentsAsPyTuple(input_tangents));
2538
2539 // Normalize the input sequence to a tuple so it works with function
2540 // caching; otherwise it may be an opaque _InputList object.
2541 tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2542 PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2543 tensorflow::Safe_PyObjectPtr callback_args(
2544 Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2545 py_input_tangents.get(), to_batch));
2546 tensorflow::Safe_PyObjectPtr py_result(
2547 PyObject_CallObject(forward_gradient_function, callback_args.get()));
2548 if (py_result == nullptr || PyErr_Occurred()) {
2549 return tensorflow::errors::Internal(
2550 "forward gradient function threw exceptions");
2551 }
2552 return ParseTangentOutputs(py_result.get(), output_tangents);
2553}
2554
2555// Like CallJVPFunction, but calls a pre-bound forward function.
2556// These are passed in from a record_gradient argument.
2557tensorflow::Status CallOpSpecificJVPFunction(
2558 PyObject* op_specific_forward_function,
2559 const std::vector<PyObject*>& input_tangents,
2560 std::vector<PyObject*>* output_tangents) {
2561 tensorflow::Safe_PyObjectPtr py_input_tangents(
2562 TangentsAsPyTuple(input_tangents));
2563
2564 tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2565 op_specific_forward_function, py_input_tangents.get()));
2566 if (py_result == nullptr || PyErr_Occurred()) {
2567 return tensorflow::errors::Internal(
2568 "forward gradient function threw exceptions");
2569 }
2570 return ParseTangentOutputs(py_result.get(), output_tangents);
2571}
2572
2573bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2574 if (PyBytes_Check(op_type)) {
2575 *op_type_string = PyBytes_AsString(op_type);
2576 } else if (PyUnicode_Check(op_type)) {
2577#if PY_MAJOR_VERSION >= 3
2578 *op_type_string = PyUnicode_AsUTF8(op_type);
2579#else
2580 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2581 if (py_str == nullptr) {
2582 return false;
2583 }
2584 *op_type_string = PyBytes_AS_STRING(py_str);
2585 Py_DECREF(py_str);
2586#endif
2587 } else {
2588 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2589 return false;
2590 }
2591 return true;
2592}
2593
2594bool TapeSetRecordOperation(
2595 PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2596 const std::vector<int64_t>& input_ids,
2597 const std::vector<tensorflow::DataType>& input_dtypes,
2598 const std::function<PyBackwardFunction*()>& backward_function_getter,
2599 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2600 const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2601 std::vector<PyTapeTensor> output_info;
2602 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2603 output_tensors, "expected a sequence of integer tensor ids"));
2604 if (PyErr_Occurred() ||
2605 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2606 return false;
2607 }
2608 string op_type_str;
2609 if (!ParseOpTypeString(op_type, &op_type_str)) {
2610 return false;
2611 }
2612 tensorflow::uint64 max_gradient_tape_id;
2613 if (!TapeSetRecordForwardprop(
2614 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2615 input_dtypes, backward_function_getter, backward_function_killer,
2616 forward_function, nullptr /* No special-cased jvps. */,
2617 &max_gradient_tape_id)) {
2618 return false;
2619 }
2620 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2621 backward_function_getter, backward_function_killer,
2622 max_gradient_tape_id);
2623 return true;
2624}
2625} // namespace
2626
2627PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2628 PyObject* output_tensors,
2629 PyObject* input_tensors,
2630 PyObject* backward_function,
2631 PyObject* forward_function) {
2632 if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2633 Py_RETURN_NONE;
2634 }
2635 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2636 if (PyErr_Occurred()) return nullptr;
2637
2638 std::vector<tensorflow::DataType> input_dtypes =
2639 MakeTensorDtypeList(input_tensors);
2640 if (PyErr_Occurred()) return nullptr;
2641
2642 std::function<PyBackwardFunction*()> backward_function_getter(
2643 [backward_function]() {
2644 Py_INCREF(backward_function);
2645 PyBackwardFunction* function = new PyBackwardFunction(
2646 [backward_function](PyObject* out_grads,
2647 const std::vector<int64_t>& unused) {
2648 return PyObject_CallObject(backward_function, out_grads);
2649 });
2650 return function;
2651 });
2652 std::function<void(PyBackwardFunction*)> backward_function_killer(
2653 [backward_function](PyBackwardFunction* py_backward_function) {
2654 Py_DECREF(backward_function);
2655 delete py_backward_function;
2656 });
2657
2658 if (forward_function == Py_None) {
2659 if (!TapeSetRecordOperation(
2660 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2661 backward_function_getter, backward_function_killer,
2662 nullptr /* No special-cased forward function */)) {
2663 return nullptr;
2664 }
2665 } else {
2666 tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2667 [forward_function](const std::vector<PyObject*>& input_tangents,
2668 std::vector<PyObject*>* output_tangents,
2669 bool use_batch = false) {
2670 return CallOpSpecificJVPFunction(forward_function, input_tangents,
2671 output_tangents);
2672 });
2673 if (!TapeSetRecordOperation(
2674 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2675 backward_function_getter, backward_function_killer,
2676 &wrapped_forward_function)) {
2677 return nullptr;
2678 }
2679 }
2680 Py_RETURN_NONE;
2681}
2682
2683PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2684 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2685 PyObject* backward_function, PyObject* forwardprop_output_indices) {
2686 if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2687 Py_RETURN_NONE;
2688 }
2689 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2690 if (PyErr_Occurred()) return nullptr;
2691
2692 std::vector<tensorflow::DataType> input_dtypes =
2693 MakeTensorDtypeList(input_tensors);
2694 if (PyErr_Occurred()) return nullptr;
2695
2696 std::function<PyBackwardFunction*()> backward_function_getter(
2697 [backward_function]() {
2698 Py_INCREF(backward_function);
2699 PyBackwardFunction* function = new PyBackwardFunction(
2700 [backward_function](PyObject* out_grads,
2701 const std::vector<int64_t>& unused) {
2702 return PyObject_CallObject(backward_function, out_grads);
2703 });
2704 return function;
2705 });
2706 std::function<void(PyBackwardFunction*)> backward_function_killer(
2707 [backward_function](PyBackwardFunction* py_backward_function) {
2708 Py_DECREF(backward_function);
2709 delete py_backward_function;
2710 });
2711 std::vector<PyTapeTensor> output_info;
2712 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2713 output_tensors, "expected a sequence of integer tensor ids"));
2714 if (PyErr_Occurred() ||
2715 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2716 return nullptr;
2717 }
2718 string op_type_str;
2719 if (!ParseOpTypeString(op_type, &op_type_str)) {
2720 return nullptr;
2721 }
2722 tensorflow::uint64 max_gradient_tape_id;
2723 if (!TapeSetRecordForwardprop(
2724 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2725 input_dtypes, backward_function_getter, backward_function_killer,
2726 nullptr /* no special-cased forward function */,
2727 forwardprop_output_indices, &max_gradient_tape_id)) {
2728 return nullptr;
2729 }
2730 Py_RETURN_NONE;
2731}
2732
2733PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2734 PyObject* output_tensors,
2735 PyObject* input_tensors,
2736 PyObject* backward_function) {
2737 if (!CouldBackprop()) {
2738 Py_RETURN_NONE;
2739 }
2740 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2741 if (PyErr_Occurred()) return nullptr;
2742
2743 std::vector<tensorflow::DataType> input_dtypes =
2744 MakeTensorDtypeList(input_tensors);
2745 if (PyErr_Occurred()) return nullptr;
2746
2747 std::function<PyBackwardFunction*()> backward_function_getter(
2748 [backward_function]() {
2749 Py_INCREF(backward_function);
2750 PyBackwardFunction* function = new PyBackwardFunction(
2751 [backward_function](PyObject* out_grads,
2752 const std::vector<int64_t>& unused) {
2753 return PyObject_CallObject(backward_function, out_grads);
2754 });
2755 return function;
2756 });
2757 std::function<void(PyBackwardFunction*)> backward_function_killer(
2758 [backward_function](PyBackwardFunction* py_backward_function) {
2759 Py_DECREF(backward_function);
2760 delete py_backward_function;
2761 });
2762 std::vector<PyTapeTensor> output_info;
2763 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2764 output_tensors, "expected a sequence of integer tensor ids"));
2765 if (PyErr_Occurred() ||
2766 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2767 return nullptr;
2768 }
2769 string op_type_str;
2770 if (!ParseOpTypeString(op_type, &op_type_str)) {
2771 return nullptr;
2772 }
2773 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2774 backward_function_getter, backward_function_killer,
2775 // No filtering based on relative ordering with forward
2776 // accumulators.
2777 std::numeric_limits<tensorflow::uint64>::max());
2778 Py_RETURN_NONE;
2779}
2780
2781void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2782 auto* tape_set = GetTapeSet();
2783 if (tape_set == nullptr) {
2784 // Current thread is being destructed, and the tape set has already
2785 // been cleared.
2786 return;
2787 }
2788 for (TFE_Py_Tape* tape : *tape_set) {
2789 tape->tape->DeleteTrace(tensor_id);
2790 }
2791}
2792
2793std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2794 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2795 if (seq == nullptr) {
2796 return {};
2797 }
2798 int len = PySequence_Fast_GET_SIZE(seq);
2799 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2800 std::vector<PyObject*> list(seq_array, seq_array + len);
2801 Py_DECREF(seq);
2802 return list;
2803}
2804
2805PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2806 PyObject* sources, PyObject* output_gradients,
2807 PyObject* sources_raw,
2808 PyObject* unconnected_gradients,
2809 TF_Status* status) {
2810 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2811 if (!tape_obj->tape->IsPersistent()) {
2812 auto* tape_set = GetTapeSet();
2813 if (tape_set->find(tape_obj) != tape_set->end()) {
2814 PyErr_SetString(PyExc_RuntimeError,
2815 "gradient() cannot be invoked within the "
2816 "GradientTape context (i.e., while operations are being "
2817 "recorded). Either move the call to gradient() to be "
2818 "outside the 'with tf.GradientTape' block, or "
2819 "use a persistent tape: "
2820 "'with tf.GradientTape(persistent=true)'");
2821 return nullptr;
2822 }
2823 }
2824
2825 std::vector<int64_t> target_vec = MakeTensorIDList(target);
2826 if (PyErr_Occurred()) {
2827 return nullptr;
2828 }
2829 std::vector<int64_t> sources_vec = MakeTensorIDList(sources);
2830 if (PyErr_Occurred()) {
2831 return nullptr;
2832 }
2833 tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(),
2834 sources_vec.end());
2835
2836 tensorflow::Safe_PyObjectPtr seq =
2837 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2838 int len = PySequence_Fast_GET_SIZE(seq.get());
2839 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2840 std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets;
2841 for (int i = 0; i < len; ++i) {
2842 int64_t target_id = target_vec[i];
2843 if (sources_set.find(target_id) != sources_set.end()) {
2844 auto tensor = seq_array[i];
2845 source_tensors_that_are_targets.insert(
2846 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2847 }
2848 if (PyErr_Occurred()) {
2849 return nullptr;
2850 }
2851 }
2852 if (PyErr_Occurred()) {
2853 return nullptr;
2854 }
2855
2856 std::vector<PyObject*> outgrad_vec;
2857 if (output_gradients != Py_None) {
2858 outgrad_vec = MakeTensorList(output_gradients);
2859 if (PyErr_Occurred()) {
2860 return nullptr;
2861 }
2862 for (PyObject* tensor : outgrad_vec) {
2863 // Calling the backward function will eat a reference to the tensors in
2864 // outgrad_vec, so we need to increase their reference count.
2865 Py_INCREF(tensor);
2866 }
2867 }
2868 std::vector<PyObject*> result(sources_vec.size());
2869 status->status = tape_obj->tape->ComputeGradient(
2870 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2871 outgrad_vec, absl::MakeSpan(result));
2872 if (!status->status.ok()) {
2873 if (PyErr_Occurred()) {
2874 // Do not propagate the erroneous status as that would swallow the
2875 // exception which caused the problem.
2876 status->status = ::tensorflow::OkStatus();
2877 }
2878 return nullptr;
2879 }
2880
2881 bool unconnected_gradients_zero =
2882 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2883 std::vector<PyObject*> sources_obj;
2884 if (unconnected_gradients_zero) {
2885 // Uses the "raw" sources here so it can properly make a zeros tensor even
2886 // if there are resource variables as sources.
2887 sources_obj = MakeTensorList(sources_raw);
2888 }
2889
2890 if (!result.empty()) {
2891 PyObject* py_result = PyList_New(result.size());
2892 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2893 for (int i = 0; i < result.size(); ++i) {
2894 if (result[i] == nullptr) {
2895 if (unconnected_gradients_zero) {
2896 // generate a zeros tensor in the shape of sources[i]
2897 tensorflow::DataType dtype =
2898 tensorflow::PyTensor_DataType(sources_obj[i]);
2899 PyTapeTensor tensor =
2900 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2901 result[i] = tensor.ZerosLike();
2902 } else {
2903 Py_INCREF(Py_None);
2904 result[i] = Py_None;
2905 }
2906 } else if (seen_results.find(result[i]) != seen_results.end()) {
2907 Py_INCREF(result[i]);
2908 }
2909 seen_results.insert(result[i]);
2910 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2911 }
2912 return py_result;
2913 }
2914 return PyList_New(0);
2915}
2916
2917PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2918 TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2919 if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2920 TFE_Py_ForwardAccumulator* accumulator =
2921 PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2922 if (py_vspace == nullptr) {
2923 MaybeRaiseExceptionFromStatus(
2924 tensorflow::errors::Internal(
2925 "ForwardAccumulator requires a PyVSpace to be registered."),
2926 nullptr);
2927 }
2928 accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2929 return reinterpret_cast<PyObject*>(accumulator);
2930}
2931
2932PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2933 TFE_Py_ForwardAccumulator* c_accumulator(
2934 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2935 c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2936 if (GetAccumulatorSet()->insert(c_accumulator)) {
2937 Py_INCREF(accumulator);
2938 Py_RETURN_NONE;
2939 } else {
2940 MaybeRaiseExceptionFromStatus(
2941 tensorflow::errors::Internal(
2942 "A ForwardAccumulator was added to the active set twice."),
2943 nullptr);
2944 return nullptr;
2945 }
2946}
2947
2948void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2949 auto* accumulator_set = GetAccumulatorSet();
2950 if (accumulator_set != nullptr) {
2951 accumulator_set->erase(
2952 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2953 }
2954 Py_DECREF(accumulator);
2955}
2956
2957void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2958 PyObject* tangent) {
2959 int64_t tensor_id = FastTensorId(tensor);
2960 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2961 ->accumulator->Watch(tensor_id, tangent);
2962 RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2963}
2964
2965// Returns a new reference to the JVP Tensor.
2966PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2967 PyObject* tensor) {
2968 PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2969 ->accumulator->FetchJVP(FastTensorId(tensor));
2970 if (jvp == nullptr) {
2971 jvp = Py_None;
2972 }
2973 Py_INCREF(jvp);
2974 return jvp;
2975}
2976
2977PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2978 if (!TapeCouldPossiblyRecord(tensors)) {
2979 tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2980 tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2981 return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2982 }
2983 auto& accumulators = *GetAccumulatorSet();
2984 tensorflow::Safe_PyObjectPtr tensors_fast(
2985 PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2986 if (tensors_fast == nullptr || PyErr_Occurred()) {
2987 return nullptr;
2988 }
2989 std::vector<int64_t> augmented_input_ids;
2990 int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2991 PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2992 for (Py_ssize_t position = 0; position < len; ++position) {
2993 PyObject* input = tensors_fast_array[position];
2994 if (input == Py_None) {
2995 continue;
2996 }
2997 tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2998 if (input_dtype == tensorflow::DT_INVALID) {
2999 return nullptr;
3000 }
3001 augmented_input_ids.push_back(FastTensorId(input));
3002 }
3003 if (PyErr_Occurred()) {
3004 return nullptr;
3005 }
3006 // Find the innermost accumulator such that all outer accumulators are
3007 // recording. Any more deeply nested accumulators will not have their JVPs
3008 // saved.
3009 AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
3010 for (; innermost_all_recording != accumulators.end();
3011 ++innermost_all_recording) {
3012 if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
3013 break;
3014 }
3015 }
3016 AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
3017 innermost_all_recording);
3018
3019 bool saving_jvps = false;
3020 tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
3021 std::vector<PyObject*> new_tensors;
3022 Py_ssize_t accumulator_index = 0;
3023 // Start with the innermost accumulators to give outer accumulators a chance
3024 // to find their higher-order JVPs.
3025 for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
3026 it != accumulators.rend(); ++it, ++accumulator_index) {
3027 std::vector<int64_t> new_input_ids;
3028 std::vector<std::pair<int64_t, int64_t>> accumulator_indices;
3029 if (it == reverse_innermost_all_recording) {
3030 saving_jvps = true;
3031 }
3032 if (saving_jvps) {
3033 for (int input_index = 0; input_index < augmented_input_ids.size();
3034 ++input_index) {
3035 int64_t existing_input = augmented_input_ids[input_index];
3036 PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
3037 if (jvp != nullptr) {
3038 new_tensors.push_back(jvp);
3039 new_input_ids.push_back(FastTensorId(jvp));
3040 accumulator_indices.emplace_back(
3041 input_index,
3042 augmented_input_ids.size() + new_input_ids.size() - 1);
3043 }
3044 }
3045 }
3046 tensorflow::Safe_PyObjectPtr accumulator_indices_py(
3047 PyTuple_New(accumulator_indices.size()));
3048 for (int i = 0; i < accumulator_indices.size(); ++i) {
3049 tensorflow::Safe_PyObjectPtr from_index(
3050 GetPythonObjectFromInt(accumulator_indices[i].first));
3051 tensorflow::Safe_PyObjectPtr to_index(
3052 GetPythonObjectFromInt(accumulator_indices[i].second));
3053 PyTuple_SetItem(accumulator_indices_py.get(), i,
3054 PyTuple_Pack(2, from_index.get(), to_index.get()));
3055 }
3056 PyTuple_SetItem(all_indices.get(), accumulator_index,
3057 accumulator_indices_py.release());
3058 augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
3059 new_input_ids.end());
3060 }
3061
3062 tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
3063 for (int i = 0; i < new_tensors.size(); ++i) {
3064 PyObject* jvp = new_tensors[i];
3065 Py_INCREF(jvp);
3066 PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
3067 }
3068 return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
3069}
3070
3071namespace {
3072
3073// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
3074enum FastPathExecuteArgIndex {
3075 FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3076 FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3077 FAST_PATH_EXECUTE_ARG_NAME = 2,
3078 FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3079};
3080
3081PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3082#if PY_MAJOR_VERSION >= 3
3083 return PyUnicode_FromStringAndSize(s.data(), s.size());
3084#else
3085 return PyBytes_FromStringAndSize(s.data(), s.size());
3086#endif
3087}
3088
3089bool CheckResourceVariable(PyObject* item) {
3090 if (tensorflow::swig::IsResourceVariable(item)) {
3091 tensorflow::Safe_PyObjectPtr handle(
3092 PyObject_GetAttrString(item, "_handle"));
3093 return EagerTensor_CheckExact(handle.get());
3094 }
3095
3096 return false;
3097}
3098
3099bool IsNumberType(PyObject* item) {
3100#if PY_MAJOR_VERSION >= 3
3101 return PyFloat_Check(item) || PyLong_Check(item);
3102#else
3103 return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3104#endif
3105}
3106
3107bool CheckOneInput(PyObject* item) {
3108 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3109 PyArray_Check(item) || IsNumberType(item)) {
3110 return true;
3111 }
3112
3113 // Sequences are not properly handled. Sequences with purely python numeric
3114 // types work, but sequences with mixes of EagerTensors and python numeric
3115 // types don't work.
3116 // TODO(nareshmodi): fix
3117 return false;
3118}
3119
3120bool CheckInputsOk(PyObject* seq, int start_index,
3121 const tensorflow::OpDef& op_def) {
3122 for (int i = 0; i < op_def.input_arg_size(); i++) {
3123 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3124 if (!op_def.input_arg(i).number_attr().empty() ||
3125 !op_def.input_arg(i).type_list_attr().empty()) {
3126 // This item should be a seq input.
3127 if (!PySequence_Check(item)) {
3128 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3129 << "\", Input \"" << op_def.input_arg(i).name()
3130 << "\" since we expected a sequence, but got "
3131 << item->ob_type->tp_name;
3132 return false;
3133 }
3134 tensorflow::Safe_PyObjectPtr fast_item(
3135 PySequence_Fast(item, "Could not parse sequence."));
3136 if (fast_item.get() == nullptr) {
3137 return false;
3138 }
3139 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3140 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3141 for (Py_ssize_t j = 0; j < len; j++) {
3142 PyObject* inner_item = fast_item_array[j];
3143 if (!CheckOneInput(inner_item)) {
3144 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3145 << "\", Input \"" << op_def.input_arg(i).name()
3146 << "\", Index " << j
3147 << " since we expected an EagerTensor/ResourceVariable, "
3148 "but got "
3149 << inner_item->ob_type->tp_name;
3150 return false;
3151 }
3152 }
3153 } else if (!CheckOneInput(item)) {
3154 VLOG(1)
3155 << "Falling back to slow path for Op \"" << op_def.name()
3156 << "\", Input \"" << op_def.input_arg(i).name()
3157 << "\" since we expected an EagerTensor/ResourceVariable, but got "
3158 << item->ob_type->tp_name;
3159 return false;
3160 }
3161 }
3162
3163 return true;
3164}
3165
3166tensorflow::DataType MaybeGetDType(PyObject* item) {
3167 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3168 return tensorflow::PyTensor_DataType(item);
3169 }
3170
3171 return tensorflow::DT_INVALID;
3172}
3173
3174tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3175 FastPathOpExecInfo* op_exec_info) {
3176 auto cached_it = op_exec_info->cached_dtypes.find(attr);
3177 if (cached_it != op_exec_info->cached_dtypes.end()) {
3178 return cached_it->second;
3179 }
3180
3181 auto it = op_exec_info->attr_to_inputs_map->find(attr);
3182 if (it == op_exec_info->attr_to_inputs_map->end()) {
3183 // No other inputs - this should never happen.
3184 return tensorflow::DT_INVALID;
3185 }
3186
3187 for (const auto& input_info : it->second) {
3188 PyObject* item = PyTuple_GET_ITEM(
3189 op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3190 if (input_info.is_list) {
3191 tensorflow::Safe_PyObjectPtr fast_item(
3192 PySequence_Fast(item, "Unable to allocate"));
3193 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3194 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3195 for (int i = 0; i < len; i++) {
3196 auto dtype = MaybeGetDType(fast_item_array[i]);
3197 if (dtype != tensorflow::DT_INVALID) return dtype;
3198 }
3199 } else {
3200 auto dtype = MaybeGetDType(item);
3201 if (dtype != tensorflow::DT_INVALID) return dtype;
3202 }
3203 }
3204
3205 auto default_it = op_exec_info->default_dtypes->find(attr);
3206 if (default_it != op_exec_info->default_dtypes->end()) {
3207 return default_it->second;
3208 }
3209
3210 return tensorflow::DT_INVALID;
3211}
3212
3213PyObject* CopySequenceSettingIndicesToNull(
3214 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3215 tensorflow::Safe_PyObjectPtr fast_seq(
3216 PySequence_Fast(seq, "unable to allocate"));
3217 int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3218 PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3219 PyObject* result = PyTuple_New(len);
3220 for (int i = 0; i < len; i++) {
3221 PyObject* item;
3222 if (indices.find(i) != indices.end()) {
3223 item = Py_None;
3224 } else {
3225 item = fast_seq_array[i];
3226 }
3227 Py_INCREF(item);
3228 PyTuple_SET_ITEM(result, i, item);
3229 }
3230 return result;
3231}
3232
3233PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3234 PyObject* results,
3235 PyObject* forward_pass_name_scope = nullptr) {
3236 std::vector<int64_t> input_ids = MakeTensorIDList(inputs);
3237 if (PyErr_Occurred()) return nullptr;
3238 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3239 if (PyErr_Occurred()) return nullptr;
3240
3241 bool should_record = false;
3242 for (TFE_Py_Tape* tape : SafeTapeSet()) {
3243 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3244 should_record = true;
3245 break;
3246 }
3247 }
3248 if (!should_record) {
3249 for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3250 if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3251 should_record = true;
3252 break;
3253 }
3254 }
3255 }
3256 if (!should_record) Py_RETURN_NONE;
3257
3258 string c_op_name = TFE_GetPythonString(op_name);
3259
3260 PyObject* op_outputs;
3261 bool op_outputs_tuple_created = false;
3262
3263 if (const auto unused_output_indices =
3264 OpGradientUnusedOutputIndices(c_op_name)) {
3265 if (unused_output_indices->empty()) {
3266 op_outputs = Py_None;
3267 } else {
3268 op_outputs_tuple_created = true;
3269 op_outputs =
3270 CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3271 }
3272 } else {
3273 op_outputs = results;
3274 }
3275
3276 PyObject* op_inputs;
3277 bool op_inputs_tuple_created = false;
3278
3279 if (const auto unused_input_indices =
3280 OpGradientUnusedInputIndices(c_op_name)) {
3281 if (unused_input_indices->empty()) {
3282 op_inputs = Py_None;
3283 } else {
3284 op_inputs_tuple_created = true;
3285 op_inputs =
3286 CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3287 }
3288 } else {
3289 op_inputs = inputs;
3290 }
3291
3292 tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3293 [op_name, attrs, inputs, results](
3294 const std::vector<PyObject*>& input_tangents,
3295 std::vector<PyObject*>* output_tangents, bool use_batch) {
3296 return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3297 output_tangents, use_batch);
3298 });
3299 tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3300 if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3301 c_op_name == "If" || c_op_name == "StatelessIf") {
3302 // Control flow contains non-hashable attributes. Handling them in Python is
3303 // a headache, so instead we'll stay as close to GradientTape's handling as
3304 // possible (a null forward function means the accumulator forwards to a
3305 // tape).
3306 //
3307 // This is safe to do since we'll only see control flow when graph building,
3308 // in which case we can rely on pruning.
3309 forward_function = nullptr;
3310 } else {
3311 forward_function = &py_forward_function;
3312 }
3313
3314 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3315
3316 if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3317
3318 TapeSetRecordOperation(
3319 op_name, inputs, results, input_ids, input_dtypes,
3320 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3321 forward_pass_name_scope]() {
3322 Py_INCREF(op_name);
3323 Py_INCREF(attrs);
3324 Py_INCREF(num_inputs);
3325 Py_INCREF(op_inputs);
3326 Py_INCREF(op_outputs);
3327 Py_INCREF(forward_pass_name_scope);
3328 PyBackwardFunction* function = new PyBackwardFunction(
3329 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3330 forward_pass_name_scope](
3331 PyObject* output_grads,
3332 const std::vector<int64_t>& unneeded_gradients) {
3333 if (PyErr_Occurred()) {
3334 return static_cast<PyObject*>(nullptr);
3335 }
3336 tensorflow::Safe_PyObjectPtr skip_input_indices;
3337 if (!unneeded_gradients.empty()) {
3338 skip_input_indices.reset(
3339 PyTuple_New(unneeded_gradients.size()));
3340 for (int i = 0; i < unneeded_gradients.size(); i++) {
3341 PyTuple_SET_ITEM(
3342 skip_input_indices.get(), i,
3343 GetPythonObjectFromInt(unneeded_gradients[i]));
3344 }
3345 } else {
3346 Py_INCREF(Py_None);
3347 skip_input_indices.reset(Py_None);
3348 }
3349 tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3350 "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3351 output_grads, skip_input_indices.get(),
3352 forward_pass_name_scope));
3353
3354 tensorflow::Safe_PyObjectPtr result(
3355 PyObject_CallObject(gradient_function, callback_args.get()));
3356
3357 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3358
3359 return tensorflow::swig::Flatten(result.get());
3360 });
3361 return function;
3362 },
3363 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3364 forward_pass_name_scope](PyBackwardFunction* backward_function) {
3365 Py_DECREF(op_name);
3366 Py_DECREF(attrs);
3367 Py_DECREF(num_inputs);
3368 Py_DECREF(op_inputs);
3369 Py_DECREF(op_outputs);
3370 Py_DECREF(forward_pass_name_scope);
3371
3372 delete backward_function;
3373 },
3374 forward_function);
3375
3376 Py_DECREF(num_inputs);
3377 if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3378 if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3379
3380 if (PyErr_Occurred()) {
3381 return nullptr;
3382 }
3383
3384 Py_RETURN_NONE;
3385}
3386
3387void MaybeNotifyVariableAccessed(PyObject* input) {
3388 DCHECK(CheckResourceVariable(input));
3389 DCHECK(PyObject_HasAttrString(input, "_trainable"));
3390
3391 tensorflow::Safe_PyObjectPtr trainable(
3392 PyObject_GetAttrString(input, "_trainable"));
3393 if (trainable.get() == Py_False) return;
3394 TFE_Py_TapeVariableAccessed(input);
3395 TFE_Py_VariableWatcherVariableAccessed(input);
3396}
3397
3398bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3399 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3400 TF_Status* status) {
3401 MaybeNotifyVariableAccessed(input);
3402
3403 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3404 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3405 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3406 return false;
3407
3408 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3409 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3410 return false;
3411
3412 // Set dtype
3413 DCHECK(PyObject_HasAttrString(input, "_dtype"));
3414 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3415 int value;
3416 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3417 return false;
3418 }
3419 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3420
3421 // Get handle
3422 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3423 if (!EagerTensor_CheckExact(handle.get())) return false;
3424 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3425 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3426 return false;
3427
3428 int num_retvals = 1;
3429 TFE_TensorHandle* output_handle;
3430 TFE_Execute(op, &output_handle, &num_retvals, status);
3431 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3432 return false;
3433
3434 // Always create the py object (and correctly DECREF it) from the returned
3435 // value, else the data will leak.
3436 output->reset(EagerTensorFromHandle(output_handle));
3437
3438 // TODO(nareshmodi): Should we run post exec callbacks here?
3439 if (parent_op_exec_info.run_gradient_callback) {
3440 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3441 PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3442
3443 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3444 Py_INCREF(output->get()); // stay alive after since tuple steals.
3445 PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3446
3447 tensorflow::Safe_PyObjectPtr op_string(
3448 GetPythonObjectFromString("ReadVariableOp"));
3449 if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3450 outputs.get())) {
3451 return false;
3452 }
3453 }
3454
3455 return true;
3456}
3457
3458// Supports 3 cases at the moment:
3459// i) input is an EagerTensor.
3460// ii) input is a ResourceVariable - in this case, the is_variable param is
3461// set to true.
3462// iii) input is an arbitrary python list/tuple (note, this handling doesn't
3463// support packing).
3464//
3465// NOTE: dtype_hint_getter must *always* return a PyObject that can be
3466// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3467// increfs Py_None).
3468//
3469// NOTE: This function sets a python error directly, and returns false.
3470// TF_Status is only passed since we don't want to have to reallocate it.
3471bool ConvertToTensor(
3472 const FastPathOpExecInfo& op_exec_info, PyObject* input,
3473 tensorflow::Safe_PyObjectPtr* output_handle,
3474 // This gets a hint for this particular input.
3475 const std::function<tensorflow::DataType()>& dtype_hint_getter,
3476 // This sets the dtype after conversion is complete.
3477 const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3478 TF_Status* status) {
3479 if (EagerTensor_CheckExact(input)) {
3480 Py_INCREF(input);
3481 output_handle->reset(input);
3482 return true;
3483 } else if (CheckResourceVariable(input)) {
3484 return ReadVariableOp(op_exec_info, input, output_handle, status);
3485 }
3486
3487 // The hint comes from a supposedly similarly typed tensor.
3488 tensorflow::DataType dtype_hint = dtype_hint_getter();
3489
3490 TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3491 op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3492 if (handle == nullptr) {
3493 return tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3494 }
3495
3496 output_handle->reset(EagerTensorFromHandle(handle));
3497 dtype_setter(
3498 static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3499
3500 return true;
3501}
3502
3503// Adds input and type attr to the op, and to the list of flattened
3504// inputs/attrs.
3505bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3506 const bool add_type_attr,
3507 const tensorflow::OpDef::ArgDef& input_arg,
3508 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3509 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3510 TFE_Op* op, TF_Status* status) {
3511 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3512 // required, else the object is destroyed and DECREF'd when the object goes
3513 // out of scope in this function.
3514 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3515
3516 if (!ConvertToTensor(
3517 *op_exec_info, input, &py_eager_tensor,
3518 [&]() {
3519 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3520 return input_arg.type();
3521 }
3522 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3523 },
3524 [&](const tensorflow::DataType dtype) {
3525 op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3526 },
3527 status)) {
3528 return false;
3529 }
3530
3531 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3532
3533 if (add_type_attr && !input_arg.type_attr().empty()) {
3534 auto dtype = TFE_TensorHandleDataType(input_handle);
3535 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3536 if (flattened_attrs != nullptr) {
3537 flattened_attrs->emplace_back(
3538 GetPythonObjectFromString(input_arg.type_attr()));
3539 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3540 }
3541 }
3542
3543 if (flattened_inputs != nullptr) {
3544 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3545 }
3546
3547 TFE_OpAddInput(op, input_handle, status);
3548 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3549 return false;
3550 }
3551
3552 return true;
3553}
3554
3555const char* GetDeviceName(PyObject* py_device_name) {
3556 if (py_device_name != Py_None) {
3557 return TFE_GetPythonString(py_device_name);
3558 }
3559 return nullptr;
3560}
3561
3562bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3563 if (!PySequence_Check(seq)) {
3564 PyErr_SetString(PyExc_TypeError,
3565 Printf("expected a sequence for attr %s, got %s instead",
3566 attr_name.data(), seq->ob_type->tp_name)
3567 .data());
3568
3569 return false;
3570 }
3571 if (PyArray_Check(seq) &&
3572 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3573 PyErr_SetString(PyExc_ValueError,
3574 Printf("expected a sequence for attr %s, got an ndarray "
3575 "with rank %d instead",
3576 attr_name.data(),
3577 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3578 .data());
3579 return false;
3580 }
3581 return true;
3582}
3583
3584bool RunCallbacks(
3585 const FastPathOpExecInfo& op_exec_info, PyObject* args,
3586 int num_inferred_attrs,
3587 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3588 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3589 PyObject* flattened_result) {
3590 DCHECK(op_exec_info.run_callbacks);
3591
3592 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3593 for (int i = 0; i < flattened_inputs.size(); i++) {
3594 PyObject* input = flattened_inputs[i].get();
3595 Py_INCREF(input);
3596 PyTuple_SET_ITEM(inputs.get(), i, input);
3597 }
3598
3599 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3600 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3601 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3602
3603 for (int i = 0; i < num_non_inferred_attrs; i++) {
3604 auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3605 Py_INCREF(attr);
3606 PyTuple_SET_ITEM(attrs.get(), i, attr);
3607 }
3608
3609 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3610 PyObject* attr_or_name =
3611 flattened_attrs.at(i - num_non_inferred_attrs).get();
3612 Py_INCREF(attr_or_name);
3613 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3614 }
3615
3616 if (op_exec_info.run_gradient_callback) {
3617 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3618 flattened_result)) {
3619 return false;
3620 }
3621 }
3622
3623 if (op_exec_info.run_post_exec_callbacks) {
3624 tensorflow::Safe_PyObjectPtr callback_args(
3625 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3626 flattened_result, op_exec_info.name));
3627 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3628 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3629 if (!PyCallable_Check(callback_fn)) {
3630 PyErr_SetString(
3631 PyExc_TypeError,
3632 Printf("expected a function for "
3633 "post execution callback in index %ld, got %s instead",
3634 i, callback_fn->ob_type->tp_name)
3635 .c_str());
3636 return false;
3637 }
3638 PyObject* callback_result =
3639 PyObject_CallObject(callback_fn, callback_args.get());
3640 if (!callback_result) {
3641 return false;
3642 }
3643 Py_DECREF(callback_result);
3644 }
3645 }
3646
3647 return true;
3648}
3649
3650} // namespace
3651
3652PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3653 tensorflow::profiler::TraceMe activity(
3654 "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3655 Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3656 if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3657 PyErr_SetString(
3658 PyExc_ValueError,
3659 Printf("There must be at least %d items in the input tuple.",
3660 FAST_PATH_EXECUTE_ARG_INPUT_START)
3661 .c_str());
3662 return nullptr;
3663 }
3664
3665 FastPathOpExecInfo op_exec_info;
3666
3667 PyObject* py_eager_context =
3668 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3669
3670 // TODO(edoper): Use interned string here
3671 PyObject* eager_context_handle =
3672 PyObject_GetAttrString(py_eager_context, "_context_handle");
3673
3674 TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3675 PyCapsule_GetPointer(eager_context_handle, nullptr));
3676 op_exec_info.ctx = ctx;
3677 op_exec_info.args = args;
3678
3679 if (ctx == nullptr) {
3680 // The context hasn't been initialized. It will be in the slow path.
3681 RaiseFallbackException(
3682 "This function does not handle the case of the path where "
3683 "all inputs are not already EagerTensors.");
3684 return nullptr;
3685 }
3686
3687 auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3688 if (tld == nullptr) {
3689 return nullptr;
3690 }
3691 op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3692 op_exec_info.callbacks = tld->op_callbacks.get();
3693
3694 op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3695 op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3696
3697 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3698 // (similar to benchmark_tf_gradient_function_*). Also consider using an
3699 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3700 // point out problems with heap allocs.
3701 op_exec_info.run_gradient_callback =
3702 !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3703 op_exec_info.run_post_exec_callbacks =
3704 op_exec_info.callbacks != Py_None &&
3705 PyList_Size(op_exec_info.callbacks) > 0;
3706 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3707 op_exec_info.run_post_exec_callbacks;
3708
3709 TF_Status* status = GetStatus();
3710 const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3711 if (op_name == nullptr) {
3712 PyErr_SetString(PyExc_TypeError,
3713 Printf("expected a string for op_name, got %s instead",
3714 op_exec_info.op_name->ob_type->tp_name)
3715 .c_str());
3716 return nullptr;
3717 }
3718
3719 TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3720
3721 auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3722 ReturnStatus(status);
3723 ReturnOp(ctx, op);
3724 });
3725
3726 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3727 return nullptr;
3728 }
3729
3730 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3731 tensorflow::StackTrace::kStackTraceInitialSize));
3732
3733 const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3734 if (op_def == nullptr) return nullptr;
3735
3736 if (args_size <
3737 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3738 PyErr_SetString(
3739 PyExc_ValueError,
3740 Printf("Tuple size smaller than intended. Expected to be at least %d, "
3741 "was %ld",
3742 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3743 args_size)
3744 .c_str());
3745 return nullptr;
3746 }
3747
3748 if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3749 RaiseFallbackException(
3750 "This function does not handle the case of the path where "
3751 "all inputs are not already EagerTensors.");
3752 return nullptr;
3753 }
3754
3755 op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3756 op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3757
3758 // Mapping of attr name to size - used to calculate the number of values
3759 // to be expected by the TFE_Execute run.
3760 tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes;
3761
3762 // Set non-inferred attrs, including setting defaults if the attr is passed in
3763 // as None.
3764 for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3765 i < args_size; i += 2) {
3766 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3767 const char* attr_name = TFE_GetPythonString(py_attr_name);
3768 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3769
3770 // Not creating an index since most of the time there are not more than a
3771 // few attrs.
3772 // TODO(nareshmodi): Maybe include the index as part of the
3773 // OpRegistrationData.
3774 for (const auto& attr : op_def->attr()) {
3775 if (tensorflow::StringPiece(attr_name) == attr.name()) {
3776 SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3777 &attr_list_sizes, status);
3778
3779 if (!status->status.ok()) {
3780 VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3781 << "\" since we are unable to set the value for attr \""
3782 << attr.name() << "\" due to: " << TF_Message(status);
3783 RaiseFallbackException(TF_Message(status));
3784 return nullptr;
3785 }
3786
3787 break;
3788 }
3789 }
3790 }
3791
3792 // Flat attrs and inputs as required by the record_gradient call. The attrs
3793 // here only contain inferred attrs (non-inferred attrs are added directly
3794 // from the input args).
3795 // All items in flattened_attrs and flattened_inputs contain
3796 // Safe_PyObjectPtr - any time something steals a reference to this, it must
3797 // INCREF.
3798 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3799 // directly.
3800 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3801 nullptr;
3802 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3803 nullptr;
3804
3805 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3806 if (op_exec_info.run_callbacks) {
3807 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3808 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3809 }
3810
3811 // Add inferred attrs and inputs.
3812 // The following code might set duplicate type attrs. This will result in
3813 // the CacheKey for the generated AttrBuilder possibly differing from
3814 // those where the type attrs are correctly set. Inconsistent CacheKeys
3815 // for ops means that there might be unnecessarily duplicated kernels.
3816 // TODO(nareshmodi): Fix this.
3817 for (int i = 0; i < op_def->input_arg_size(); i++) {
3818 const auto& input_arg = op_def->input_arg(i);
3819
3820 PyObject* input =
3821 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3822 if (!input_arg.number_attr().empty()) {
3823 // The item is a homogeneous list.
3824 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3825 tensorflow::Safe_PyObjectPtr fast_input(
3826 PySequence_Fast(input, "Could not parse sequence."));
3827 if (fast_input.get() == nullptr) {
3828 return nullptr;
3829 }
3830 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3831 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3832
3833 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3834 if (op_exec_info.run_callbacks) {
3835 flattened_attrs->emplace_back(
3836 GetPythonObjectFromString(input_arg.number_attr()));
3837 flattened_attrs->emplace_back(PyLong_FromLong(len));
3838 }
3839 attr_list_sizes[input_arg.number_attr()] = len;
3840
3841 if (len > 0) {
3842 // First item adds the type attr.
3843 if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3844 flattened_attrs.get(), flattened_inputs.get(), op,
3845 status)) {
3846 return nullptr;
3847 }
3848
3849 for (Py_ssize_t j = 1; j < len; j++) {
3850 // Since the list is homogeneous, we don't need to re-add the attr.
3851 if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3852 input_arg, nullptr /* flattened_attrs */,
3853 flattened_inputs.get(), op, status)) {
3854 return nullptr;
3855 }
3856 }
3857 }
3858 } else if (!input_arg.type_list_attr().empty()) {
3859 // The item is a heterogeneous list.
3860 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3861 return nullptr;
3862 }
3863 tensorflow::Safe_PyObjectPtr fast_input(
3864 PySequence_Fast(input, "Could not parse sequence."));
3865 if (fast_input.get() == nullptr) {
3866 return nullptr;
3867 }
3868 const string& attr_name = input_arg.type_list_attr();
3869 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3870 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3871 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3872 PyObject* py_attr_value = nullptr;
3873 if (op_exec_info.run_callbacks) {
3874 py_attr_value = PyTuple_New(len);
3875 }
3876 for (Py_ssize_t j = 0; j < len; j++) {
3877 PyObject* py_input = fast_input_array[j];
3878 tensorflow::Safe_PyObjectPtr py_eager_tensor;
3879 if (!ConvertToTensor(
3880 op_exec_info, py_input, &py_eager_tensor,
3881 []() { return tensorflow::DT_INVALID; },
3882 [](const tensorflow::DataType dtype) {}, status)) {
3883 return nullptr;
3884 }
3885
3886 TFE_TensorHandle* input_handle =
3887 EagerTensor_Handle(py_eager_tensor.get());
3888
3889 attr_value[j] = TFE_TensorHandleDataType(input_handle);
3890
3891 TFE_OpAddInput(op, input_handle, status);
3892 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3893 return nullptr;
3894 }
3895
3896 if (op_exec_info.run_callbacks) {
3897 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3898
3899 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3900 }
3901 }
3902 if (op_exec_info.run_callbacks) {
3903 flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3904 flattened_attrs->emplace_back(py_attr_value);
3905 }
3906 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3907 attr_value.size());
3908 attr_list_sizes[attr_name] = len;
3909 } else {
3910 // The item is a single item.
3911 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3912 flattened_attrs.get(), flattened_inputs.get(), op,
3913 status)) {
3914 return nullptr;
3915 }
3916 }
3917 }
3918
3919 int64_t num_outputs = 0;
3920 for (int i = 0; i < op_def->output_arg_size(); i++) {
3921 const auto& output_arg = op_def->output_arg(i);
3922 int64_t delta = 1;
3923 if (!output_arg.number_attr().empty()) {
3924 delta = attr_list_sizes[output_arg.number_attr()];
3925 } else if (!output_arg.type_list_attr().empty()) {
3926 delta = attr_list_sizes[output_arg.type_list_attr()];
3927 }
3928 if (delta < 0) {
3929 RaiseFallbackException(
3930 "Attributes suggest that the size of an output list is less than 0");
3931 return nullptr;
3932 }
3933 num_outputs += delta;
3934 }
3935
3936 // If number of retvals is larger than int32, we error out.
3937 if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3938 PyErr_SetString(
3939 PyExc_ValueError,
3940 Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3941 return nullptr;
3942 }
3943 int num_retvals = num_outputs;
3944
3945 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3946
3947 Py_BEGIN_ALLOW_THREADS;
3948 TFE_Execute(op, retvals.data(), &num_retvals, status);
3949 Py_END_ALLOW_THREADS;
3950
3951 if (!status->status.ok()) {
3952 // Augment the status with the op_name for easier debugging similar to
3953 // TFE_Py_Execute.
3954 status->status = tensorflow::errors::CreateWithUpdatedMessage(
3955 status->status, tensorflow::strings::StrCat(
3956 TF_Message(status), " [Op:",
3957 TFE_GetPythonString(op_exec_info.op_name), "]"));
3958 tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3959 return nullptr;
3960 }
3961
3962 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3963 for (int i = 0; i < num_retvals; ++i) {
3964 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3965 }
3966
3967 if (op_exec_info.run_callbacks) {
3968 if (!RunCallbacks(
3969 op_exec_info, args,
3970 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3971 *flattened_inputs, *flattened_attrs, flat_result.get())) {
3972 return nullptr;
3973 }
3974 }
3975
3976 // Unflatten results.
3977 if (op_def->output_arg_size() == 0) {
3978 Py_RETURN_NONE;
3979 }
3980
3981 if (op_def->output_arg_size() == 1) {
3982 if (!op_def->output_arg(0).number_attr().empty() ||
3983 !op_def->output_arg(0).type_list_attr().empty()) {
3984 return flat_result.release();
3985 } else {
3986 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3987 Py_INCREF(result);
3988 return result;
3989 }
3990 }
3991
3992 // Correctly output the results that are made into a namedtuple.
3993 PyObject* result = PyList_New(op_def->output_arg_size());
3994 int flat_result_index = 0;
3995 for (int i = 0; i < op_def->output_arg_size(); i++) {
3996 if (!op_def->output_arg(i).number_attr().empty()) {
3997 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3998 PyObject* inner_list = PyList_New(list_length);
3999 for (int j = 0; j < list_length; j++) {
4000 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4001 Py_INCREF(obj);
4002 PyList_SET_ITEM(inner_list, j, obj);
4003 }
4004 PyList_SET_ITEM(result, i, inner_list);
4005 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
4006 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
4007 PyObject* inner_list = PyList_New(list_length);
4008 for (int j = 0; j < list_length; j++) {
4009 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4010 Py_INCREF(obj);
4011 PyList_SET_ITEM(inner_list, j, obj);
4012 }
4013 PyList_SET_ITEM(result, i, inner_list);
4014 } else {
4015 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4016 Py_INCREF(obj);
4017 PyList_SET_ITEM(result, i, obj);
4018 }
4019 }
4020 return result;
4021}
4022
4023PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
4024 PyObject* attrs, PyObject* results,
4025 PyObject* forward_pass_name_scope) {
4026 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
4027 Py_RETURN_NONE;
4028 }
4029
4030 return RecordGradient(op_name, inputs, attrs, results,
4031 forward_pass_name_scope);
4032}
4033
4034// A method prints incoming messages directly to Python's
4035// stdout using Python's C API. This is necessary in Jupyter notebooks
4036// and colabs where messages to the C stdout don't go to the notebook
4037// cell outputs, but calls to Python's stdout do.
4038void PrintToPythonStdout(const char* msg) {
4039 if (Py_IsInitialized()) {
4040 PyGILState_STATE py_threadstate;
4041 py_threadstate = PyGILState_Ensure();
4042
4043 string string_msg = msg;
4044 // PySys_WriteStdout truncates strings over 1000 bytes, so
4045 // we write the message in chunks small enough to not be truncated.
4046 int CHUNK_SIZE = 900;
4047 auto len = string_msg.length();
4048 for (int i = 0; i < len; i += CHUNK_SIZE) {
4049 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4050 }
4051
4052 // Force flushing to make sure print newlines aren't interleaved in
4053 // some colab environments
4054 PyRun_SimpleString("import sys; sys.stdout.flush()");
4055
4056 PyGILState_Release(py_threadstate);
4057 }
4058}
4059
4060// Register PrintToPythonStdout as a log listener, to allow
4061// printing in colabs and jupyter notebooks to work.
4062void TFE_Py_EnableInteractivePythonLogging() {
4063 static bool enabled_interactive_logging = false;
4064 if (!enabled_interactive_logging) {
4065 enabled_interactive_logging = true;
4066 TF_RegisterLogListener(PrintToPythonStdout);
4067 }
4068}
4069
4070namespace {
4071// TODO(mdan): Clean this. Maybe by decoupling context lifetime from Python GC?
4072// Weak reference to the Python Context (see tensorflow/python/eager/context.py)
4073// object currently active. This object is opaque and wrapped inside a Python
4074// Capsule. However, the EagerContext object it holds is tracked by the
4075// global_c_eager_context object.
4076// Also see common_runtime/eager/context.cc.
4077PyObject* global_py_eager_context = nullptr;
4078} // namespace
4079
4080PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4081 Py_XDECREF(global_py_eager_context);
4082 global_py_eager_context = PyWeakref_NewRef(py_context, nullptr);
4083 if (global_py_eager_context == nullptr) {
4084 return nullptr;
4085 }
4086 Py_RETURN_NONE;
4087}
4088
4089PyObject* GetPyEagerContext() {
4090 if (global_py_eager_context == nullptr) {
4091 PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4092 return nullptr;
4093 }
4094 PyObject* py_context = PyWeakref_GET_OBJECT(global_py_eager_context);
4095 if (py_context == Py_None) {
4096 PyErr_SetString(PyExc_RuntimeError,
4097 "Python eager context has been destroyed");
4098 return nullptr;
4099 }
4100 Py_INCREF(py_context);
4101 return py_context;
4102}
4103
4104namespace {
4105
4106// Default values for thread_local_data fields.
4107struct EagerContextThreadLocalDataDefaults {
4108 tensorflow::Safe_PyObjectPtr is_eager;
4109 tensorflow::Safe_PyObjectPtr device_spec;
4110};
4111
4112// Maps each py_eager_context object to its thread_local_data.
4113//
4114// Note: we need to use the python Context object as the key here (and not
4115// its handle object), because the handle object isn't created until the
4116// context is initialized; but thread_local_data is potentially accessed
4117// before then.
4118using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4119 PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4120thread_local EagerContextThreadLocalDataMap*
4121 eager_context_thread_local_data_map = nullptr;
4122
4123// Maps each py_eager_context object to default values.
4124using EagerContextThreadLocalDataDefaultsMap =
4125 absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4126EagerContextThreadLocalDataDefaultsMap*
4127 eager_context_thread_local_data_defaults = nullptr;
4128
4129} // namespace
4130
4131namespace tensorflow {
4132
4133void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4134 PyObject* is_eager,
4135 PyObject* device_spec) {
4136 DCheckPyGilState();
4137 if (eager_context_thread_local_data_defaults == nullptr) {
4138 absl::LeakCheckDisabler disabler;
4139 eager_context_thread_local_data_defaults =
4140 new EagerContextThreadLocalDataDefaultsMap();
4141 }
4142 if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4143 PyErr_SetString(PyExc_AssertionError,
4144 "MakeEagerContextThreadLocalData may not be called "
4145 "twice on the same eager Context object.");
4146 }
4147
4148 auto& defaults =
4149 (*eager_context_thread_local_data_defaults)[py_eager_context];
4150 Py_INCREF(is_eager);
4151 defaults.is_eager.reset(is_eager);
4152 Py_INCREF(device_spec);
4153 defaults.device_spec.reset(device_spec);
4154}
4155
4156EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4157 PyObject* py_eager_context) {
4158 if (eager_context_thread_local_data_defaults == nullptr) {
4159 PyErr_SetString(PyExc_AssertionError,
4160 "MakeEagerContextThreadLocalData must be called "
4161 "before GetEagerContextThreadLocalData.");
4162 return nullptr;
4163 }
4164 auto defaults =
4165 eager_context_thread_local_data_defaults->find(py_eager_context);
4166 if (defaults == eager_context_thread_local_data_defaults->end()) {
4167 PyErr_SetString(PyExc_AssertionError,
4168 "MakeEagerContextThreadLocalData must be called "
4169 "before GetEagerContextThreadLocalData.");
4170 return nullptr;
4171 }
4172
4173 if (eager_context_thread_local_data_map == nullptr) {
4174 absl::LeakCheckDisabler disabler;
4175 eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4176 }
4177 auto& thread_local_data =
4178 (*eager_context_thread_local_data_map)[py_eager_context];
4179
4180 if (!thread_local_data) {
4181 thread_local_data.reset(new EagerContextThreadLocalData());
4182
4183 Safe_PyObjectPtr is_eager(
4184 PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4185 if (!is_eager) return nullptr;
4186 thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4187
4188#if PY_MAJOR_VERSION >= 3
4189 PyObject* scope_name = PyUnicode_FromString("");
4190#else
4191 PyObject* scope_name = PyString_FromString("");
4192#endif
4193 thread_local_data->scope_name.reset(scope_name);
4194
4195#if PY_MAJOR_VERSION >= 3
4196 PyObject* device_name = PyUnicode_FromString("");
4197#else
4198 PyObject* device_name = PyString_FromString("");
4199#endif
4200 thread_local_data->device_name.reset(device_name);
4201
4202 Py_INCREF(defaults->second.device_spec.get());
4203 thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4204
4205 Py_INCREF(Py_None);
4206 thread_local_data->function_call_options.reset(Py_None);
4207
4208 Py_INCREF(Py_None);
4209 thread_local_data->executor.reset(Py_None);
4210
4211 thread_local_data->op_callbacks.reset(PyList_New(0));
4212 }
4213 return thread_local_data.get();
4214}
4215
4216void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4217 DCheckPyGilState();
4218 if (eager_context_thread_local_data_defaults) {
4219 eager_context_thread_local_data_defaults->erase(py_eager_context);
4220 }
4221 if (eager_context_thread_local_data_map) {
4222 eager_context_thread_local_data_map->erase(py_eager_context);
4223 }
4224}
4225
4226} // namespace tensorflow
4227