1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <atomic> |
17 | #include <cstring> |
18 | #include <unordered_map> |
19 | |
20 | #include "absl/debugging/leak_check.h" |
21 | #include "absl/strings/str_cat.h" |
22 | #include "absl/strings/str_replace.h" |
23 | #include "absl/types/variant.h" |
24 | #include "tensorflow/c/c_api.h" |
25 | #include "tensorflow/c/c_api_internal.h" |
26 | #include "tensorflow/c/eager/c_api.h" |
27 | #include "tensorflow/c/eager/c_api_internal.h" |
28 | #include "tensorflow/c/eager/tape.h" |
29 | #include "tensorflow/c/eager/tfe_context_internal.h" |
30 | #include "tensorflow/c/eager/tfe_op_internal.h" |
31 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
32 | #include "tensorflow/c/tf_status.h" |
33 | #include "tensorflow/core/framework/types.pb.h" |
34 | #include "tensorflow/core/lib/core/errors.h" |
35 | #include "tensorflow/core/lib/gtl/cleanup.h" |
36 | #include "tensorflow/core/lib/gtl/compactptrset.h" |
37 | #include "tensorflow/core/lib/gtl/flatmap.h" |
38 | #include "tensorflow/core/lib/gtl/flatset.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/lib/strings/stringprintf.h" |
41 | #include "tensorflow/core/platform/casts.h" |
42 | #include "tensorflow/core/platform/errors.h" |
43 | #include "tensorflow/core/platform/mutex.h" |
44 | #include "tensorflow/core/platform/protobuf.h" |
45 | #include "tensorflow/core/platform/status.h" |
46 | #include "tensorflow/core/platform/statusor.h" |
47 | #include "tensorflow/core/platform/types.h" |
48 | #include "tensorflow/core/profiler/lib/traceme.h" |
49 | #include "tensorflow/core/util/managed_stack_trace.h" |
50 | #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" |
51 | #include "tensorflow/python/eager/pywrap_tensor.h" |
52 | #include "tensorflow/python/eager/pywrap_tfe.h" |
53 | #include "tensorflow/python/lib/core/py_util.h" |
54 | #include "tensorflow/python/lib/core/safe_ptr.h" |
55 | #include "tensorflow/python/util/stack_trace.h" |
56 | #include "tensorflow/python/util/util.h" |
57 | |
58 | using tensorflow::Status; |
59 | using tensorflow::string; |
60 | using tensorflow::strings::Printf; |
61 | |
62 | namespace { |
63 | // NOTE: Items are retrieved from and returned to these unique_ptrs, and they |
64 | // act as arenas. This is important if the same thread requests 2 items without |
65 | // releasing one. |
66 | // The following sequence of events on the same thread will still succeed: |
67 | // - GetOp <- Returns existing. |
68 | // - GetOp <- Allocates and returns a new pointer. |
69 | // - ReleaseOp <- Sets the item in the unique_ptr. |
70 | // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one. |
71 | // This occurs when a PyFunc kernel is run. This behavior makes it safe in that |
72 | // case, as well as the case where python decides to reuse the underlying |
73 | // C++ thread in 2 python threads case. |
74 | struct OpDeleter { |
75 | void operator()(TFE_Op* op) const { TFE_DeleteOp(op); } |
76 | }; |
77 | thread_local std::unordered_map<TFE_Context*, |
78 | std::unique_ptr<TFE_Op, OpDeleter>> |
79 | thread_local_eager_operation_map; // NOLINT |
80 | thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT |
81 | nullptr; |
82 | |
83 | std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) { |
84 | auto it = thread_local_eager_operation_map.find(ctx); |
85 | if (it == thread_local_eager_operation_map.end()) { |
86 | return nullptr; |
87 | } |
88 | return std::move(it->second); |
89 | } |
90 | |
91 | TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, |
92 | const char* raw_device_name, TF_Status* status) { |
93 | auto op = ReleaseThreadLocalOp(ctx); |
94 | if (!op) { |
95 | op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); |
96 | } |
97 | status->status = |
98 | tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); |
99 | if (!status->status.ok()) { |
100 | op.reset(); |
101 | } |
102 | return op.release(); |
103 | } |
104 | |
105 | void ReturnOp(TFE_Context* ctx, TFE_Op* op) { |
106 | if (op) { |
107 | tensorflow::unwrap(op)->Clear(); |
108 | thread_local_eager_operation_map[ctx].reset(op); |
109 | } |
110 | } |
111 | |
112 | TF_Status* ReleaseThreadLocalStatus() { |
113 | if (thread_local_tf_status == nullptr) { |
114 | return nullptr; |
115 | } |
116 | return thread_local_tf_status.release(); |
117 | } |
118 | |
119 | struct InputInfo { |
120 | InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} |
121 | |
122 | int i; |
123 | bool is_list = false; |
124 | }; |
125 | |
126 | // Takes in output gradients, returns input gradients. |
127 | typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)> |
128 | PyBackwardFunction; |
129 | |
130 | using AttrToInputsMap = |
131 | tensorflow::gtl::FlatMap<string, |
132 | tensorflow::gtl::InlinedVector<InputInfo, 4>>; |
133 | |
134 | tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() { |
135 | static auto* all_attr_to_input_maps = |
136 | new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>; |
137 | return all_attr_to_input_maps; |
138 | } |
139 | |
140 | // This function doesn't use a lock, since we depend on the GIL directly. |
141 | AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) { |
142 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4 |
143 | DCHECK(PyGILState_Check()) |
144 | << "This function needs to hold the GIL when called." ; |
145 | #endif |
146 | auto* all_attr_to_input_maps = GetAllAttrToInputsMaps(); |
147 | auto* output = |
148 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name()); |
149 | if (output != nullptr) { |
150 | return output; |
151 | } |
152 | |
153 | std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap); |
154 | |
155 | // Store a list of InputIndex -> List of corresponding inputs. |
156 | for (int i = 0; i < op_def.input_arg_size(); i++) { |
157 | if (!op_def.input_arg(i).type_attr().empty()) { |
158 | auto it = m->find(op_def.input_arg(i).type_attr()); |
159 | if (it == m->end()) { |
160 | it = m->insert({op_def.input_arg(i).type_attr(), {}}).first; |
161 | } |
162 | it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty()); |
163 | } |
164 | } |
165 | |
166 | auto* retval = m.get(); |
167 | (*all_attr_to_input_maps)[op_def.name()] = m.release(); |
168 | |
169 | return retval; |
170 | } |
171 | |
172 | // This function doesn't use a lock, since we depend on the GIL directly. |
173 | tensorflow::gtl::FlatMap< |
174 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>* |
175 | GetAllAttrToDefaultsMaps() { |
176 | static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< |
177 | string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>; |
178 | return all_attr_to_defaults_maps; |
179 | } |
180 | |
181 | tensorflow::gtl::FlatMap<string, tensorflow::DataType>* |
182 | GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) { |
183 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4 |
184 | DCHECK(PyGILState_Check()) |
185 | << "This function needs to hold the GIL when called." ; |
186 | #endif |
187 | auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); |
188 | auto* output = |
189 | tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); |
190 | if (output != nullptr) { |
191 | return output; |
192 | } |
193 | |
194 | auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>; |
195 | |
196 | for (const auto& attr : op_def.attr()) { |
197 | if (attr.type() == "type" && attr.has_default_value()) { |
198 | new_map->insert({attr.name(), attr.default_value().type()}); |
199 | } |
200 | } |
201 | |
202 | (*all_attr_to_defaults_maps)[op_def.name()] = new_map; |
203 | |
204 | return new_map; |
205 | } |
206 | |
207 | struct FastPathOpExecInfo { |
208 | TFE_Context* ctx; |
209 | const char* device_name; |
210 | |
211 | bool run_callbacks; |
212 | bool run_post_exec_callbacks; |
213 | bool run_gradient_callback; |
214 | |
215 | // The op name of the main op being executed. |
216 | PyObject* name; |
217 | // The op type name of the main op being executed. |
218 | PyObject* op_name; |
219 | PyObject* callbacks; |
220 | |
221 | // All the args passed into the FastPathOpExecInfo. |
222 | PyObject* args; |
223 | |
224 | // DTypes can come from another input that has the same attr. So build that |
225 | // map. |
226 | const AttrToInputsMap* attr_to_inputs_map; |
227 | const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes; |
228 | tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes; |
229 | }; |
230 | |
231 | #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ |
232 | bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ |
233 | type* value) { \ |
234 | if (check_fn(py_value)) { \ |
235 | *value = static_cast<type>(parse_fn(py_value)); \ |
236 | return true; \ |
237 | } else { \ |
238 | TF_SetStatus(status, TF_INVALID_ARGUMENT, \ |
239 | tensorflow::strings::StrCat( \ |
240 | "Expecting " #type " value for attr ", key, ", got ", \ |
241 | py_value->ob_type->tp_name) \ |
242 | .c_str()); \ |
243 | return false; \ |
244 | } \ |
245 | } |
246 | |
247 | #if PY_MAJOR_VERSION >= 3 |
248 | PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) |
249 | PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong) |
250 | #else |
251 | PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) |
252 | #endif |
253 | PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) |
254 | #undef PARSE_VALUE |
255 | |
256 | #if PY_MAJOR_VERSION < 3 |
257 | bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status, |
258 | int64_t* value) { |
259 | if (PyInt_Check(py_value)) { |
260 | *value = static_cast<int64_t>(PyInt_AsLong(py_value)); |
261 | return true; |
262 | } else if (PyLong_Check(py_value)) { |
263 | *value = static_cast<int64_t>(PyLong_AsLong(py_value)); |
264 | return true; |
265 | } |
266 | TF_SetStatus( |
267 | status, TF_INVALID_ARGUMENT, |
268 | tensorflow::strings::StrCat("Expecting int or long value for attr " , key, |
269 | ", got " , py_value->ob_type->tp_name) |
270 | .c_str()); |
271 | return false; |
272 | } |
273 | #endif |
274 | |
275 | Py_ssize_t TensorShapeNumDims(PyObject* value) { |
276 | const auto size = PySequence_Size(value); |
277 | if (size == -1) { |
278 | // TensorShape.__len__ raises an error in the scenario where the shape is an |
279 | // unknown, which needs to be cleared. |
280 | // TODO(nareshmodi): ensure that this is actually a TensorShape. |
281 | PyErr_Clear(); |
282 | } |
283 | return size; |
284 | } |
285 | |
286 | bool IsInteger(PyObject* py_value) { |
287 | #if PY_MAJOR_VERSION >= 3 |
288 | return PyLong_Check(py_value); |
289 | #else |
290 | return PyInt_Check(py_value) || PyLong_Check(py_value); |
291 | #endif |
292 | } |
293 | |
294 | // This function considers a Dimension._value of None to be valid, and sets the |
295 | // value to be -1 in that case. |
296 | bool ParseDimensionValue(const string& key, PyObject* py_value, |
297 | TF_Status* status, int64_t* value) { |
298 | if (IsInteger(py_value)) { |
299 | return ParseInt64Value(key, py_value, status, value); |
300 | } |
301 | |
302 | tensorflow::Safe_PyObjectPtr dimension_value( |
303 | PyObject_GetAttrString(py_value, "_value" )); |
304 | if (dimension_value == nullptr) { |
305 | PyErr_Clear(); |
306 | TF_SetStatus( |
307 | status, TF_INVALID_ARGUMENT, |
308 | tensorflow::strings::StrCat("Expecting a Dimension for attr " , key, |
309 | ", got " , py_value->ob_type->tp_name) |
310 | .c_str()); |
311 | return false; |
312 | } |
313 | |
314 | if (dimension_value.get() == Py_None) { |
315 | *value = -1; |
316 | return true; |
317 | } |
318 | |
319 | return ParseInt64Value(key, dimension_value.get(), status, value); |
320 | } |
321 | |
322 | bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, |
323 | tensorflow::StringPiece* value) { |
324 | if (PyBytes_Check(py_value)) { |
325 | Py_ssize_t size = 0; |
326 | char* buf = nullptr; |
327 | if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; |
328 | *value = tensorflow::StringPiece(buf, size); |
329 | return true; |
330 | } |
331 | #if PY_MAJOR_VERSION >= 3 |
332 | if (PyUnicode_Check(py_value)) { |
333 | Py_ssize_t size = 0; |
334 | const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); |
335 | if (buf == nullptr) return false; |
336 | *value = tensorflow::StringPiece(buf, size); |
337 | return true; |
338 | } |
339 | #endif |
340 | TF_SetStatus( |
341 | status, TF_INVALID_ARGUMENT, |
342 | tensorflow::strings::StrCat("Expecting a string value for attr " , key, |
343 | ", got " , py_value->ob_type->tp_name) |
344 | .c_str()); |
345 | return false; |
346 | } |
347 | |
348 | bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, |
349 | unsigned char* value) { |
350 | if (PyBool_Check(py_value)) { |
351 | *value = PyObject_IsTrue(py_value); |
352 | return true; |
353 | } |
354 | TF_SetStatus( |
355 | status, TF_INVALID_ARGUMENT, |
356 | tensorflow::strings::StrCat("Expecting bool value for attr " , key, |
357 | ", got " , py_value->ob_type->tp_name) |
358 | .c_str()); |
359 | return false; |
360 | } |
361 | |
362 | // The passed in py_value is expected to be an object of the python type |
363 | // dtypes.DType or an int. |
364 | bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, |
365 | int* value) { |
366 | if (IsInteger(py_value)) { |
367 | return ParseIntValue(key, py_value, status, value); |
368 | } |
369 | |
370 | tensorflow::Safe_PyObjectPtr py_type_enum( |
371 | PyObject_GetAttrString(py_value, "_type_enum" )); |
372 | if (py_type_enum == nullptr) { |
373 | PyErr_Clear(); |
374 | TF_SetStatus( |
375 | status, TF_INVALID_ARGUMENT, |
376 | tensorflow::strings::StrCat("Expecting a DType.dtype for attr " , key, |
377 | ", got " , py_value->ob_type->tp_name) |
378 | .c_str()); |
379 | return false; |
380 | } |
381 | |
382 | return ParseIntValue(key, py_type_enum.get(), status, value); |
383 | } |
384 | |
385 | bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key, |
386 | PyObject* py_list, TF_AttrType type, |
387 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, |
388 | TF_Status* status) { |
389 | if (!PySequence_Check(py_list)) { |
390 | TF_SetStatus( |
391 | status, TF_INVALID_ARGUMENT, |
392 | tensorflow::strings::StrCat("Expecting sequence value for attr " , key, |
393 | ", got " , py_list->ob_type->tp_name) |
394 | .c_str()); |
395 | return false; |
396 | } |
397 | const int num_values = PySequence_Size(py_list); |
398 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; |
399 | |
400 | #define PARSE_LIST(c_type, parse_fn) \ |
401 | std::unique_ptr<c_type[]> values(new c_type[num_values]); \ |
402 | for (int i = 0; i < num_values; ++i) { \ |
403 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \ |
404 | if (py_value == nullptr) { \ |
405 | TF_SetStatus(status, TF_INVALID_ARGUMENT, \ |
406 | tensorflow::strings::StrCat( \ |
407 | "Expecting sequence of " #c_type " for attr ", key, \ |
408 | ", got ", py_list->ob_type->tp_name) \ |
409 | .c_str()); \ |
410 | return false; \ |
411 | } else if (!parse_fn(key, py_value.get(), status, &values[i])) { \ |
412 | return false; \ |
413 | } \ |
414 | } |
415 | |
416 | if (type == TF_ATTR_STRING) { |
417 | std::unique_ptr<const void*[]> values(new const void*[num_values]); |
418 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); |
419 | for (int i = 0; i < num_values; ++i) { |
420 | tensorflow::StringPiece value; |
421 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
422 | if (!ParseStringValue(key, py_value.get(), status, &value)) return false; |
423 | values[i] = value.data(); |
424 | lengths[i] = value.size(); |
425 | } |
426 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); |
427 | } else if (type == TF_ATTR_INT) { |
428 | PARSE_LIST(int64_t, ParseInt64Value); |
429 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); |
430 | } else if (type == TF_ATTR_FLOAT) { |
431 | PARSE_LIST(float, ParseFloatValue); |
432 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); |
433 | } else if (type == TF_ATTR_BOOL) { |
434 | PARSE_LIST(unsigned char, ParseBoolValue); |
435 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); |
436 | } else if (type == TF_ATTR_TYPE) { |
437 | PARSE_LIST(int, ParseTypeValue); |
438 | TFE_OpSetAttrTypeList(op, key, |
439 | reinterpret_cast<const TF_DataType*>(values.get()), |
440 | num_values); |
441 | } else if (type == TF_ATTR_SHAPE) { |
442 | // Make one pass through the input counting the total number of |
443 | // dims across all the input lists. |
444 | int total_dims = 0; |
445 | for (int i = 0; i < num_values; ++i) { |
446 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
447 | if (py_value.get() != Py_None) { |
448 | if (!PySequence_Check(py_value.get())) { |
449 | TF_SetStatus( |
450 | status, TF_INVALID_ARGUMENT, |
451 | tensorflow::strings::StrCat( |
452 | "Expecting None or sequence value for element" , i, |
453 | " of attr " , key, ", got " , py_value->ob_type->tp_name) |
454 | .c_str()); |
455 | return false; |
456 | } |
457 | const auto size = TensorShapeNumDims(py_value.get()); |
458 | if (size >= 0) { |
459 | total_dims += size; |
460 | } |
461 | } |
462 | } |
463 | // Allocate a buffer that can fit all of the dims together. |
464 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); |
465 | // Copy the input dims into the buffer and set dims to point to |
466 | // the start of each list's dims. |
467 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); |
468 | std::unique_ptr<int[]> num_dims(new int[num_values]); |
469 | int64_t* offset = buffer.get(); |
470 | for (int i = 0; i < num_values; ++i) { |
471 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
472 | if (py_value.get() == Py_None) { |
473 | dims[i] = nullptr; |
474 | num_dims[i] = -1; |
475 | } else { |
476 | const auto size = TensorShapeNumDims(py_value.get()); |
477 | if (size == -1) { |
478 | dims[i] = nullptr; |
479 | num_dims[i] = -1; |
480 | continue; |
481 | } |
482 | dims[i] = offset; |
483 | num_dims[i] = size; |
484 | for (int j = 0; j < size; ++j) { |
485 | tensorflow::Safe_PyObjectPtr inner_py_value( |
486 | PySequence_ITEM(py_value.get(), j)); |
487 | if (inner_py_value.get() == Py_None) { |
488 | *offset = -1; |
489 | } else if (!ParseDimensionValue(key, inner_py_value.get(), status, |
490 | offset)) { |
491 | return false; |
492 | } |
493 | ++offset; |
494 | } |
495 | } |
496 | } |
497 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, |
498 | status); |
499 | if (!status->status.ok()) return false; |
500 | } else if (type == TF_ATTR_FUNC) { |
501 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); |
502 | for (int i = 0; i < num_values; ++i) { |
503 | tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
504 | // Allow: |
505 | // (1) String function name, OR |
506 | // (2) A Python object with a .name attribute |
507 | // (A crude test for being a |
508 | // tensorflow.python.framework.function._DefinedFunction) |
509 | // (which is what the various "defun" or "Defun" decorators do). |
510 | // And in the future also allow an object that can encapsulate |
511 | // the function name and its attribute values. |
512 | tensorflow::StringPiece func_name; |
513 | if (!ParseStringValue(key, py_value.get(), status, &func_name)) { |
514 | PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name" ); |
515 | if (name_attr == nullptr || |
516 | !ParseStringValue(key, name_attr, status, &func_name)) { |
517 | TF_SetStatus( |
518 | status, TF_INVALID_ARGUMENT, |
519 | tensorflow::strings::StrCat( |
520 | "unable to set function value attribute from a " , |
521 | py_value.get()->ob_type->tp_name, |
522 | " object. If you think this is an error, please file an " |
523 | "issue at " |
524 | "https://github.com/tensorflow/tensorflow/issues/new" ) |
525 | .c_str()); |
526 | return false; |
527 | } |
528 | } |
529 | funcs[i] = TFE_NewOp(ctx, func_name.data(), status); |
530 | if (!status->status.ok()) return false; |
531 | } |
532 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); |
533 | if (!status->status.ok()) return false; |
534 | } else { |
535 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
536 | tensorflow::strings::StrCat("Attr " , key, |
537 | " has unhandled list type " , type) |
538 | .c_str()); |
539 | return false; |
540 | } |
541 | #undef PARSE_LIST |
542 | return true; |
543 | } |
544 | |
545 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, |
546 | TF_Status* status) { |
547 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); |
548 | for (const auto& attr : func.attr()) { |
549 | if (!status->status.ok()) return nullptr; |
550 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); |
551 | if (!status->status.ok()) return nullptr; |
552 | } |
553 | return func_op; |
554 | } |
555 | |
556 | void SetOpAttrListDefault( |
557 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, |
558 | const char* key, TF_AttrType type, |
559 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, |
560 | TF_Status* status) { |
561 | if (type == TF_ATTR_STRING) { |
562 | int num_values = attr.default_value().list().s_size(); |
563 | std::unique_ptr<const void*[]> values(new const void*[num_values]); |
564 | std::unique_ptr<size_t[]> lengths(new size_t[num_values]); |
565 | (*attr_list_sizes)[key] = num_values; |
566 | for (int i = 0; i < num_values; i++) { |
567 | const string& v = attr.default_value().list().s(i); |
568 | values[i] = v.data(); |
569 | lengths[i] = v.size(); |
570 | } |
571 | TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); |
572 | } else if (type == TF_ATTR_INT) { |
573 | int num_values = attr.default_value().list().i_size(); |
574 | std::unique_ptr<int64_t[]> values(new int64_t[num_values]); |
575 | (*attr_list_sizes)[key] = num_values; |
576 | for (int i = 0; i < num_values; i++) { |
577 | values[i] = attr.default_value().list().i(i); |
578 | } |
579 | TFE_OpSetAttrIntList(op, key, values.get(), num_values); |
580 | } else if (type == TF_ATTR_FLOAT) { |
581 | int num_values = attr.default_value().list().f_size(); |
582 | std::unique_ptr<float[]> values(new float[num_values]); |
583 | (*attr_list_sizes)[key] = num_values; |
584 | for (int i = 0; i < num_values; i++) { |
585 | values[i] = attr.default_value().list().f(i); |
586 | } |
587 | TFE_OpSetAttrFloatList(op, key, values.get(), num_values); |
588 | } else if (type == TF_ATTR_BOOL) { |
589 | int num_values = attr.default_value().list().b_size(); |
590 | std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); |
591 | (*attr_list_sizes)[key] = num_values; |
592 | for (int i = 0; i < num_values; i++) { |
593 | values[i] = attr.default_value().list().b(i); |
594 | } |
595 | TFE_OpSetAttrBoolList(op, key, values.get(), num_values); |
596 | } else if (type == TF_ATTR_TYPE) { |
597 | int num_values = attr.default_value().list().type_size(); |
598 | std::unique_ptr<int[]> values(new int[num_values]); |
599 | (*attr_list_sizes)[key] = num_values; |
600 | for (int i = 0; i < num_values; i++) { |
601 | values[i] = attr.default_value().list().type(i); |
602 | } |
603 | TFE_OpSetAttrTypeList(op, key, |
604 | reinterpret_cast<const TF_DataType*>(values.get()), |
605 | attr.default_value().list().type_size()); |
606 | } else if (type == TF_ATTR_SHAPE) { |
607 | int num_values = attr.default_value().list().shape_size(); |
608 | (*attr_list_sizes)[key] = num_values; |
609 | int total_dims = 0; |
610 | for (int i = 0; i < num_values; ++i) { |
611 | if (!attr.default_value().list().shape(i).unknown_rank()) { |
612 | total_dims += attr.default_value().list().shape(i).dim_size(); |
613 | } |
614 | } |
615 | // Allocate a buffer that can fit all of the dims together. |
616 | std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); |
617 | // Copy the input dims into the buffer and set dims to point to |
618 | // the start of each list's dims. |
619 | std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); |
620 | std::unique_ptr<int[]> num_dims(new int[num_values]); |
621 | int64_t* offset = buffer.get(); |
622 | for (int i = 0; i < num_values; ++i) { |
623 | const auto& shape = attr.default_value().list().shape(i); |
624 | if (shape.unknown_rank()) { |
625 | dims[i] = nullptr; |
626 | num_dims[i] = -1; |
627 | } else { |
628 | for (int j = 0; j < shape.dim_size(); j++) { |
629 | *offset = shape.dim(j).size(); |
630 | ++offset; |
631 | } |
632 | } |
633 | } |
634 | TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, |
635 | status); |
636 | } else if (type == TF_ATTR_FUNC) { |
637 | int num_values = attr.default_value().list().func_size(); |
638 | (*attr_list_sizes)[key] = num_values; |
639 | std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); |
640 | for (int i = 0; i < num_values; i++) { |
641 | funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); |
642 | } |
643 | TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); |
644 | } else { |
645 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
646 | "Lists of tensors are not yet implemented for default valued " |
647 | "attributes for an operation." ); |
648 | } |
649 | } |
650 | |
651 | bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, |
652 | PyObject* py_value, TF_AttrType type, |
653 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, |
654 | TF_Status* status) { |
655 | if (type == TF_ATTR_STRING) { |
656 | tensorflow::StringPiece value; |
657 | if (!ParseStringValue(key, py_value, status, &value)) return false; |
658 | TFE_OpSetAttrString(op, key, value.data(), value.size()); |
659 | } else if (type == TF_ATTR_INT) { |
660 | int64_t value; |
661 | if (!ParseInt64Value(key, py_value, status, &value)) return false; |
662 | TFE_OpSetAttrInt(op, key, value); |
663 | // attr_list_sizes is set for all int attributes (since at this point we are |
664 | // not aware if that attribute might be used to calculate the size of an |
665 | // output list or not). |
666 | if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; |
667 | } else if (type == TF_ATTR_FLOAT) { |
668 | float value; |
669 | if (!ParseFloatValue(key, py_value, status, &value)) return false; |
670 | TFE_OpSetAttrFloat(op, key, value); |
671 | } else if (type == TF_ATTR_BOOL) { |
672 | unsigned char value; |
673 | if (!ParseBoolValue(key, py_value, status, &value)) return false; |
674 | TFE_OpSetAttrBool(op, key, value); |
675 | } else if (type == TF_ATTR_TYPE) { |
676 | int value; |
677 | if (!ParseTypeValue(key, py_value, status, &value)) return false; |
678 | TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); |
679 | } else if (type == TF_ATTR_SHAPE) { |
680 | if (py_value == Py_None) { |
681 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); |
682 | } else { |
683 | if (!PySequence_Check(py_value)) { |
684 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
685 | tensorflow::strings::StrCat( |
686 | "Expecting None or sequence value for attr" , key, |
687 | ", got " , py_value->ob_type->tp_name) |
688 | .c_str()); |
689 | return false; |
690 | } |
691 | const auto num_dims = TensorShapeNumDims(py_value); |
692 | if (num_dims == -1) { |
693 | TFE_OpSetAttrShape(op, key, nullptr, -1, status); |
694 | return true; |
695 | } |
696 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); |
697 | for (int i = 0; i < num_dims; ++i) { |
698 | tensorflow::Safe_PyObjectPtr inner_py_value( |
699 | PySequence_ITEM(py_value, i)); |
700 | // If an error is generated when iterating through object, we can |
701 | // sometimes get a nullptr. |
702 | if (inner_py_value.get() == Py_None) { |
703 | dims[i] = -1; |
704 | } else if (inner_py_value.get() == nullptr || |
705 | !ParseDimensionValue(key, inner_py_value.get(), status, |
706 | &dims[i])) { |
707 | return false; |
708 | } |
709 | } |
710 | TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); |
711 | } |
712 | if (!status->status.ok()) return false; |
713 | } else if (type == TF_ATTR_FUNC) { |
714 | // Allow: |
715 | // (1) String function name, OR |
716 | // (2) A Python object with a .name attribute |
717 | // (A crude test for being a |
718 | // tensorflow.python.framework.function._DefinedFunction) |
719 | // (which is what the various "defun" or "Defun" decorators do). |
720 | // And in the future also allow an object that can encapsulate |
721 | // the function name and its attribute values. |
722 | tensorflow::StringPiece func_name; |
723 | if (!ParseStringValue(key, py_value, status, &func_name)) { |
724 | PyObject* name_attr = PyObject_GetAttrString(py_value, "name" ); |
725 | if (name_attr == nullptr || |
726 | !ParseStringValue(key, name_attr, status, &func_name)) { |
727 | TF_SetStatus( |
728 | status, TF_INVALID_ARGUMENT, |
729 | tensorflow::strings::StrCat( |
730 | "unable to set function value attribute from a " , |
731 | py_value->ob_type->tp_name, |
732 | " object. If you think this is an error, please file an issue " |
733 | "at https://github.com/tensorflow/tensorflow/issues/new" ) |
734 | .c_str()); |
735 | return false; |
736 | } |
737 | } |
738 | TF_SetStatus(status, TF_OK, "" ); |
739 | TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); |
740 | } else { |
741 | TF_SetStatus( |
742 | status, TF_UNIMPLEMENTED, |
743 | tensorflow::strings::StrCat("Attr " , key, " has unhandled type " , type) |
744 | .c_str()); |
745 | return false; |
746 | } |
747 | return true; |
748 | } |
749 | |
750 | void SetOpAttrScalarDefault( |
751 | TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, |
752 | const char* attr_name, |
753 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, |
754 | TF_Status* status) { |
755 | SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); |
756 | if (default_value.value_case() == tensorflow::AttrValue::kI) { |
757 | (*attr_list_sizes)[attr_name] = default_value.i(); |
758 | } |
759 | } |
760 | |
761 | // start_index is the index at which the Tuple/List attrs will start getting |
762 | // processed. |
763 | void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, |
764 | TF_Status* out_status) { |
765 | if (attrs == Py_None) return; |
766 | Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index; |
767 | if ((len & 1) != 0) { |
768 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, |
769 | "Expecting attrs tuple to have even length." ); |
770 | return; |
771 | } |
772 | // Parse attrs |
773 | for (Py_ssize_t i = 0; i < len; i += 2) { |
774 | PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i); |
775 | PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1); |
776 | #if PY_MAJOR_VERSION >= 3 |
777 | const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key) |
778 | : PyUnicode_AsUTF8(py_key); |
779 | #else |
780 | const char* key = PyBytes_AsString(py_key); |
781 | #endif |
782 | unsigned char is_list = 0; |
783 | const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); |
784 | if (!out_status->status.ok()) return; |
785 | if (is_list != 0) { |
786 | if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status)) |
787 | return; |
788 | } else { |
789 | if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) |
790 | return; |
791 | } |
792 | } |
793 | } |
794 | |
795 | // This function will set the op attrs required. If an attr has the value of |
796 | // None, then it will read the AttrDef to get the default value and set that |
797 | // instead. Any failure in this function will simply fall back to the slow |
798 | // path. |
799 | void SetOpAttrWithDefaults( |
800 | TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, |
801 | const char* attr_name, PyObject* attr_value, |
802 | tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes, |
803 | TF_Status* status) { |
804 | unsigned char is_list = 0; |
805 | const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); |
806 | if (!status->status.ok()) return; |
807 | if (attr_value == Py_None) { |
808 | if (is_list != 0) { |
809 | SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, |
810 | status); |
811 | } else { |
812 | SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, |
813 | attr_list_sizes, status); |
814 | } |
815 | } else { |
816 | if (is_list != 0) { |
817 | SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, |
818 | status); |
819 | } else { |
820 | SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, |
821 | status); |
822 | } |
823 | } |
824 | } |
825 | |
826 | PyObject* GetPythonObjectFromInt(int num) { |
827 | #if PY_MAJOR_VERSION >= 3 |
828 | return PyLong_FromLong(num); |
829 | #else |
830 | return PyInt_FromLong(num); |
831 | #endif |
832 | } |
833 | |
834 | // Python subclass of Exception that is created on not ok Status. |
835 | tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); |
836 | PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr; |
837 | |
838 | // Python subclass of Exception that is created to signal fallback. |
839 | PyObject* fallback_exception_class = nullptr; |
840 | |
841 | // Python function that returns input gradients given output gradients. |
842 | PyObject* gradient_function = nullptr; |
843 | |
844 | // Python function that returns output gradients given input gradients. |
845 | PyObject* forward_gradient_function = nullptr; |
846 | |
847 | static std::atomic<int64_t> _uid; |
848 | |
849 | // This struct is responsible for marking thread_local storage as destroyed. |
850 | // Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker |
851 | // is safe because it's a trivial type, so long as nobody creates a new |
852 | // thread_local in the space where now-destroyed marker used to be. |
853 | // Hopefully creating new thread_locals while destructing a thread is rare. |
854 | struct ThreadLocalDestructionMarker { |
855 | ~ThreadLocalDestructionMarker() { alive = false; } |
856 | bool alive = true; |
857 | }; |
858 | |
859 | } // namespace |
860 | |
861 | TF_Status* GetStatus() { |
862 | TF_Status* maybe_status = ReleaseThreadLocalStatus(); |
863 | if (maybe_status) { |
864 | TF_SetStatus(maybe_status, TF_OK, "" ); |
865 | return maybe_status; |
866 | } else { |
867 | return TF_NewStatus(); |
868 | } |
869 | } |
870 | |
871 | void ReturnStatus(TF_Status* status) { |
872 | TF_SetStatus(status, TF_OK, "" ); |
873 | thread_local_tf_status.reset(status); |
874 | } |
875 | |
876 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, |
877 | const char* op_name, TFE_InputTensorHandles* inputs, |
878 | PyObject* attrs, TFE_OutputTensorHandles* outputs, |
879 | TF_Status* out_status) { |
880 | TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, |
881 | /*cancellation_manager=*/nullptr, outputs, |
882 | out_status); |
883 | } |
884 | |
885 | void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, |
886 | const char* op_name, |
887 | TFE_InputTensorHandles* inputs, PyObject* attrs, |
888 | TFE_CancellationManager* cancellation_manager, |
889 | TFE_OutputTensorHandles* outputs, |
890 | TF_Status* out_status) { |
891 | tensorflow::profiler::TraceMe activity( |
892 | "TFE_Py_ExecuteCancelable" , tensorflow::profiler::TraceMeLevel::kInfo); |
893 | |
894 | TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); |
895 | |
896 | auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); |
897 | if (!out_status->status.ok()) return; |
898 | |
899 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( |
900 | tensorflow::StackTrace::kStackTraceInitialSize)); |
901 | |
902 | for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { |
903 | TFE_OpAddInput(op, inputs->at(i), out_status); |
904 | } |
905 | if (cancellation_manager && out_status->status.ok()) { |
906 | TFE_OpSetCancellationManager(op, cancellation_manager, out_status); |
907 | } |
908 | if (out_status->status.ok()) { |
909 | SetOpAttrs(ctx, op, attrs, 0, out_status); |
910 | } |
911 | Py_BEGIN_ALLOW_THREADS; |
912 | |
913 | int num_outputs = outputs->size(); |
914 | |
915 | if (out_status->status.ok()) { |
916 | TFE_Execute(op, outputs->data(), &num_outputs, out_status); |
917 | } |
918 | |
919 | if (out_status->status.ok()) { |
920 | outputs->resize(num_outputs); |
921 | } else { |
922 | TF_SetStatus(out_status, TF_GetCode(out_status), |
923 | tensorflow::strings::StrCat(TF_Message(out_status), |
924 | " [Op:" , op_name, "]" ) |
925 | .c_str()); |
926 | } |
927 | |
928 | Py_END_ALLOW_THREADS; |
929 | } |
930 | |
931 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { |
932 | tensorflow::mutex_lock l(exception_class_mutex); |
933 | if (exception_class != nullptr) { |
934 | Py_DECREF(exception_class); |
935 | } |
936 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { |
937 | exception_class = nullptr; |
938 | PyErr_SetString(PyExc_TypeError, |
939 | "TFE_Py_RegisterExceptionClass: " |
940 | "Registered class should be subclass of Exception." ); |
941 | return nullptr; |
942 | } |
943 | |
944 | Py_INCREF(e); |
945 | exception_class = e; |
946 | Py_RETURN_NONE; |
947 | } |
948 | |
949 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { |
950 | if (fallback_exception_class != nullptr) { |
951 | Py_DECREF(fallback_exception_class); |
952 | } |
953 | if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { |
954 | fallback_exception_class = nullptr; |
955 | PyErr_SetString(PyExc_TypeError, |
956 | "TFE_Py_RegisterFallbackExceptionClass: " |
957 | "Registered class should be subclass of Exception." ); |
958 | return nullptr; |
959 | } else { |
960 | Py_INCREF(e); |
961 | fallback_exception_class = e; |
962 | Py_RETURN_NONE; |
963 | } |
964 | } |
965 | |
966 | PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { |
967 | if (gradient_function != nullptr) { |
968 | Py_DECREF(gradient_function); |
969 | } |
970 | if (!PyCallable_Check(e)) { |
971 | gradient_function = nullptr; |
972 | PyErr_SetString(PyExc_TypeError, |
973 | "TFE_Py_RegisterGradientFunction: " |
974 | "Registered object should be function." ); |
975 | return nullptr; |
976 | } else { |
977 | Py_INCREF(e); |
978 | gradient_function = e; |
979 | Py_RETURN_NONE; |
980 | } |
981 | } |
982 | |
983 | PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) { |
984 | if (forward_gradient_function != nullptr) { |
985 | Py_DECREF(forward_gradient_function); |
986 | } |
987 | if (!PyCallable_Check(e)) { |
988 | forward_gradient_function = nullptr; |
989 | PyErr_SetString(PyExc_TypeError, |
990 | "TFE_Py_RegisterJVPFunction: " |
991 | "Registered object should be function." ); |
992 | return nullptr; |
993 | } else { |
994 | Py_INCREF(e); |
995 | forward_gradient_function = e; |
996 | Py_RETURN_NONE; |
997 | } |
998 | } |
999 | |
1000 | void RaiseFallbackException(const char* message) { |
1001 | if (fallback_exception_class != nullptr) { |
1002 | PyErr_SetString(fallback_exception_class, message); |
1003 | return; |
1004 | } |
1005 | |
1006 | PyErr_SetString( |
1007 | PyExc_RuntimeError, |
1008 | tensorflow::strings::StrCat( |
1009 | "Fallback exception type not set, attempting to fallback due to " , |
1010 | message) |
1011 | .data()); |
1012 | } |
1013 | |
1014 | // Format and return `status`' error message with the attached stack trace if |
1015 | // available. `status` must have an error. |
1016 | std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { |
1017 | tensorflow::DCheckPyGilState(); |
1018 | DCHECK(!status.ok()); |
1019 | |
1020 | std::vector<tensorflow::StackFrame> stack_trace = |
1021 | tensorflow::errors::GetStackTrace(status); |
1022 | |
1023 | if (stack_trace.empty()) return status.error_message(); |
1024 | |
1025 | PyObject* linecache = PyImport_ImportModule("linecache" ); |
1026 | PyObject* getline = |
1027 | PyObject_GetAttr(linecache, PyUnicode_FromString("getline" )); |
1028 | DCHECK(getline); |
1029 | |
1030 | std::ostringstream result; |
1031 | result << "Exception originated from\n\n" ; |
1032 | |
1033 | for (const tensorflow::StackFrame& stack_frame : stack_trace) { |
1034 | PyObject* line_str_obj = PyObject_CallFunction( |
1035 | getline, const_cast<char*>("si" ), stack_frame.file_name.c_str(), |
1036 | stack_frame.line_number); |
1037 | tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); |
1038 | tensorflow::str_util::RemoveWhitespaceContext(&line_str); |
1039 | result << " File \"" << stack_frame.file_name << "\", line " |
1040 | << stack_frame.line_number << ", in " << stack_frame.function_name |
1041 | << '\n'; |
1042 | |
1043 | if (!line_str.empty()) result << " " << line_str << '\n'; |
1044 | Py_XDECREF(line_str_obj); |
1045 | } |
1046 | |
1047 | Py_DecRef(getline); |
1048 | Py_DecRef(linecache); |
1049 | |
1050 | result << '\n' << status.error_message(); |
1051 | return result.str(); |
1052 | } |
1053 | |
1054 | namespace tensorflow { |
1055 | |
1056 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { |
1057 | if (status->status.ok()) return 0; |
1058 | const char* msg = TF_Message(status); |
1059 | if (exception == nullptr) { |
1060 | tensorflow::mutex_lock l(exception_class_mutex); |
1061 | if (exception_class != nullptr) { |
1062 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); |
1063 | for (const auto& payload : |
1064 | tensorflow::errors::GetPayloads(status->status)) { |
1065 | PyDict_SetItem(payloads.get(), |
1066 | PyBytes_FromString(payload.first.c_str()), |
1067 | PyBytes_FromString(payload.second.c_str())); |
1068 | } |
1069 | tensorflow::Safe_PyObjectPtr val(Py_BuildValue( |
1070 | "siO" , FormatErrorStatusStackTrace(status->status).c_str(), |
1071 | TF_GetCode(status), payloads.get())); |
1072 | if (PyErr_Occurred()) { |
1073 | // NOTE: This hides the actual error (i.e. the reason `status` was not |
1074 | // TF_OK), but there is nothing we can do at this point since we can't |
1075 | // generate a reasonable error from the status. |
1076 | // Consider adding a message explaining this. |
1077 | return -1; |
1078 | } |
1079 | PyErr_SetObject(exception_class, val.get()); |
1080 | return -1; |
1081 | } else { |
1082 | exception = PyExc_RuntimeError; |
1083 | } |
1084 | } |
1085 | // May be update already set exception. |
1086 | PyErr_SetString(exception, msg); |
1087 | return -1; |
1088 | } |
1089 | |
1090 | } // namespace tensorflow |
1091 | |
1092 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, |
1093 | PyObject* exception) { |
1094 | if (status.ok()) return 0; |
1095 | const char* msg = status.error_message().c_str(); |
1096 | if (exception == nullptr) { |
1097 | tensorflow::mutex_lock l(exception_class_mutex); |
1098 | if (exception_class != nullptr) { |
1099 | tensorflow::Safe_PyObjectPtr payloads(PyDict_New()); |
1100 | for (const auto& element : tensorflow::errors::GetPayloads(status)) { |
1101 | PyDict_SetItem(payloads.get(), |
1102 | PyBytes_FromString(element.first.c_str()), |
1103 | PyBytes_FromString(element.second.c_str())); |
1104 | } |
1105 | tensorflow::Safe_PyObjectPtr val( |
1106 | Py_BuildValue("siO" , FormatErrorStatusStackTrace(status).c_str(), |
1107 | status.code(), payloads.get())); |
1108 | PyErr_SetObject(exception_class, val.get()); |
1109 | return -1; |
1110 | } else { |
1111 | exception = PyExc_RuntimeError; |
1112 | } |
1113 | } |
1114 | // May be update already set exception. |
1115 | PyErr_SetString(exception, msg); |
1116 | return -1; |
1117 | } |
1118 | |
1119 | const char* TFE_GetPythonString(PyObject* o) { |
1120 | #if PY_MAJOR_VERSION >= 3 |
1121 | if (PyBytes_Check(o)) { |
1122 | return PyBytes_AsString(o); |
1123 | } else { |
1124 | return PyUnicode_AsUTF8(o); |
1125 | } |
1126 | #else |
1127 | return PyBytes_AsString(o); |
1128 | #endif |
1129 | } |
1130 | |
1131 | int64_t get_uid() { return _uid++; } |
1132 | |
1133 | PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } |
1134 | |
1135 | void TFE_DeleteContextCapsule(PyObject* context) { |
1136 | TFE_Context* ctx = |
1137 | reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); |
1138 | auto op = ReleaseThreadLocalOp(ctx); |
1139 | op.reset(); |
1140 | TFE_DeleteContext(ctx); |
1141 | } |
1142 | |
1143 | static int64_t MakeInt(PyObject* integer) { |
1144 | #if PY_MAJOR_VERSION >= 3 |
1145 | return PyLong_AsLong(integer); |
1146 | #else |
1147 | return PyInt_AsLong(integer); |
1148 | #endif |
1149 | } |
1150 | |
1151 | static int64_t FastTensorId(PyObject* tensor) { |
1152 | if (EagerTensor_CheckExact(tensor)) { |
1153 | return PyEagerTensor_ID(tensor); |
1154 | } |
1155 | PyObject* id_field = PyObject_GetAttrString(tensor, "_id" ); |
1156 | if (id_field == nullptr) { |
1157 | return -1; |
1158 | } |
1159 | int64_t id = MakeInt(id_field); |
1160 | Py_DECREF(id_field); |
1161 | return id; |
1162 | } |
1163 | |
1164 | namespace tensorflow { |
1165 | DataType PyTensor_DataType(PyObject* tensor) { |
1166 | if (EagerTensor_CheckExact(tensor)) { |
1167 | return PyEagerTensor_Dtype(tensor); |
1168 | } else { |
1169 | #if PY_MAJOR_VERSION < 3 |
1170 | // Python 2.x: |
1171 | static PyObject* dtype_attr = PyString_InternFromString("dtype" ); |
1172 | static PyObject* type_enum_attr = PyString_InternFromString("_type_enum" ); |
1173 | #else |
1174 | // Python 3.x: |
1175 | static PyObject* dtype_attr = PyUnicode_InternFromString("dtype" ); |
1176 | static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum" ); |
1177 | #endif |
1178 | Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr)); |
1179 | if (!dtype_field) { |
1180 | return DT_INVALID; |
1181 | } |
1182 | |
1183 | Safe_PyObjectPtr enum_field( |
1184 | PyObject_GetAttr(dtype_field.get(), type_enum_attr)); |
1185 | if (!enum_field) { |
1186 | return DT_INVALID; |
1187 | } |
1188 | |
1189 | return static_cast<DataType>(MakeInt(enum_field.get())); |
1190 | } |
1191 | } |
1192 | } // namespace tensorflow |
1193 | |
1194 | class PyTapeTensor { |
1195 | public: |
1196 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, |
1197 | const tensorflow::TensorShape& shape) |
1198 | : id_(id), dtype_(dtype), shape_(shape) {} |
1199 | PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape) |
1200 | : id_(id), dtype_(dtype), shape_(shape) { |
1201 | Py_INCREF(absl::get<1>(shape_)); |
1202 | } |
1203 | PyTapeTensor(const PyTapeTensor& other) { |
1204 | id_ = other.id_; |
1205 | dtype_ = other.dtype_; |
1206 | shape_ = other.shape_; |
1207 | if (shape_.index() == 1) { |
1208 | Py_INCREF(absl::get<1>(shape_)); |
1209 | } |
1210 | } |
1211 | |
1212 | ~PyTapeTensor() { |
1213 | if (shape_.index() == 1) { |
1214 | Py_DECREF(absl::get<1>(shape_)); |
1215 | } |
1216 | } |
1217 | PyObject* GetShape() const; |
1218 | PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); } |
1219 | int64_t GetID() const { return id_; } |
1220 | tensorflow::DataType GetDType() const { return dtype_; } |
1221 | |
1222 | PyObject* OnesLike() const; |
1223 | PyObject* ZerosLike() const; |
1224 | |
1225 | private: |
1226 | int64_t id_; |
1227 | tensorflow::DataType dtype_; |
1228 | |
1229 | // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that |
1230 | // PyObject is the tensor itself. This is used to support tf.shape(tensor) for |
1231 | // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype |
1232 | // tensors. |
1233 | absl::variant<tensorflow::TensorShape, PyObject*> shape_; |
1234 | }; |
1235 | |
1236 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor); |
1237 | |
1238 | class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, |
1239 | PyTapeTensor> { |
1240 | public: |
1241 | explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { |
1242 | Py_INCREF(py_vspace_); |
1243 | } |
1244 | |
1245 | tensorflow::Status Initialize() { |
1246 | num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn" ); |
1247 | if (num_elements_ == nullptr) { |
1248 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1249 | } |
1250 | aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn" ); |
1251 | if (aggregate_fn_ == nullptr) { |
1252 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1253 | } |
1254 | zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn" ); |
1255 | if (zeros_fn_ == nullptr) { |
1256 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1257 | } |
1258 | zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn" ); |
1259 | if (zeros_like_fn_ == nullptr) { |
1260 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1261 | } |
1262 | ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn" ); |
1263 | if (ones_fn_ == nullptr) { |
1264 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1265 | } |
1266 | ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn" ); |
1267 | if (ones_like_fn_ == nullptr) { |
1268 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1269 | } |
1270 | graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn" ); |
1271 | if (graph_shape_fn_ == nullptr) { |
1272 | return tensorflow::errors::InvalidArgument("invalid vspace" ); |
1273 | } |
1274 | return ::tensorflow::OkStatus(); |
1275 | } |
1276 | |
1277 | ~PyVSpace() override { |
1278 | Py_XDECREF(num_elements_); |
1279 | Py_XDECREF(aggregate_fn_); |
1280 | Py_XDECREF(zeros_fn_); |
1281 | Py_XDECREF(zeros_like_fn_); |
1282 | Py_XDECREF(ones_fn_); |
1283 | Py_XDECREF(ones_like_fn_); |
1284 | Py_XDECREF(graph_shape_fn_); |
1285 | |
1286 | Py_DECREF(py_vspace_); |
1287 | } |
1288 | |
1289 | int64_t NumElements(PyObject* tensor) const final { |
1290 | if (EagerTensor_CheckExact(tensor)) { |
1291 | return PyEagerTensor_NumElements(tensor); |
1292 | } |
1293 | PyObject* arglist = |
1294 | Py_BuildValue("(O)" , reinterpret_cast<PyObject*>(tensor)); |
1295 | PyObject* result = PyEval_CallObject(num_elements_, arglist); |
1296 | Py_DECREF(arglist); |
1297 | if (result == nullptr) { |
1298 | // The caller detects whether a python exception has been raised. |
1299 | return -1; |
1300 | } |
1301 | int64_t r = MakeInt(result); |
1302 | Py_DECREF(result); |
1303 | return r; |
1304 | } |
1305 | |
1306 | PyObject* AggregateGradients( |
1307 | tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { |
1308 | PyObject* list = PyList_New(gradient_tensors.size()); |
1309 | for (int i = 0; i < gradient_tensors.size(); ++i) { |
1310 | // Note: stealing a reference to the gradient tensors. |
1311 | CHECK(gradient_tensors[i] != nullptr); |
1312 | CHECK(gradient_tensors[i] != Py_None); |
1313 | PyList_SET_ITEM(list, i, |
1314 | reinterpret_cast<PyObject*>(gradient_tensors[i])); |
1315 | } |
1316 | PyObject* arglist = Py_BuildValue("(O)" , list); |
1317 | CHECK(arglist != nullptr); |
1318 | PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); |
1319 | Py_DECREF(arglist); |
1320 | Py_DECREF(list); |
1321 | return result; |
1322 | } |
1323 | |
1324 | int64_t TensorId(PyObject* tensor) const final { |
1325 | return FastTensorId(tensor); |
1326 | } |
1327 | |
1328 | void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } |
1329 | |
1330 | PyObject* Ones(PyObject* shape, PyObject* dtype) const { |
1331 | if (PyErr_Occurred()) { |
1332 | return nullptr; |
1333 | } |
1334 | PyObject* arg_list = Py_BuildValue("OO" , shape, dtype); |
1335 | PyObject* result = PyEval_CallObject(ones_fn_, arg_list); |
1336 | Py_DECREF(arg_list); |
1337 | return result; |
1338 | } |
1339 | |
1340 | PyObject* OnesLike(PyObject* tensor) const { |
1341 | if (PyErr_Occurred()) { |
1342 | return nullptr; |
1343 | } |
1344 | return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL); |
1345 | } |
1346 | |
1347 | // Builds a tensor filled with ones with the same shape and dtype as `t`. |
1348 | Status BuildOnesLike(const PyTapeTensor& t, |
1349 | PyObject** result) const override { |
1350 | *result = t.OnesLike(); |
1351 | return ::tensorflow::OkStatus(); |
1352 | } |
1353 | |
1354 | PyObject* Zeros(PyObject* shape, PyObject* dtype) const { |
1355 | if (PyErr_Occurred()) { |
1356 | return nullptr; |
1357 | } |
1358 | PyObject* arg_list = Py_BuildValue("OO" , shape, dtype); |
1359 | PyObject* result = PyEval_CallObject(zeros_fn_, arg_list); |
1360 | Py_DECREF(arg_list); |
1361 | return result; |
1362 | } |
1363 | |
1364 | PyObject* ZerosLike(PyObject* tensor) const { |
1365 | if (PyErr_Occurred()) { |
1366 | return nullptr; |
1367 | } |
1368 | return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL); |
1369 | } |
1370 | |
1371 | PyObject* GraphShape(PyObject* tensor) const { |
1372 | PyObject* arg_list = Py_BuildValue("(O)" , tensor); |
1373 | PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list); |
1374 | Py_DECREF(arg_list); |
1375 | return result; |
1376 | } |
1377 | |
1378 | tensorflow::Status CallBackwardFunction( |
1379 | const string& op_type, PyBackwardFunction* backward_function, |
1380 | const std::vector<int64_t>& unneeded_gradients, |
1381 | tensorflow::gtl::ArraySlice<PyObject*> output_gradients, |
1382 | absl::Span<PyObject*> result) const final { |
1383 | PyObject* grads = PyTuple_New(output_gradients.size()); |
1384 | for (int i = 0; i < output_gradients.size(); ++i) { |
1385 | if (output_gradients[i] == nullptr) { |
1386 | Py_INCREF(Py_None); |
1387 | PyTuple_SET_ITEM(grads, i, Py_None); |
1388 | } else { |
1389 | PyTuple_SET_ITEM(grads, i, |
1390 | reinterpret_cast<PyObject*>(output_gradients[i])); |
1391 | } |
1392 | } |
1393 | PyObject* py_result = (*backward_function)(grads, unneeded_gradients); |
1394 | Py_DECREF(grads); |
1395 | if (py_result == nullptr) { |
1396 | return tensorflow::errors::Internal("gradient function threw exceptions" ); |
1397 | } |
1398 | PyObject* seq = |
1399 | PySequence_Fast(py_result, "expected a sequence of gradients" ); |
1400 | if (seq == nullptr) { |
1401 | return tensorflow::errors::InvalidArgument( |
1402 | "gradient function did not return a list" ); |
1403 | } |
1404 | int len = PySequence_Fast_GET_SIZE(seq); |
1405 | if (len != result.size()) { |
1406 | return tensorflow::errors::Internal( |
1407 | "Recorded operation '" , op_type, |
1408 | "' returned too few gradients. Expected " , result.size(), |
1409 | " but received " , len); |
1410 | } |
1411 | PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
1412 | VLOG(1) << "Gradient length is " << len; |
1413 | for (int i = 0; i < len; ++i) { |
1414 | PyObject* item = seq_array[i]; |
1415 | if (item == Py_None) { |
1416 | result[i] = nullptr; |
1417 | } else { |
1418 | Py_INCREF(item); |
1419 | result[i] = item; |
1420 | } |
1421 | } |
1422 | Py_DECREF(seq); |
1423 | Py_DECREF(py_result); |
1424 | return ::tensorflow::OkStatus(); |
1425 | } |
1426 | |
1427 | void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } |
1428 | |
1429 | PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final { |
1430 | return TapeTensorFromTensor(tensor); |
1431 | } |
1432 | |
1433 | private: |
1434 | PyObject* py_vspace_; |
1435 | |
1436 | PyObject* num_elements_; |
1437 | PyObject* aggregate_fn_; |
1438 | PyObject* zeros_fn_; |
1439 | PyObject* zeros_like_fn_; |
1440 | PyObject* ones_fn_; |
1441 | PyObject* ones_like_fn_; |
1442 | PyObject* graph_shape_fn_; |
1443 | }; |
1444 | PyVSpace* py_vspace = nullptr; |
1445 | |
1446 | bool HasAccumulator(); |
1447 | |
1448 | PyObject* TFE_Py_RegisterVSpace(PyObject* e) { |
1449 | if (py_vspace != nullptr) { |
1450 | if (HasAccumulator()) { |
1451 | // Accumulators reference py_vspace, so we can't swap it out while one is |
1452 | // active. This is unlikely to ever happen. |
1453 | MaybeRaiseExceptionFromStatus( |
1454 | tensorflow::errors::Internal( |
1455 | "Can't change the vspace implementation while a " |
1456 | "forward accumulator is active." ), |
1457 | nullptr); |
1458 | } |
1459 | delete py_vspace; |
1460 | } |
1461 | |
1462 | py_vspace = new PyVSpace(e); |
1463 | auto status = py_vspace->Initialize(); |
1464 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
1465 | delete py_vspace; |
1466 | return nullptr; |
1467 | } |
1468 | |
1469 | Py_RETURN_NONE; |
1470 | } |
1471 | |
1472 | PyObject* PyTapeTensor::GetShape() const { |
1473 | if (shape_.index() == 0) { |
1474 | auto& shape = absl::get<0>(shape_); |
1475 | PyObject* py_shape = PyTuple_New(shape.dims()); |
1476 | for (int i = 0; i < shape.dims(); ++i) { |
1477 | PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); |
1478 | } |
1479 | |
1480 | return py_shape; |
1481 | } |
1482 | |
1483 | return py_vspace->GraphShape(absl::get<1>(shape_)); |
1484 | } |
1485 | |
1486 | PyObject* PyTapeTensor::OnesLike() const { |
1487 | if (shape_.index() == 1) { |
1488 | PyObject* tensor = absl::get<1>(shape_); |
1489 | return py_vspace->OnesLike(tensor); |
1490 | } |
1491 | PyObject* py_shape = GetShape(); |
1492 | PyObject* dtype_field = GetPyDType(); |
1493 | PyObject* result = py_vspace->Ones(py_shape, dtype_field); |
1494 | Py_DECREF(dtype_field); |
1495 | Py_DECREF(py_shape); |
1496 | return result; |
1497 | } |
1498 | |
1499 | PyObject* PyTapeTensor::ZerosLike() const { |
1500 | if (GetDType() == tensorflow::DT_RESOURCE) { |
1501 | // Gradient functions for ops which return resource tensors accept |
1502 | // None. This is the behavior of py_vspace->Zeros, but checking here avoids |
1503 | // issues with ZerosLike. |
1504 | Py_RETURN_NONE; |
1505 | } |
1506 | if (shape_.index() == 1) { |
1507 | PyObject* tensor = absl::get<1>(shape_); |
1508 | return py_vspace->ZerosLike(tensor); |
1509 | } |
1510 | PyObject* py_shape = GetShape(); |
1511 | PyObject* dtype_field = GetPyDType(); |
1512 | PyObject* result = py_vspace->Zeros(py_shape, dtype_field); |
1513 | Py_DECREF(dtype_field); |
1514 | Py_DECREF(py_shape); |
1515 | return result; |
1516 | } |
1517 | |
1518 | // Keeps track of all variables that have been accessed during execution. |
1519 | class VariableWatcher { |
1520 | public: |
1521 | VariableWatcher() {} |
1522 | |
1523 | ~VariableWatcher() { |
1524 | for (const IdAndVariable& v : watched_variables_) { |
1525 | Py_DECREF(v.variable); |
1526 | } |
1527 | } |
1528 | |
1529 | int64_t WatchVariable(PyObject* v) { |
1530 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle" )); |
1531 | if (handle == nullptr) { |
1532 | return -1; |
1533 | } |
1534 | int64_t id = FastTensorId(handle.get()); |
1535 | |
1536 | tensorflow::mutex_lock l(watched_variables_mu_); |
1537 | auto insert_result = watched_variables_.emplace(id, v); |
1538 | |
1539 | if (insert_result.second) { |
1540 | // Only increment the reference count if we aren't already watching this |
1541 | // variable. |
1542 | Py_INCREF(v); |
1543 | } |
1544 | |
1545 | return id; |
1546 | } |
1547 | |
1548 | PyObject* GetVariablesAsPyTuple() { |
1549 | tensorflow::mutex_lock l(watched_variables_mu_); |
1550 | PyObject* result = PyTuple_New(watched_variables_.size()); |
1551 | Py_ssize_t pos = 0; |
1552 | for (const IdAndVariable& id_and_variable : watched_variables_) { |
1553 | PyTuple_SET_ITEM(result, pos++, id_and_variable.variable); |
1554 | Py_INCREF(id_and_variable.variable); |
1555 | } |
1556 | return result; |
1557 | } |
1558 | |
1559 | private: |
1560 | // We store an IdAndVariable in the map since the map needs to be locked |
1561 | // during insert, but should not call back into python during insert to avoid |
1562 | // deadlocking with the GIL. |
1563 | struct IdAndVariable { |
1564 | int64_t id; |
1565 | PyObject* variable; |
1566 | |
1567 | IdAndVariable(int64_t id, PyObject* variable) |
1568 | : id(id), variable(variable) {} |
1569 | }; |
1570 | struct CompareById { |
1571 | bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { |
1572 | return lhs.id < rhs.id; |
1573 | } |
1574 | }; |
1575 | |
1576 | tensorflow::mutex watched_variables_mu_; |
1577 | std::set<IdAndVariable, CompareById> watched_variables_ |
1578 | TF_GUARDED_BY(watched_variables_mu_); |
1579 | }; |
1580 | |
1581 | class GradientTape |
1582 | : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, |
1583 | PyTapeTensor> { |
1584 | public: |
1585 | explicit GradientTape(bool persistent, bool watch_accessed_variables) |
1586 | : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, |
1587 | PyTapeTensor>(persistent), |
1588 | watch_accessed_variables_(watch_accessed_variables) {} |
1589 | |
1590 | virtual ~GradientTape() {} |
1591 | |
1592 | void VariableAccessed(PyObject* v) { |
1593 | if (watch_accessed_variables_) { |
1594 | WatchVariable(v); |
1595 | } |
1596 | } |
1597 | |
1598 | void WatchVariable(PyObject* v) { |
1599 | int64_t id = variable_watcher_.WatchVariable(v); |
1600 | |
1601 | if (!PyErr_Occurred()) { |
1602 | this->Watch(id); |
1603 | } |
1604 | } |
1605 | |
1606 | PyObject* GetVariablesAsPyTuple() { |
1607 | return variable_watcher_.GetVariablesAsPyTuple(); |
1608 | } |
1609 | |
1610 | private: |
1611 | bool watch_accessed_variables_; |
1612 | VariableWatcher variable_watcher_; |
1613 | }; |
1614 | |
1615 | typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction, |
1616 | PyTapeTensor> |
1617 | ForwardAccumulator; |
1618 | |
1619 | // Incremented when a GradientTape or accumulator is newly added to a set, and |
1620 | // used to enforce an ordering between them. |
1621 | std::atomic_uint_fast64_t tape_nesting_id_counter(0); |
1622 | |
1623 | typedef struct { |
1624 | PyObject_HEAD |
1625 | /* Type-specific fields go here. */ |
1626 | GradientTape* tape; |
1627 | // A nesting order between GradientTapes and ForwardAccumulators, used to |
1628 | // ensure that GradientTapes do not watch the products of outer |
1629 | // ForwardAccumulators. |
1630 | int64_t nesting_id; |
1631 | } TFE_Py_Tape; |
1632 | |
1633 | static void TFE_Py_Tape_Delete(PyObject* tape) { |
1634 | delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; |
1635 | Py_TYPE(tape)->tp_free(tape); |
1636 | } |
1637 | |
1638 | static PyTypeObject TFE_Py_Tape_Type = { |
1639 | PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape" , /* tp_name */ |
1640 | sizeof(TFE_Py_Tape), /* tp_basicsize */ |
1641 | 0, /* tp_itemsize */ |
1642 | &TFE_Py_Tape_Delete, /* tp_dealloc */ |
1643 | #if PY_VERSION_HEX < 0x03080000 |
1644 | nullptr, /* tp_print */ |
1645 | #else |
1646 | 0, /* tp_vectorcall_offset */ |
1647 | #endif |
1648 | nullptr, /* tp_getattr */ |
1649 | nullptr, /* tp_setattr */ |
1650 | nullptr, /* tp_reserved */ |
1651 | nullptr, /* tp_repr */ |
1652 | nullptr, /* tp_as_number */ |
1653 | nullptr, /* tp_as_sequence */ |
1654 | nullptr, /* tp_as_mapping */ |
1655 | nullptr, /* tp_hash */ |
1656 | nullptr, /* tp_call */ |
1657 | nullptr, /* tp_str */ |
1658 | nullptr, /* tp_getattro */ |
1659 | nullptr, /* tp_setattro */ |
1660 | nullptr, /* tp_as_buffer */ |
1661 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
1662 | "TFE_Py_Tape objects" , /* tp_doc */ |
1663 | }; |
1664 | |
1665 | typedef struct { |
1666 | PyObject_HEAD |
1667 | /* Type-specific fields go here. */ |
1668 | ForwardAccumulator* accumulator; |
1669 | // A nesting order between GradientTapes and ForwardAccumulators, used to |
1670 | // ensure that GradientTapes do not watch the products of outer |
1671 | // ForwardAccumulators. |
1672 | int64_t nesting_id; |
1673 | } TFE_Py_ForwardAccumulator; |
1674 | |
1675 | static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) { |
1676 | delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator; |
1677 | Py_TYPE(accumulator)->tp_free(accumulator); |
1678 | } |
1679 | |
1680 | static PyTypeObject TFE_Py_ForwardAccumulator_Type = { |
1681 | PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator" , /* tp_name */ |
1682 | sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ |
1683 | 0, /* tp_itemsize */ |
1684 | &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ |
1685 | #if PY_VERSION_HEX < 0x03080000 |
1686 | nullptr, /* tp_print */ |
1687 | #else |
1688 | 0, /* tp_vectorcall_offset */ |
1689 | #endif |
1690 | nullptr, /* tp_getattr */ |
1691 | nullptr, /* tp_setattr */ |
1692 | nullptr, /* tp_reserved */ |
1693 | nullptr, /* tp_repr */ |
1694 | nullptr, /* tp_as_number */ |
1695 | nullptr, /* tp_as_sequence */ |
1696 | nullptr, /* tp_as_mapping */ |
1697 | nullptr, /* tp_hash */ |
1698 | nullptr, /* tp_call */ |
1699 | nullptr, /* tp_str */ |
1700 | nullptr, /* tp_getattro */ |
1701 | nullptr, /* tp_setattro */ |
1702 | nullptr, /* tp_as_buffer */ |
1703 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
1704 | "TFE_Py_ForwardAccumulator objects" , /* tp_doc */ |
1705 | }; |
1706 | |
1707 | typedef struct { |
1708 | PyObject_HEAD |
1709 | /* Type-specific fields go here. */ |
1710 | VariableWatcher* variable_watcher; |
1711 | } TFE_Py_VariableWatcher; |
1712 | |
1713 | static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) { |
1714 | delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) |
1715 | ->variable_watcher; |
1716 | Py_TYPE(variable_watcher)->tp_free(variable_watcher); |
1717 | } |
1718 | |
1719 | static PyTypeObject TFE_Py_VariableWatcher_Type = { |
1720 | PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher" , /* tp_name */ |
1721 | sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ |
1722 | 0, /* tp_itemsize */ |
1723 | &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ |
1724 | #if PY_VERSION_HEX < 0x03080000 |
1725 | nullptr, /* tp_print */ |
1726 | #else |
1727 | 0, /* tp_vectorcall_offset */ |
1728 | #endif |
1729 | nullptr, /* tp_getattr */ |
1730 | nullptr, /* tp_setattr */ |
1731 | nullptr, /* tp_reserved */ |
1732 | nullptr, /* tp_repr */ |
1733 | nullptr, /* tp_as_number */ |
1734 | nullptr, /* tp_as_sequence */ |
1735 | nullptr, /* tp_as_mapping */ |
1736 | nullptr, /* tp_hash */ |
1737 | nullptr, /* tp_call */ |
1738 | nullptr, /* tp_str */ |
1739 | nullptr, /* tp_getattro */ |
1740 | nullptr, /* tp_setattro */ |
1741 | nullptr, /* tp_as_buffer */ |
1742 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
1743 | "TFE_Py_VariableWatcher objects" , /* tp_doc */ |
1744 | }; |
1745 | |
1746 | // Note: in the current design no mutex is needed here because of the python |
1747 | // GIL, which is always held when any TFE_Py_* methods are called. We should |
1748 | // revisit this if/when decide to not hold the GIL while manipulating the tape |
1749 | // stack. |
1750 | tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { |
1751 | thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> |
1752 | tape_set; |
1753 | thread_local ThreadLocalDestructionMarker marker; |
1754 | if (!marker.alive) { |
1755 | // This thread is being destroyed. It is unsafe to access tape_set. |
1756 | return nullptr; |
1757 | } |
1758 | if (tape_set == nullptr) { |
1759 | tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>); |
1760 | } |
1761 | return tape_set.get(); |
1762 | } |
1763 | |
1764 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>* |
1765 | GetVariableWatcherSet() { |
1766 | thread_local std::unique_ptr< |
1767 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> |
1768 | variable_watcher_set; |
1769 | thread_local ThreadLocalDestructionMarker marker; |
1770 | if (!marker.alive) { |
1771 | // This thread is being destroyed. It is unsafe to access |
1772 | // variable_watcher_set. |
1773 | return nullptr; |
1774 | } |
1775 | if (variable_watcher_set == nullptr) { |
1776 | variable_watcher_set.reset( |
1777 | new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>); |
1778 | } |
1779 | return variable_watcher_set.get(); |
1780 | } |
1781 | |
1782 | // A linked hash set, where iteration is in insertion order. |
1783 | // |
1784 | // Nested accumulators rely on op recording happening in insertion order, so an |
1785 | // unordered data structure like CompactPointerSet is not suitable. Outer |
1786 | // accumulators need to observe operations first so they know to watch the inner |
1787 | // accumulator's jvp computation. |
1788 | // |
1789 | // Not thread safe. |
1790 | class AccumulatorSet { |
1791 | public: |
1792 | // Returns true if `element` was newly inserted, false if it already exists. |
1793 | bool insert(TFE_Py_ForwardAccumulator* element) { |
1794 | if (map_.find(element) != map_.end()) { |
1795 | return false; |
1796 | } |
1797 | ListType::iterator it = ordered_.insert(ordered_.end(), element); |
1798 | map_.insert(std::make_pair(element, it)); |
1799 | return true; |
1800 | } |
1801 | |
1802 | void erase(TFE_Py_ForwardAccumulator* element) { |
1803 | MapType::iterator existing = map_.find(element); |
1804 | if (existing == map_.end()) { |
1805 | return; |
1806 | } |
1807 | ListType::iterator list_position = existing->second; |
1808 | map_.erase(existing); |
1809 | ordered_.erase(list_position); |
1810 | } |
1811 | |
1812 | bool empty() const { return ordered_.empty(); } |
1813 | |
1814 | size_t size() const { return ordered_.size(); } |
1815 | |
1816 | private: |
1817 | typedef std::list<TFE_Py_ForwardAccumulator*> ListType; |
1818 | typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*, |
1819 | ListType::iterator> |
1820 | MapType; |
1821 | |
1822 | public: |
1823 | typedef ListType::const_iterator const_iterator; |
1824 | typedef ListType::const_reverse_iterator const_reverse_iterator; |
1825 | |
1826 | const_iterator begin() const { return ordered_.begin(); } |
1827 | const_iterator end() const { return ordered_.end(); } |
1828 | |
1829 | const_reverse_iterator rbegin() const { return ordered_.rbegin(); } |
1830 | const_reverse_iterator rend() const { return ordered_.rend(); } |
1831 | |
1832 | private: |
1833 | MapType map_; |
1834 | ListType ordered_; |
1835 | }; |
1836 | |
1837 | AccumulatorSet* GetAccumulatorSet() { |
1838 | thread_local std::unique_ptr<AccumulatorSet> accumulator_set; |
1839 | thread_local ThreadLocalDestructionMarker marker; |
1840 | if (!marker.alive) { |
1841 | // This thread is being destroyed. It is unsafe to access accumulator_set. |
1842 | return nullptr; |
1843 | } |
1844 | if (accumulator_set == nullptr) { |
1845 | accumulator_set.reset(new AccumulatorSet); |
1846 | } |
1847 | return accumulator_set.get(); |
1848 | } |
1849 | |
1850 | inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); } |
1851 | |
1852 | inline bool HasGradientTape() { return !GetTapeSet()->empty(); } |
1853 | |
1854 | inline bool HasAccumulatorOrTape() { |
1855 | return HasGradientTape() || HasAccumulator(); |
1856 | } |
1857 | |
1858 | // A safe copy of a set, used for tapes and accumulators. The copy is not |
1859 | // affected by other python threads changing the set of active tapes. |
1860 | template <typename ContainerType> |
1861 | class SafeSetCopy { |
1862 | public: |
1863 | explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) { |
1864 | for (auto* member : set_copy_) { |
1865 | Py_INCREF(member); |
1866 | } |
1867 | } |
1868 | |
1869 | ~SafeSetCopy() { |
1870 | for (auto* member : set_copy_) { |
1871 | Py_DECREF(member); |
1872 | } |
1873 | } |
1874 | |
1875 | typename ContainerType::const_iterator begin() const { |
1876 | return set_copy_.begin(); |
1877 | } |
1878 | |
1879 | typename ContainerType::const_iterator end() const { return set_copy_.end(); } |
1880 | |
1881 | bool empty() const { return set_copy_.empty(); } |
1882 | size_t size() const { return set_copy_.size(); } |
1883 | |
1884 | protected: |
1885 | ContainerType set_copy_; |
1886 | }; |
1887 | |
1888 | class SafeTapeSet |
1889 | : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> { |
1890 | public: |
1891 | SafeTapeSet() |
1892 | : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>( |
1893 | *GetTapeSet()) {} |
1894 | }; |
1895 | |
1896 | class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> { |
1897 | public: |
1898 | SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {} |
1899 | |
1900 | typename AccumulatorSet::const_reverse_iterator rbegin() const { |
1901 | return set_copy_.rbegin(); |
1902 | } |
1903 | |
1904 | typename AccumulatorSet::const_reverse_iterator rend() const { |
1905 | return set_copy_.rend(); |
1906 | } |
1907 | }; |
1908 | |
1909 | class SafeVariableWatcherSet |
1910 | : public SafeSetCopy< |
1911 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> { |
1912 | public: |
1913 | SafeVariableWatcherSet() |
1914 | : SafeSetCopy< |
1915 | tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>( |
1916 | *GetVariableWatcherSet()) {} |
1917 | }; |
1918 | |
1919 | bool* ThreadTapeIsStopped() { |
1920 | thread_local bool thread_tape_is_stopped{false}; |
1921 | return &thread_tape_is_stopped; |
1922 | } |
1923 | |
1924 | void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } |
1925 | |
1926 | void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } |
1927 | |
1928 | PyObject* TFE_Py_TapeSetIsStopped() { |
1929 | if (*ThreadTapeIsStopped()) { |
1930 | Py_RETURN_TRUE; |
1931 | } |
1932 | Py_RETURN_FALSE; |
1933 | } |
1934 | |
1935 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent, |
1936 | PyObject* watch_accessed_variables) { |
1937 | TFE_Py_Tape_Type.tp_new = PyType_GenericNew; |
1938 | if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; |
1939 | TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); |
1940 | tape->tape = new GradientTape(persistent == Py_True, |
1941 | watch_accessed_variables == Py_True); |
1942 | Py_INCREF(tape); |
1943 | tape->nesting_id = tape_nesting_id_counter.fetch_add(1); |
1944 | GetTapeSet()->insert(tape); |
1945 | return reinterpret_cast<PyObject*>(tape); |
1946 | } |
1947 | |
1948 | void TFE_Py_TapeSetAdd(PyObject* tape) { |
1949 | Py_INCREF(tape); |
1950 | TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape); |
1951 | if (!GetTapeSet()->insert(tfe_tape).second) { |
1952 | // Already exists in the tape set. |
1953 | Py_DECREF(tape); |
1954 | } else { |
1955 | tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1); |
1956 | } |
1957 | } |
1958 | |
1959 | PyObject* TFE_Py_TapeSetIsEmpty() { |
1960 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { |
1961 | Py_RETURN_TRUE; |
1962 | } |
1963 | Py_RETURN_FALSE; |
1964 | } |
1965 | |
1966 | void TFE_Py_TapeSetRemove(PyObject* tape) { |
1967 | auto* stack = GetTapeSet(); |
1968 | if (stack != nullptr) { |
1969 | stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); |
1970 | } |
1971 | // We kept a reference to the tape in the set to ensure it wouldn't get |
1972 | // deleted under us; cleaning it up here. |
1973 | Py_DECREF(tape); |
1974 | } |
1975 | |
1976 | static std::vector<int64_t> MakeIntList(PyObject* list) { |
1977 | if (list == Py_None) { |
1978 | return {}; |
1979 | } |
1980 | PyObject* seq = PySequence_Fast(list, "expected a sequence" ); |
1981 | if (seq == nullptr) { |
1982 | return {}; |
1983 | } |
1984 | int len = PySequence_Size(list); |
1985 | PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
1986 | std::vector<int64_t> tensor_ids; |
1987 | tensor_ids.reserve(len); |
1988 | for (int i = 0; i < len; ++i) { |
1989 | PyObject* item = seq_array[i]; |
1990 | #if PY_MAJOR_VERSION >= 3 |
1991 | if (PyLong_Check(item)) { |
1992 | #else |
1993 | if (PyLong_Check(item) || PyInt_Check(item)) { |
1994 | #endif |
1995 | int64_t id = MakeInt(item); |
1996 | tensor_ids.push_back(id); |
1997 | } else { |
1998 | tensor_ids.push_back(-1); |
1999 | } |
2000 | } |
2001 | Py_DECREF(seq); |
2002 | return tensor_ids; |
2003 | } |
2004 | |
2005 | // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be |
2006 | // null. Returns true on success and false on a Python exception. |
2007 | bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids, |
2008 | std::vector<tensorflow::DataType>* dtypes) { |
2009 | tensorflow::Safe_PyObjectPtr seq( |
2010 | PySequence_Fast(tensors, "expected a sequence" )); |
2011 | if (seq == nullptr) { |
2012 | return false; |
2013 | } |
2014 | int len = PySequence_Fast_GET_SIZE(seq.get()); |
2015 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
2016 | tensor_ids->reserve(len); |
2017 | dtypes->reserve(len); |
2018 | for (int i = 0; i < len; ++i) { |
2019 | PyObject* item = seq_array[i]; |
2020 | tensor_ids->push_back(FastTensorId(item)); |
2021 | dtypes->push_back(tensorflow::PyTensor_DataType(item)); |
2022 | } |
2023 | return true; |
2024 | } |
2025 | |
2026 | bool TapeCouldPossiblyRecord(PyObject* tensors) { |
2027 | if (tensors == Py_None) { |
2028 | return false; |
2029 | } |
2030 | if (*ThreadTapeIsStopped()) { |
2031 | return false; |
2032 | } |
2033 | if (!HasAccumulatorOrTape()) { |
2034 | return false; |
2035 | } |
2036 | return true; |
2037 | } |
2038 | |
2039 | bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); } |
2040 | |
2041 | bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); } |
2042 | |
2043 | PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) { |
2044 | if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) { |
2045 | Py_RETURN_FALSE; |
2046 | } |
2047 | // TODO(apassos) consider not building a list and changing the API to check |
2048 | // each tensor individually. |
2049 | std::vector<int64_t> tensor_ids; |
2050 | std::vector<tensorflow::DataType> dtypes; |
2051 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { |
2052 | return nullptr; |
2053 | } |
2054 | auto& tape_set = *GetTapeSet(); |
2055 | for (TFE_Py_Tape* tape : tape_set) { |
2056 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { |
2057 | Py_RETURN_TRUE; |
2058 | } |
2059 | } |
2060 | |
2061 | Py_RETURN_FALSE; |
2062 | } |
2063 | |
2064 | PyObject* TFE_Py_ForwardAccumulatorPushState() { |
2065 | auto& forward_accumulators = *GetAccumulatorSet(); |
2066 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
2067 | accumulator->accumulator->PushState(); |
2068 | } |
2069 | Py_RETURN_NONE; |
2070 | } |
2071 | |
2072 | PyObject* TFE_Py_ForwardAccumulatorPopState() { |
2073 | auto& forward_accumulators = *GetAccumulatorSet(); |
2074 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
2075 | accumulator->accumulator->PopState(); |
2076 | } |
2077 | Py_RETURN_NONE; |
2078 | } |
2079 | |
2080 | PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { |
2081 | if (!TapeCouldPossiblyRecord(tensors)) { |
2082 | return GetPythonObjectFromInt(0); |
2083 | } |
2084 | std::vector<int64_t> tensor_ids; |
2085 | std::vector<tensorflow::DataType> dtypes; |
2086 | if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { |
2087 | return nullptr; |
2088 | } |
2089 | |
2090 | // If there is a persistent tape watching, or if there are multiple tapes |
2091 | // watching, we'll return immediately indicating that higher-order tape |
2092 | // gradients are possible. |
2093 | bool some_tape_watching = false; |
2094 | if (CouldBackprop()) { |
2095 | auto& tape_set = *GetTapeSet(); |
2096 | for (TFE_Py_Tape* tape : tape_set) { |
2097 | if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { |
2098 | if (tape->tape->IsPersistent() || some_tape_watching) { |
2099 | // Either this is the second tape watching, or this tape is |
2100 | // persistent: higher-order gradients are possible. |
2101 | return GetPythonObjectFromInt(2); |
2102 | } |
2103 | some_tape_watching = true; |
2104 | } |
2105 | } |
2106 | } |
2107 | if (CouldForwardprop()) { |
2108 | auto& forward_accumulators = *GetAccumulatorSet(); |
2109 | for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
2110 | if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { |
2111 | if (some_tape_watching) { |
2112 | // This is the second tape watching: higher-order gradients are |
2113 | // possible. Note that there's no equivalent of persistence for |
2114 | // forward-mode. |
2115 | return GetPythonObjectFromInt(2); |
2116 | } |
2117 | some_tape_watching = true; |
2118 | } |
2119 | } |
2120 | } |
2121 | if (some_tape_watching) { |
2122 | // There's exactly one non-persistent tape. The user can request first-order |
2123 | // gradients but won't be able to get higher-order tape gradients. |
2124 | return GetPythonObjectFromInt(1); |
2125 | } else { |
2126 | // There are no tapes. The user can't request tape gradients. |
2127 | return GetPythonObjectFromInt(0); |
2128 | } |
2129 | } |
2130 | |
2131 | void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { |
2132 | if (!CouldBackprop()) { |
2133 | return; |
2134 | } |
2135 | int64_t tensor_id = FastTensorId(tensor); |
2136 | if (PyErr_Occurred()) { |
2137 | return; |
2138 | } |
2139 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); |
2140 | } |
2141 | |
2142 | bool ListContainsNone(PyObject* list) { |
2143 | if (list == Py_None) return true; |
2144 | tensorflow::Safe_PyObjectPtr seq( |
2145 | PySequence_Fast(list, "expected a sequence" )); |
2146 | if (seq == nullptr) { |
2147 | return false; |
2148 | } |
2149 | |
2150 | int len = PySequence_Size(list); |
2151 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
2152 | for (int i = 0; i < len; ++i) { |
2153 | PyObject* item = seq_array[i]; |
2154 | if (item == Py_None) return true; |
2155 | } |
2156 | |
2157 | return false; |
2158 | } |
2159 | |
2160 | // As an optimization, the tape generally keeps only the shape and dtype of |
2161 | // tensors, and uses this information to generate ones/zeros tensors. However, |
2162 | // some tensors require OnesLike/ZerosLike because their gradients do not match |
2163 | // their inference shape/dtype. |
2164 | bool DTypeNeedsHandleData(tensorflow::DataType dtype) { |
2165 | return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; |
2166 | } |
2167 | |
2168 | static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { |
2169 | if (EagerTensor_CheckExact(tensor)) { |
2170 | tensorflow::ImmediateExecutionTensorHandle* handle = |
2171 | tensorflow::unwrap(EagerTensor_Handle(tensor)); |
2172 | int64_t id = PyEagerTensor_ID(tensor); |
2173 | tensorflow::DataType dtype = |
2174 | static_cast<tensorflow::DataType>(handle->DataType()); |
2175 | if (DTypeNeedsHandleData(dtype)) { |
2176 | return PyTapeTensor(id, dtype, tensor); |
2177 | } |
2178 | |
2179 | tensorflow::TensorShape tensor_shape; |
2180 | int num_dims; |
2181 | tensorflow::Status status = handle->NumDims(&num_dims); |
2182 | if (status.ok()) { |
2183 | for (int i = 0; i < num_dims; ++i) { |
2184 | int64_t dim_size; |
2185 | status = handle->Dim(i, &dim_size); |
2186 | if (!status.ok()) break; |
2187 | tensor_shape.AddDim(dim_size); |
2188 | } |
2189 | } |
2190 | |
2191 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
2192 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
2193 | tensorflow::TensorShape({})); |
2194 | } else { |
2195 | return PyTapeTensor(id, dtype, tensor_shape); |
2196 | } |
2197 | } |
2198 | int64_t id = FastTensorId(tensor); |
2199 | if (PyErr_Occurred()) { |
2200 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
2201 | tensorflow::TensorShape({})); |
2202 | } |
2203 | PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype" ); |
2204 | PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum" ); |
2205 | Py_DECREF(dtype_object); |
2206 | tensorflow::DataType dtype = |
2207 | static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); |
2208 | Py_DECREF(dtype_enum); |
2209 | if (PyErr_Occurred()) { |
2210 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
2211 | tensorflow::TensorShape({})); |
2212 | } |
2213 | static char _shape_tuple[] = "_shape_tuple" ; |
2214 | tensorflow::Safe_PyObjectPtr shape_tuple( |
2215 | PyObject_CallMethod(tensor, _shape_tuple, nullptr)); |
2216 | if (PyErr_Occurred()) { |
2217 | return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
2218 | tensorflow::TensorShape({})); |
2219 | } |
2220 | |
2221 | if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { |
2222 | return PyTapeTensor(id, dtype, tensor); |
2223 | } |
2224 | |
2225 | auto l = MakeIntList(shape_tuple.get()); |
2226 | // Replace -1, which represents accidental Nones which can occur in graph mode |
2227 | // and can cause errors in shape construction with 0s. |
2228 | for (auto& c : l) { |
2229 | if (c < 0) { |
2230 | c = 0; |
2231 | } |
2232 | } |
2233 | tensorflow::TensorShape shape(l); |
2234 | return PyTapeTensor(id, dtype, shape); |
2235 | } |
2236 | |
2237 | // Populates output_info from output_seq, which must come from PySequence_Fast. |
2238 | // |
2239 | // Does not take ownership of output_seq. Returns true on success and false if a |
2240 | // Python exception has been set. |
2241 | bool TapeTensorsFromTensorSequence(PyObject* output_seq, |
2242 | std::vector<PyTapeTensor>* output_info) { |
2243 | Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq); |
2244 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq); |
2245 | output_info->reserve(output_len); |
2246 | for (Py_ssize_t i = 0; i < output_len; ++i) { |
2247 | output_info->push_back(TapeTensorFromTensor(output_seq_array[i])); |
2248 | if (PyErr_Occurred() != nullptr) { |
2249 | return false; |
2250 | } |
2251 | } |
2252 | return true; |
2253 | } |
2254 | |
2255 | std::vector<int64_t> MakeTensorIDList(PyObject* tensors) { |
2256 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence" ); |
2257 | if (seq == nullptr) { |
2258 | return {}; |
2259 | } |
2260 | int len = PySequence_Fast_GET_SIZE(seq); |
2261 | PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
2262 | std::vector<int64_t> list; |
2263 | list.reserve(len); |
2264 | for (int i = 0; i < len; ++i) { |
2265 | PyObject* tensor = seq_array[i]; |
2266 | list.push_back(FastTensorId(tensor)); |
2267 | if (PyErr_Occurred()) { |
2268 | Py_DECREF(seq); |
2269 | return list; |
2270 | } |
2271 | } |
2272 | Py_DECREF(seq); |
2273 | return list; |
2274 | } |
2275 | |
2276 | void TFE_Py_TapeVariableAccessed(PyObject* variable) { |
2277 | if (!CouldBackprop()) { |
2278 | return; |
2279 | } |
2280 | for (TFE_Py_Tape* tape : SafeTapeSet()) { |
2281 | tape->tape->VariableAccessed(variable); |
2282 | } |
2283 | } |
2284 | |
2285 | void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { |
2286 | if (!CouldBackprop()) { |
2287 | return; |
2288 | } |
2289 | reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); |
2290 | } |
2291 | |
2292 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { |
2293 | return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); |
2294 | } |
2295 | |
2296 | PyObject* TFE_Py_VariableWatcherNew() { |
2297 | TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew; |
2298 | if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr; |
2299 | TFE_Py_VariableWatcher* variable_watcher = |
2300 | PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type); |
2301 | variable_watcher->variable_watcher = new VariableWatcher(); |
2302 | Py_INCREF(variable_watcher); |
2303 | GetVariableWatcherSet()->insert(variable_watcher); |
2304 | return reinterpret_cast<PyObject*>(variable_watcher); |
2305 | } |
2306 | |
2307 | void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) { |
2308 | auto* stack = GetVariableWatcherSet(); |
2309 | stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)); |
2310 | // We kept a reference to the variable watcher in the set to ensure it |
2311 | // wouldn't get deleted under us; cleaning it up here. |
2312 | Py_DECREF(variable_watcher); |
2313 | } |
2314 | |
2315 | void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) { |
2316 | for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) { |
2317 | variable_watcher->variable_watcher->WatchVariable(variable); |
2318 | } |
2319 | } |
2320 | |
2321 | PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) { |
2322 | return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) |
2323 | ->variable_watcher->GetVariablesAsPyTuple(); |
2324 | } |
2325 | |
2326 | namespace { |
2327 | std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { |
2328 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence" ); |
2329 | if (seq == nullptr) { |
2330 | return {}; |
2331 | } |
2332 | int len = PySequence_Fast_GET_SIZE(seq); |
2333 | PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
2334 | std::vector<tensorflow::DataType> list; |
2335 | list.reserve(len); |
2336 | for (int i = 0; i < len; ++i) { |
2337 | PyObject* tensor = seq_array[i]; |
2338 | list.push_back(tensorflow::PyTensor_DataType(tensor)); |
2339 | } |
2340 | Py_DECREF(seq); |
2341 | return list; |
2342 | } |
2343 | |
2344 | PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id, |
2345 | PyObject* weak_tensor_ref) { |
2346 | auto* accumulator_set = GetAccumulatorSet(); |
2347 | if (accumulator_set != nullptr) { |
2348 | int64_t parsed_tensor_id = MakeInt(tensor_id); |
2349 | for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) { |
2350 | accumulator->accumulator->DeleteGradient(parsed_tensor_id); |
2351 | } |
2352 | } |
2353 | Py_DECREF(weak_tensor_ref); |
2354 | Py_DECREF(tensor_id); |
2355 | Py_INCREF(Py_None); |
2356 | return Py_None; |
2357 | } |
2358 | |
2359 | static PyMethodDef forward_accumulator_delete_gradient_method_def = { |
2360 | "ForwardAccumulatorDeleteGradient" , ForwardAccumulatorDeleteGradient, |
2361 | METH_O, "ForwardAccumulatorDeleteGradient" }; |
2362 | |
2363 | void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) { |
2364 | tensorflow::Safe_PyObjectPtr callback( |
2365 | PyCFunction_New(&forward_accumulator_delete_gradient_method_def, |
2366 | PyLong_FromLong(tensor_id))); |
2367 | // We need to keep a reference to the weakref active if we want our callback |
2368 | // called. The callback itself now owns the weakref object and the tensor ID |
2369 | // object. |
2370 | PyWeakref_NewRef(tensor, callback.get()); |
2371 | } |
2372 | |
2373 | void TapeSetRecordBackprop( |
2374 | const string& op_type, const std::vector<PyTapeTensor>& output_info, |
2375 | const std::vector<int64_t>& input_ids, |
2376 | const std::vector<tensorflow::DataType>& input_dtypes, |
2377 | const std::function<PyBackwardFunction*()>& backward_function_getter, |
2378 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
2379 | tensorflow::uint64 max_gradient_tape_id) { |
2380 | if (!CouldBackprop()) { |
2381 | return; |
2382 | } |
2383 | for (TFE_Py_Tape* tape : SafeTapeSet()) { |
2384 | if (tape->nesting_id < max_gradient_tape_id) { |
2385 | tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes, |
2386 | backward_function_getter, |
2387 | backward_function_killer); |
2388 | } |
2389 | } |
2390 | } |
2391 | |
2392 | bool TapeSetRecordForwardprop( |
2393 | const string& op_type, PyObject* output_seq, |
2394 | const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors, |
2395 | const std::vector<int64_t>& input_ids, |
2396 | const std::vector<tensorflow::DataType>& input_dtypes, |
2397 | const std::function<PyBackwardFunction*()>& backward_function_getter, |
2398 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
2399 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function, |
2400 | PyObject* forwardprop_output_indices, |
2401 | tensorflow::uint64* max_gradient_tape_id) { |
2402 | *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max(); |
2403 | if (!CouldForwardprop()) { |
2404 | return true; |
2405 | } |
2406 | auto accumulator_set = SafeAccumulatorSet(); |
2407 | tensorflow::Safe_PyObjectPtr input_seq( |
2408 | PySequence_Fast(input_tensors, "expected a sequence of tensors" )); |
2409 | if (input_seq == nullptr || PyErr_Occurred()) return false; |
2410 | Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get()); |
2411 | PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq); |
2412 | for (int i = 0; i < output_info.size(); ++i) { |
2413 | RegisterForwardAccumulatorCleanup(output_seq_array[i], |
2414 | output_info[i].GetID()); |
2415 | } |
2416 | if (forwardprop_output_indices != nullptr && |
2417 | forwardprop_output_indices != Py_None) { |
2418 | tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast( |
2419 | forwardprop_output_indices, "Expected a sequence of indices" )); |
2420 | if (indices_fast == nullptr || PyErr_Occurred()) { |
2421 | return false; |
2422 | } |
2423 | if (PySequence_Fast_GET_SIZE(indices_fast.get()) != |
2424 | accumulator_set.size()) { |
2425 | MaybeRaiseExceptionFromStatus( |
2426 | tensorflow::errors::Internal( |
2427 | "Accumulators were added or removed from the active set " |
2428 | "between packing and unpacking." ), |
2429 | nullptr); |
2430 | } |
2431 | PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get()); |
2432 | Py_ssize_t accumulator_index = 0; |
2433 | for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin(); |
2434 | it != accumulator_set.rend(); ++it, ++accumulator_index) { |
2435 | tensorflow::Safe_PyObjectPtr jvp_index_seq( |
2436 | PySequence_Fast(indices_fast_array[accumulator_index], |
2437 | "Expected a sequence of jvp indices." )); |
2438 | if (jvp_index_seq == nullptr || PyErr_Occurred()) { |
2439 | return false; |
2440 | } |
2441 | Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get()); |
2442 | PyObject** jvp_index_seq_array = |
2443 | PySequence_Fast_ITEMS(jvp_index_seq.get()); |
2444 | for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) { |
2445 | PyObject* tuple = jvp_index_seq_array[jvp_index]; |
2446 | int64_t primal_tensor_id = |
2447 | output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID(); |
2448 | (*it)->accumulator->Watch( |
2449 | primal_tensor_id, |
2450 | output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]); |
2451 | } |
2452 | } |
2453 | } else { |
2454 | std::vector<PyTapeTensor> input_info; |
2455 | input_info.reserve(input_len); |
2456 | PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get()); |
2457 | for (Py_ssize_t i = 0; i < input_len; ++i) { |
2458 | input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); |
2459 | } |
2460 | for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { |
2461 | tensorflow::Status status = accumulator->accumulator->Accumulate( |
2462 | op_type, input_info, output_info, input_ids, input_dtypes, |
2463 | forward_function, backward_function_getter, backward_function_killer); |
2464 | if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. |
2465 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
2466 | return false; |
2467 | } |
2468 | if (accumulator->accumulator->BusyAccumulating()) { |
2469 | // Ensure inner accumulators don't see outer accumulators' jvps. This |
2470 | // mostly happens on its own, with some potentially surprising |
2471 | // exceptions, so the blanket policy is for consistency. |
2472 | *max_gradient_tape_id = accumulator->nesting_id; |
2473 | break; |
2474 | } |
2475 | } |
2476 | } |
2477 | return true; |
2478 | } |
2479 | |
2480 | PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) { |
2481 | PyObject* py_input_tangents = PyTuple_New(input_tangents.size()); |
2482 | for (int i = 0; i < input_tangents.size(); ++i) { |
2483 | PyObject* element; |
2484 | if (input_tangents[i] == nullptr) { |
2485 | element = Py_None; |
2486 | } else { |
2487 | element = input_tangents[i]; |
2488 | } |
2489 | Py_INCREF(element); |
2490 | PyTuple_SET_ITEM(py_input_tangents, i, element); |
2491 | } |
2492 | return py_input_tangents; |
2493 | } |
2494 | |
2495 | tensorflow::Status ParseTangentOutputs( |
2496 | PyObject* user_output, std::vector<PyObject*>* output_tangents) { |
2497 | if (user_output == Py_None) { |
2498 | // No connected gradients. |
2499 | return ::tensorflow::OkStatus(); |
2500 | } |
2501 | tensorflow::Safe_PyObjectPtr fast_result( |
2502 | PySequence_Fast(user_output, "expected a sequence of forward gradients" )); |
2503 | if (fast_result == nullptr) { |
2504 | return tensorflow::errors::InvalidArgument( |
2505 | "forward gradient function did not return a sequence." ); |
2506 | } |
2507 | int len = PySequence_Fast_GET_SIZE(fast_result.get()); |
2508 | PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get()); |
2509 | output_tangents->reserve(len); |
2510 | for (int i = 0; i < len; ++i) { |
2511 | PyObject* item = fast_result_array[i]; |
2512 | if (item == Py_None) { |
2513 | output_tangents->push_back(nullptr); |
2514 | } else { |
2515 | Py_INCREF(item); |
2516 | output_tangents->push_back(item); |
2517 | } |
2518 | } |
2519 | return ::tensorflow::OkStatus(); |
2520 | } |
2521 | |
2522 | // Calls the registered forward_gradient_function, computing `output_tangents` |
2523 | // from `input_tangents`. `output_tangents` must not be null. |
2524 | // |
2525 | // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which |
2526 | // the forward function is being called. |
2527 | tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, |
2528 | PyObject* inputs, PyObject* results, |
2529 | const std::vector<PyObject*>& input_tangents, |
2530 | std::vector<PyObject*>* output_tangents, |
2531 | bool use_batch) { |
2532 | if (forward_gradient_function == nullptr) { |
2533 | return tensorflow::errors::Internal( |
2534 | "No forward gradient function registered." ); |
2535 | } |
2536 | tensorflow::Safe_PyObjectPtr py_input_tangents( |
2537 | TangentsAsPyTuple(input_tangents)); |
2538 | |
2539 | // Normalize the input sequence to a tuple so it works with function |
2540 | // caching; otherwise it may be an opaque _InputList object. |
2541 | tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs)); |
2542 | PyObject* to_batch = (use_batch) ? Py_True : Py_False; |
2543 | tensorflow::Safe_PyObjectPtr callback_args( |
2544 | Py_BuildValue("OOOOOO" , op_name, attrs, input_tuple.get(), results, |
2545 | py_input_tangents.get(), to_batch)); |
2546 | tensorflow::Safe_PyObjectPtr py_result( |
2547 | PyObject_CallObject(forward_gradient_function, callback_args.get())); |
2548 | if (py_result == nullptr || PyErr_Occurred()) { |
2549 | return tensorflow::errors::Internal( |
2550 | "forward gradient function threw exceptions" ); |
2551 | } |
2552 | return ParseTangentOutputs(py_result.get(), output_tangents); |
2553 | } |
2554 | |
2555 | // Like CallJVPFunction, but calls a pre-bound forward function. |
2556 | // These are passed in from a record_gradient argument. |
2557 | tensorflow::Status CallOpSpecificJVPFunction( |
2558 | PyObject* op_specific_forward_function, |
2559 | const std::vector<PyObject*>& input_tangents, |
2560 | std::vector<PyObject*>* output_tangents) { |
2561 | tensorflow::Safe_PyObjectPtr py_input_tangents( |
2562 | TangentsAsPyTuple(input_tangents)); |
2563 | |
2564 | tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject( |
2565 | op_specific_forward_function, py_input_tangents.get())); |
2566 | if (py_result == nullptr || PyErr_Occurred()) { |
2567 | return tensorflow::errors::Internal( |
2568 | "forward gradient function threw exceptions" ); |
2569 | } |
2570 | return ParseTangentOutputs(py_result.get(), output_tangents); |
2571 | } |
2572 | |
2573 | bool ParseOpTypeString(PyObject* op_type, string* op_type_string) { |
2574 | if (PyBytes_Check(op_type)) { |
2575 | *op_type_string = PyBytes_AsString(op_type); |
2576 | } else if (PyUnicode_Check(op_type)) { |
2577 | #if PY_MAJOR_VERSION >= 3 |
2578 | *op_type_string = PyUnicode_AsUTF8(op_type); |
2579 | #else |
2580 | PyObject* py_str = PyUnicode_AsUTF8String(op_type); |
2581 | if (py_str == nullptr) { |
2582 | return false; |
2583 | } |
2584 | *op_type_string = PyBytes_AS_STRING(py_str); |
2585 | Py_DECREF(py_str); |
2586 | #endif |
2587 | } else { |
2588 | PyErr_SetString(PyExc_RuntimeError, "op_type should be a string." ); |
2589 | return false; |
2590 | } |
2591 | return true; |
2592 | } |
2593 | |
2594 | bool TapeSetRecordOperation( |
2595 | PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors, |
2596 | const std::vector<int64_t>& input_ids, |
2597 | const std::vector<tensorflow::DataType>& input_dtypes, |
2598 | const std::function<PyBackwardFunction*()>& backward_function_getter, |
2599 | const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
2600 | const tensorflow::eager::ForwardFunction<PyObject>* forward_function) { |
2601 | std::vector<PyTapeTensor> output_info; |
2602 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
2603 | output_tensors, "expected a sequence of integer tensor ids" )); |
2604 | if (PyErr_Occurred() || |
2605 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
2606 | return false; |
2607 | } |
2608 | string op_type_str; |
2609 | if (!ParseOpTypeString(op_type, &op_type_str)) { |
2610 | return false; |
2611 | } |
2612 | tensorflow::uint64 max_gradient_tape_id; |
2613 | if (!TapeSetRecordForwardprop( |
2614 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, |
2615 | input_dtypes, backward_function_getter, backward_function_killer, |
2616 | forward_function, nullptr /* No special-cased jvps. */, |
2617 | &max_gradient_tape_id)) { |
2618 | return false; |
2619 | } |
2620 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, |
2621 | backward_function_getter, backward_function_killer, |
2622 | max_gradient_tape_id); |
2623 | return true; |
2624 | } |
2625 | } // namespace |
2626 | |
2627 | PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, |
2628 | PyObject* output_tensors, |
2629 | PyObject* input_tensors, |
2630 | PyObject* backward_function, |
2631 | PyObject* forward_function) { |
2632 | if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) { |
2633 | Py_RETURN_NONE; |
2634 | } |
2635 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); |
2636 | if (PyErr_Occurred()) return nullptr; |
2637 | |
2638 | std::vector<tensorflow::DataType> input_dtypes = |
2639 | MakeTensorDtypeList(input_tensors); |
2640 | if (PyErr_Occurred()) return nullptr; |
2641 | |
2642 | std::function<PyBackwardFunction*()> backward_function_getter( |
2643 | [backward_function]() { |
2644 | Py_INCREF(backward_function); |
2645 | PyBackwardFunction* function = new PyBackwardFunction( |
2646 | [backward_function](PyObject* out_grads, |
2647 | const std::vector<int64_t>& unused) { |
2648 | return PyObject_CallObject(backward_function, out_grads); |
2649 | }); |
2650 | return function; |
2651 | }); |
2652 | std::function<void(PyBackwardFunction*)> backward_function_killer( |
2653 | [backward_function](PyBackwardFunction* py_backward_function) { |
2654 | Py_DECREF(backward_function); |
2655 | delete py_backward_function; |
2656 | }); |
2657 | |
2658 | if (forward_function == Py_None) { |
2659 | if (!TapeSetRecordOperation( |
2660 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, |
2661 | backward_function_getter, backward_function_killer, |
2662 | nullptr /* No special-cased forward function */)) { |
2663 | return nullptr; |
2664 | } |
2665 | } else { |
2666 | tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function( |
2667 | [forward_function](const std::vector<PyObject*>& input_tangents, |
2668 | std::vector<PyObject*>* output_tangents, |
2669 | bool use_batch = false) { |
2670 | return CallOpSpecificJVPFunction(forward_function, input_tangents, |
2671 | output_tangents); |
2672 | }); |
2673 | if (!TapeSetRecordOperation( |
2674 | op_type, input_tensors, output_tensors, input_ids, input_dtypes, |
2675 | backward_function_getter, backward_function_killer, |
2676 | &wrapped_forward_function)) { |
2677 | return nullptr; |
2678 | } |
2679 | } |
2680 | Py_RETURN_NONE; |
2681 | } |
2682 | |
2683 | PyObject* TFE_Py_TapeSetRecordOperationForwardprop( |
2684 | PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, |
2685 | PyObject* backward_function, PyObject* forwardprop_output_indices) { |
2686 | if (!HasAccumulator() || *ThreadTapeIsStopped()) { |
2687 | Py_RETURN_NONE; |
2688 | } |
2689 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); |
2690 | if (PyErr_Occurred()) return nullptr; |
2691 | |
2692 | std::vector<tensorflow::DataType> input_dtypes = |
2693 | MakeTensorDtypeList(input_tensors); |
2694 | if (PyErr_Occurred()) return nullptr; |
2695 | |
2696 | std::function<PyBackwardFunction*()> backward_function_getter( |
2697 | [backward_function]() { |
2698 | Py_INCREF(backward_function); |
2699 | PyBackwardFunction* function = new PyBackwardFunction( |
2700 | [backward_function](PyObject* out_grads, |
2701 | const std::vector<int64_t>& unused) { |
2702 | return PyObject_CallObject(backward_function, out_grads); |
2703 | }); |
2704 | return function; |
2705 | }); |
2706 | std::function<void(PyBackwardFunction*)> backward_function_killer( |
2707 | [backward_function](PyBackwardFunction* py_backward_function) { |
2708 | Py_DECREF(backward_function); |
2709 | delete py_backward_function; |
2710 | }); |
2711 | std::vector<PyTapeTensor> output_info; |
2712 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
2713 | output_tensors, "expected a sequence of integer tensor ids" )); |
2714 | if (PyErr_Occurred() || |
2715 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
2716 | return nullptr; |
2717 | } |
2718 | string op_type_str; |
2719 | if (!ParseOpTypeString(op_type, &op_type_str)) { |
2720 | return nullptr; |
2721 | } |
2722 | tensorflow::uint64 max_gradient_tape_id; |
2723 | if (!TapeSetRecordForwardprop( |
2724 | op_type_str, output_seq.get(), output_info, input_tensors, input_ids, |
2725 | input_dtypes, backward_function_getter, backward_function_killer, |
2726 | nullptr /* no special-cased forward function */, |
2727 | forwardprop_output_indices, &max_gradient_tape_id)) { |
2728 | return nullptr; |
2729 | } |
2730 | Py_RETURN_NONE; |
2731 | } |
2732 | |
2733 | PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, |
2734 | PyObject* output_tensors, |
2735 | PyObject* input_tensors, |
2736 | PyObject* backward_function) { |
2737 | if (!CouldBackprop()) { |
2738 | Py_RETURN_NONE; |
2739 | } |
2740 | std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors); |
2741 | if (PyErr_Occurred()) return nullptr; |
2742 | |
2743 | std::vector<tensorflow::DataType> input_dtypes = |
2744 | MakeTensorDtypeList(input_tensors); |
2745 | if (PyErr_Occurred()) return nullptr; |
2746 | |
2747 | std::function<PyBackwardFunction*()> backward_function_getter( |
2748 | [backward_function]() { |
2749 | Py_INCREF(backward_function); |
2750 | PyBackwardFunction* function = new PyBackwardFunction( |
2751 | [backward_function](PyObject* out_grads, |
2752 | const std::vector<int64_t>& unused) { |
2753 | return PyObject_CallObject(backward_function, out_grads); |
2754 | }); |
2755 | return function; |
2756 | }); |
2757 | std::function<void(PyBackwardFunction*)> backward_function_killer( |
2758 | [backward_function](PyBackwardFunction* py_backward_function) { |
2759 | Py_DECREF(backward_function); |
2760 | delete py_backward_function; |
2761 | }); |
2762 | std::vector<PyTapeTensor> output_info; |
2763 | tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
2764 | output_tensors, "expected a sequence of integer tensor ids" )); |
2765 | if (PyErr_Occurred() || |
2766 | !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
2767 | return nullptr; |
2768 | } |
2769 | string op_type_str; |
2770 | if (!ParseOpTypeString(op_type, &op_type_str)) { |
2771 | return nullptr; |
2772 | } |
2773 | TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, |
2774 | backward_function_getter, backward_function_killer, |
2775 | // No filtering based on relative ordering with forward |
2776 | // accumulators. |
2777 | std::numeric_limits<tensorflow::uint64>::max()); |
2778 | Py_RETURN_NONE; |
2779 | } |
2780 | |
2781 | void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) { |
2782 | auto* tape_set = GetTapeSet(); |
2783 | if (tape_set == nullptr) { |
2784 | // Current thread is being destructed, and the tape set has already |
2785 | // been cleared. |
2786 | return; |
2787 | } |
2788 | for (TFE_Py_Tape* tape : *tape_set) { |
2789 | tape->tape->DeleteTrace(tensor_id); |
2790 | } |
2791 | } |
2792 | |
2793 | std::vector<PyObject*> MakeTensorList(PyObject* tensors) { |
2794 | PyObject* seq = PySequence_Fast(tensors, "expected a sequence" ); |
2795 | if (seq == nullptr) { |
2796 | return {}; |
2797 | } |
2798 | int len = PySequence_Fast_GET_SIZE(seq); |
2799 | PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
2800 | std::vector<PyObject*> list(seq_array, seq_array + len); |
2801 | Py_DECREF(seq); |
2802 | return list; |
2803 | } |
2804 | |
2805 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, |
2806 | PyObject* sources, PyObject* output_gradients, |
2807 | PyObject* sources_raw, |
2808 | PyObject* unconnected_gradients, |
2809 | TF_Status* status) { |
2810 | TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); |
2811 | if (!tape_obj->tape->IsPersistent()) { |
2812 | auto* tape_set = GetTapeSet(); |
2813 | if (tape_set->find(tape_obj) != tape_set->end()) { |
2814 | PyErr_SetString(PyExc_RuntimeError, |
2815 | "gradient() cannot be invoked within the " |
2816 | "GradientTape context (i.e., while operations are being " |
2817 | "recorded). Either move the call to gradient() to be " |
2818 | "outside the 'with tf.GradientTape' block, or " |
2819 | "use a persistent tape: " |
2820 | "'with tf.GradientTape(persistent=true)'" ); |
2821 | return nullptr; |
2822 | } |
2823 | } |
2824 | |
2825 | std::vector<int64_t> target_vec = MakeTensorIDList(target); |
2826 | if (PyErr_Occurred()) { |
2827 | return nullptr; |
2828 | } |
2829 | std::vector<int64_t> sources_vec = MakeTensorIDList(sources); |
2830 | if (PyErr_Occurred()) { |
2831 | return nullptr; |
2832 | } |
2833 | tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(), |
2834 | sources_vec.end()); |
2835 | |
2836 | tensorflow::Safe_PyObjectPtr seq = |
2837 | tensorflow::make_safe(PySequence_Fast(target, "expected a sequence" )); |
2838 | int len = PySequence_Fast_GET_SIZE(seq.get()); |
2839 | PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
2840 | std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets; |
2841 | for (int i = 0; i < len; ++i) { |
2842 | int64_t target_id = target_vec[i]; |
2843 | if (sources_set.find(target_id) != sources_set.end()) { |
2844 | auto tensor = seq_array[i]; |
2845 | source_tensors_that_are_targets.insert( |
2846 | std::make_pair(target_id, TapeTensorFromTensor(tensor))); |
2847 | } |
2848 | if (PyErr_Occurred()) { |
2849 | return nullptr; |
2850 | } |
2851 | } |
2852 | if (PyErr_Occurred()) { |
2853 | return nullptr; |
2854 | } |
2855 | |
2856 | std::vector<PyObject*> outgrad_vec; |
2857 | if (output_gradients != Py_None) { |
2858 | outgrad_vec = MakeTensorList(output_gradients); |
2859 | if (PyErr_Occurred()) { |
2860 | return nullptr; |
2861 | } |
2862 | for (PyObject* tensor : outgrad_vec) { |
2863 | // Calling the backward function will eat a reference to the tensors in |
2864 | // outgrad_vec, so we need to increase their reference count. |
2865 | Py_INCREF(tensor); |
2866 | } |
2867 | } |
2868 | std::vector<PyObject*> result(sources_vec.size()); |
2869 | status->status = tape_obj->tape->ComputeGradient( |
2870 | *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets, |
2871 | outgrad_vec, absl::MakeSpan(result)); |
2872 | if (!status->status.ok()) { |
2873 | if (PyErr_Occurred()) { |
2874 | // Do not propagate the erroneous status as that would swallow the |
2875 | // exception which caused the problem. |
2876 | status->status = ::tensorflow::OkStatus(); |
2877 | } |
2878 | return nullptr; |
2879 | } |
2880 | |
2881 | bool unconnected_gradients_zero = |
2882 | strcmp(TFE_GetPythonString(unconnected_gradients), "zero" ) == 0; |
2883 | std::vector<PyObject*> sources_obj; |
2884 | if (unconnected_gradients_zero) { |
2885 | // Uses the "raw" sources here so it can properly make a zeros tensor even |
2886 | // if there are resource variables as sources. |
2887 | sources_obj = MakeTensorList(sources_raw); |
2888 | } |
2889 | |
2890 | if (!result.empty()) { |
2891 | PyObject* py_result = PyList_New(result.size()); |
2892 | tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size()); |
2893 | for (int i = 0; i < result.size(); ++i) { |
2894 | if (result[i] == nullptr) { |
2895 | if (unconnected_gradients_zero) { |
2896 | // generate a zeros tensor in the shape of sources[i] |
2897 | tensorflow::DataType dtype = |
2898 | tensorflow::PyTensor_DataType(sources_obj[i]); |
2899 | PyTapeTensor tensor = |
2900 | PyTapeTensor(sources_vec[i], dtype, sources_obj[i]); |
2901 | result[i] = tensor.ZerosLike(); |
2902 | } else { |
2903 | Py_INCREF(Py_None); |
2904 | result[i] = Py_None; |
2905 | } |
2906 | } else if (seen_results.find(result[i]) != seen_results.end()) { |
2907 | Py_INCREF(result[i]); |
2908 | } |
2909 | seen_results.insert(result[i]); |
2910 | PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i])); |
2911 | } |
2912 | return py_result; |
2913 | } |
2914 | return PyList_New(0); |
2915 | } |
2916 | |
2917 | PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) { |
2918 | TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew; |
2919 | if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; |
2920 | TFE_Py_ForwardAccumulator* accumulator = |
2921 | PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type); |
2922 | if (py_vspace == nullptr) { |
2923 | MaybeRaiseExceptionFromStatus( |
2924 | tensorflow::errors::Internal( |
2925 | "ForwardAccumulator requires a PyVSpace to be registered." ), |
2926 | nullptr); |
2927 | } |
2928 | accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch); |
2929 | return reinterpret_cast<PyObject*>(accumulator); |
2930 | } |
2931 | |
2932 | PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) { |
2933 | TFE_Py_ForwardAccumulator* c_accumulator( |
2934 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); |
2935 | c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1); |
2936 | if (GetAccumulatorSet()->insert(c_accumulator)) { |
2937 | Py_INCREF(accumulator); |
2938 | Py_RETURN_NONE; |
2939 | } else { |
2940 | MaybeRaiseExceptionFromStatus( |
2941 | tensorflow::errors::Internal( |
2942 | "A ForwardAccumulator was added to the active set twice." ), |
2943 | nullptr); |
2944 | return nullptr; |
2945 | } |
2946 | } |
2947 | |
2948 | void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) { |
2949 | auto* accumulator_set = GetAccumulatorSet(); |
2950 | if (accumulator_set != nullptr) { |
2951 | accumulator_set->erase( |
2952 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); |
2953 | } |
2954 | Py_DECREF(accumulator); |
2955 | } |
2956 | |
2957 | void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, |
2958 | PyObject* tangent) { |
2959 | int64_t tensor_id = FastTensorId(tensor); |
2960 | reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) |
2961 | ->accumulator->Watch(tensor_id, tangent); |
2962 | RegisterForwardAccumulatorCleanup(tensor, tensor_id); |
2963 | } |
2964 | |
2965 | // Returns a new reference to the JVP Tensor. |
2966 | PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, |
2967 | PyObject* tensor) { |
2968 | PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) |
2969 | ->accumulator->FetchJVP(FastTensorId(tensor)); |
2970 | if (jvp == nullptr) { |
2971 | jvp = Py_None; |
2972 | } |
2973 | Py_INCREF(jvp); |
2974 | return jvp; |
2975 | } |
2976 | |
2977 | PyObject* TFE_Py_PackJVPs(PyObject* tensors) { |
2978 | if (!TapeCouldPossiblyRecord(tensors)) { |
2979 | tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0)); |
2980 | tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0)); |
2981 | return PyTuple_Pack(2, empty_tuple.get(), empty_list.get()); |
2982 | } |
2983 | auto& accumulators = *GetAccumulatorSet(); |
2984 | tensorflow::Safe_PyObjectPtr tensors_fast( |
2985 | PySequence_Fast(tensors, "Expected a sequence of input Tensors." )); |
2986 | if (tensors_fast == nullptr || PyErr_Occurred()) { |
2987 | return nullptr; |
2988 | } |
2989 | std::vector<int64_t> augmented_input_ids; |
2990 | int len = PySequence_Fast_GET_SIZE(tensors_fast.get()); |
2991 | PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get()); |
2992 | for (Py_ssize_t position = 0; position < len; ++position) { |
2993 | PyObject* input = tensors_fast_array[position]; |
2994 | if (input == Py_None) { |
2995 | continue; |
2996 | } |
2997 | tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input)); |
2998 | if (input_dtype == tensorflow::DT_INVALID) { |
2999 | return nullptr; |
3000 | } |
3001 | augmented_input_ids.push_back(FastTensorId(input)); |
3002 | } |
3003 | if (PyErr_Occurred()) { |
3004 | return nullptr; |
3005 | } |
3006 | // Find the innermost accumulator such that all outer accumulators are |
3007 | // recording. Any more deeply nested accumulators will not have their JVPs |
3008 | // saved. |
3009 | AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin(); |
3010 | for (; innermost_all_recording != accumulators.end(); |
3011 | ++innermost_all_recording) { |
3012 | if ((*innermost_all_recording)->accumulator->BusyAccumulating()) { |
3013 | break; |
3014 | } |
3015 | } |
3016 | AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording( |
3017 | innermost_all_recording); |
3018 | |
3019 | bool saving_jvps = false; |
3020 | tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size())); |
3021 | std::vector<PyObject*> new_tensors; |
3022 | Py_ssize_t accumulator_index = 0; |
3023 | // Start with the innermost accumulators to give outer accumulators a chance |
3024 | // to find their higher-order JVPs. |
3025 | for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin(); |
3026 | it != accumulators.rend(); ++it, ++accumulator_index) { |
3027 | std::vector<int64_t> new_input_ids; |
3028 | std::vector<std::pair<int64_t, int64_t>> accumulator_indices; |
3029 | if (it == reverse_innermost_all_recording) { |
3030 | saving_jvps = true; |
3031 | } |
3032 | if (saving_jvps) { |
3033 | for (int input_index = 0; input_index < augmented_input_ids.size(); |
3034 | ++input_index) { |
3035 | int64_t existing_input = augmented_input_ids[input_index]; |
3036 | PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input); |
3037 | if (jvp != nullptr) { |
3038 | new_tensors.push_back(jvp); |
3039 | new_input_ids.push_back(FastTensorId(jvp)); |
3040 | accumulator_indices.emplace_back( |
3041 | input_index, |
3042 | augmented_input_ids.size() + new_input_ids.size() - 1); |
3043 | } |
3044 | } |
3045 | } |
3046 | tensorflow::Safe_PyObjectPtr accumulator_indices_py( |
3047 | PyTuple_New(accumulator_indices.size())); |
3048 | for (int i = 0; i < accumulator_indices.size(); ++i) { |
3049 | tensorflow::Safe_PyObjectPtr from_index( |
3050 | GetPythonObjectFromInt(accumulator_indices[i].first)); |
3051 | tensorflow::Safe_PyObjectPtr to_index( |
3052 | GetPythonObjectFromInt(accumulator_indices[i].second)); |
3053 | PyTuple_SetItem(accumulator_indices_py.get(), i, |
3054 | PyTuple_Pack(2, from_index.get(), to_index.get())); |
3055 | } |
3056 | PyTuple_SetItem(all_indices.get(), accumulator_index, |
3057 | accumulator_indices_py.release()); |
3058 | augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(), |
3059 | new_input_ids.end()); |
3060 | } |
3061 | |
3062 | tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size())); |
3063 | for (int i = 0; i < new_tensors.size(); ++i) { |
3064 | PyObject* jvp = new_tensors[i]; |
3065 | Py_INCREF(jvp); |
3066 | PyList_SET_ITEM(new_tensors_py.get(), i, jvp); |
3067 | } |
3068 | return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get()); |
3069 | } |
3070 | |
3071 | namespace { |
3072 | |
3073 | // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C. |
3074 | enum FastPathExecuteArgIndex { |
3075 | FAST_PATH_EXECUTE_ARG_CONTEXT = 0, |
3076 | FAST_PATH_EXECUTE_ARG_OP_NAME = 1, |
3077 | FAST_PATH_EXECUTE_ARG_NAME = 2, |
3078 | FAST_PATH_EXECUTE_ARG_INPUT_START = 3 |
3079 | }; |
3080 | |
3081 | PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) { |
3082 | #if PY_MAJOR_VERSION >= 3 |
3083 | return PyUnicode_FromStringAndSize(s.data(), s.size()); |
3084 | #else |
3085 | return PyBytes_FromStringAndSize(s.data(), s.size()); |
3086 | #endif |
3087 | } |
3088 | |
3089 | bool CheckResourceVariable(PyObject* item) { |
3090 | if (tensorflow::swig::IsResourceVariable(item)) { |
3091 | tensorflow::Safe_PyObjectPtr handle( |
3092 | PyObject_GetAttrString(item, "_handle" )); |
3093 | return EagerTensor_CheckExact(handle.get()); |
3094 | } |
3095 | |
3096 | return false; |
3097 | } |
3098 | |
3099 | bool IsNumberType(PyObject* item) { |
3100 | #if PY_MAJOR_VERSION >= 3 |
3101 | return PyFloat_Check(item) || PyLong_Check(item); |
3102 | #else |
3103 | return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item); |
3104 | #endif |
3105 | } |
3106 | |
3107 | bool CheckOneInput(PyObject* item) { |
3108 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) || |
3109 | PyArray_Check(item) || IsNumberType(item)) { |
3110 | return true; |
3111 | } |
3112 | |
3113 | // Sequences are not properly handled. Sequences with purely python numeric |
3114 | // types work, but sequences with mixes of EagerTensors and python numeric |
3115 | // types don't work. |
3116 | // TODO(nareshmodi): fix |
3117 | return false; |
3118 | } |
3119 | |
3120 | bool CheckInputsOk(PyObject* seq, int start_index, |
3121 | const tensorflow::OpDef& op_def) { |
3122 | for (int i = 0; i < op_def.input_arg_size(); i++) { |
3123 | PyObject* item = PyTuple_GET_ITEM(seq, i + start_index); |
3124 | if (!op_def.input_arg(i).number_attr().empty() || |
3125 | !op_def.input_arg(i).type_list_attr().empty()) { |
3126 | // This item should be a seq input. |
3127 | if (!PySequence_Check(item)) { |
3128 | VLOG(1) << "Falling back to slow path for Op \"" << op_def.name() |
3129 | << "\", Input \"" << op_def.input_arg(i).name() |
3130 | << "\" since we expected a sequence, but got " |
3131 | << item->ob_type->tp_name; |
3132 | return false; |
3133 | } |
3134 | tensorflow::Safe_PyObjectPtr fast_item( |
3135 | PySequence_Fast(item, "Could not parse sequence." )); |
3136 | if (fast_item.get() == nullptr) { |
3137 | return false; |
3138 | } |
3139 | int len = PySequence_Fast_GET_SIZE(fast_item.get()); |
3140 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get()); |
3141 | for (Py_ssize_t j = 0; j < len; j++) { |
3142 | PyObject* inner_item = fast_item_array[j]; |
3143 | if (!CheckOneInput(inner_item)) { |
3144 | VLOG(1) << "Falling back to slow path for Op \"" << op_def.name() |
3145 | << "\", Input \"" << op_def.input_arg(i).name() |
3146 | << "\", Index " << j |
3147 | << " since we expected an EagerTensor/ResourceVariable, " |
3148 | "but got " |
3149 | << inner_item->ob_type->tp_name; |
3150 | return false; |
3151 | } |
3152 | } |
3153 | } else if (!CheckOneInput(item)) { |
3154 | VLOG(1) |
3155 | << "Falling back to slow path for Op \"" << op_def.name() |
3156 | << "\", Input \"" << op_def.input_arg(i).name() |
3157 | << "\" since we expected an EagerTensor/ResourceVariable, but got " |
3158 | << item->ob_type->tp_name; |
3159 | return false; |
3160 | } |
3161 | } |
3162 | |
3163 | return true; |
3164 | } |
3165 | |
3166 | tensorflow::DataType MaybeGetDType(PyObject* item) { |
3167 | if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) { |
3168 | return tensorflow::PyTensor_DataType(item); |
3169 | } |
3170 | |
3171 | return tensorflow::DT_INVALID; |
3172 | } |
3173 | |
3174 | tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, |
3175 | FastPathOpExecInfo* op_exec_info) { |
3176 | auto cached_it = op_exec_info->cached_dtypes.find(attr); |
3177 | if (cached_it != op_exec_info->cached_dtypes.end()) { |
3178 | return cached_it->second; |
3179 | } |
3180 | |
3181 | auto it = op_exec_info->attr_to_inputs_map->find(attr); |
3182 | if (it == op_exec_info->attr_to_inputs_map->end()) { |
3183 | // No other inputs - this should never happen. |
3184 | return tensorflow::DT_INVALID; |
3185 | } |
3186 | |
3187 | for (const auto& input_info : it->second) { |
3188 | PyObject* item = PyTuple_GET_ITEM( |
3189 | op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i); |
3190 | if (input_info.is_list) { |
3191 | tensorflow::Safe_PyObjectPtr fast_item( |
3192 | PySequence_Fast(item, "Unable to allocate" )); |
3193 | int len = PySequence_Fast_GET_SIZE(fast_item.get()); |
3194 | PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get()); |
3195 | for (int i = 0; i < len; i++) { |
3196 | auto dtype = MaybeGetDType(fast_item_array[i]); |
3197 | if (dtype != tensorflow::DT_INVALID) return dtype; |
3198 | } |
3199 | } else { |
3200 | auto dtype = MaybeGetDType(item); |
3201 | if (dtype != tensorflow::DT_INVALID) return dtype; |
3202 | } |
3203 | } |
3204 | |
3205 | auto default_it = op_exec_info->default_dtypes->find(attr); |
3206 | if (default_it != op_exec_info->default_dtypes->end()) { |
3207 | return default_it->second; |
3208 | } |
3209 | |
3210 | return tensorflow::DT_INVALID; |
3211 | } |
3212 | |
3213 | PyObject* CopySequenceSettingIndicesToNull( |
3214 | PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { |
3215 | tensorflow::Safe_PyObjectPtr fast_seq( |
3216 | PySequence_Fast(seq, "unable to allocate" )); |
3217 | int len = PySequence_Fast_GET_SIZE(fast_seq.get()); |
3218 | PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get()); |
3219 | PyObject* result = PyTuple_New(len); |
3220 | for (int i = 0; i < len; i++) { |
3221 | PyObject* item; |
3222 | if (indices.find(i) != indices.end()) { |
3223 | item = Py_None; |
3224 | } else { |
3225 | item = fast_seq_array[i]; |
3226 | } |
3227 | Py_INCREF(item); |
3228 | PyTuple_SET_ITEM(result, i, item); |
3229 | } |
3230 | return result; |
3231 | } |
3232 | |
3233 | PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, |
3234 | PyObject* results, |
3235 | PyObject* forward_pass_name_scope = nullptr) { |
3236 | std::vector<int64_t> input_ids = MakeTensorIDList(inputs); |
3237 | if (PyErr_Occurred()) return nullptr; |
3238 | std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); |
3239 | if (PyErr_Occurred()) return nullptr; |
3240 | |
3241 | bool should_record = false; |
3242 | for (TFE_Py_Tape* tape : SafeTapeSet()) { |
3243 | if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { |
3244 | should_record = true; |
3245 | break; |
3246 | } |
3247 | } |
3248 | if (!should_record) { |
3249 | for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) { |
3250 | if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) { |
3251 | should_record = true; |
3252 | break; |
3253 | } |
3254 | } |
3255 | } |
3256 | if (!should_record) Py_RETURN_NONE; |
3257 | |
3258 | string c_op_name = TFE_GetPythonString(op_name); |
3259 | |
3260 | PyObject* op_outputs; |
3261 | bool op_outputs_tuple_created = false; |
3262 | |
3263 | if (const auto unused_output_indices = |
3264 | OpGradientUnusedOutputIndices(c_op_name)) { |
3265 | if (unused_output_indices->empty()) { |
3266 | op_outputs = Py_None; |
3267 | } else { |
3268 | op_outputs_tuple_created = true; |
3269 | op_outputs = |
3270 | CopySequenceSettingIndicesToNull(results, *unused_output_indices); |
3271 | } |
3272 | } else { |
3273 | op_outputs = results; |
3274 | } |
3275 | |
3276 | PyObject* op_inputs; |
3277 | bool op_inputs_tuple_created = false; |
3278 | |
3279 | if (const auto unused_input_indices = |
3280 | OpGradientUnusedInputIndices(c_op_name)) { |
3281 | if (unused_input_indices->empty()) { |
3282 | op_inputs = Py_None; |
3283 | } else { |
3284 | op_inputs_tuple_created = true; |
3285 | op_inputs = |
3286 | CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); |
3287 | } |
3288 | } else { |
3289 | op_inputs = inputs; |
3290 | } |
3291 | |
3292 | tensorflow::eager::ForwardFunction<PyObject> py_forward_function( |
3293 | [op_name, attrs, inputs, results]( |
3294 | const std::vector<PyObject*>& input_tangents, |
3295 | std::vector<PyObject*>* output_tangents, bool use_batch) { |
3296 | return CallJVPFunction(op_name, attrs, inputs, results, input_tangents, |
3297 | output_tangents, use_batch); |
3298 | }); |
3299 | tensorflow::eager::ForwardFunction<PyObject>* forward_function; |
3300 | if (c_op_name == "While" || c_op_name == "StatelessWhile" || |
3301 | c_op_name == "If" || c_op_name == "StatelessIf" ) { |
3302 | // Control flow contains non-hashable attributes. Handling them in Python is |
3303 | // a headache, so instead we'll stay as close to GradientTape's handling as |
3304 | // possible (a null forward function means the accumulator forwards to a |
3305 | // tape). |
3306 | // |
3307 | // This is safe to do since we'll only see control flow when graph building, |
3308 | // in which case we can rely on pruning. |
3309 | forward_function = nullptr; |
3310 | } else { |
3311 | forward_function = &py_forward_function; |
3312 | } |
3313 | |
3314 | PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); |
3315 | |
3316 | if (!forward_pass_name_scope) forward_pass_name_scope = Py_None; |
3317 | |
3318 | TapeSetRecordOperation( |
3319 | op_name, inputs, results, input_ids, input_dtypes, |
3320 | [op_name, attrs, num_inputs, op_inputs, op_outputs, |
3321 | forward_pass_name_scope]() { |
3322 | Py_INCREF(op_name); |
3323 | Py_INCREF(attrs); |
3324 | Py_INCREF(num_inputs); |
3325 | Py_INCREF(op_inputs); |
3326 | Py_INCREF(op_outputs); |
3327 | Py_INCREF(forward_pass_name_scope); |
3328 | PyBackwardFunction* function = new PyBackwardFunction( |
3329 | [op_name, attrs, num_inputs, op_inputs, op_outputs, |
3330 | forward_pass_name_scope]( |
3331 | PyObject* output_grads, |
3332 | const std::vector<int64_t>& unneeded_gradients) { |
3333 | if (PyErr_Occurred()) { |
3334 | return static_cast<PyObject*>(nullptr); |
3335 | } |
3336 | tensorflow::Safe_PyObjectPtr skip_input_indices; |
3337 | if (!unneeded_gradients.empty()) { |
3338 | skip_input_indices.reset( |
3339 | PyTuple_New(unneeded_gradients.size())); |
3340 | for (int i = 0; i < unneeded_gradients.size(); i++) { |
3341 | PyTuple_SET_ITEM( |
3342 | skip_input_indices.get(), i, |
3343 | GetPythonObjectFromInt(unneeded_gradients[i])); |
3344 | } |
3345 | } else { |
3346 | Py_INCREF(Py_None); |
3347 | skip_input_indices.reset(Py_None); |
3348 | } |
3349 | tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue( |
3350 | "OOOOOOOO" , op_name, attrs, num_inputs, op_inputs, op_outputs, |
3351 | output_grads, skip_input_indices.get(), |
3352 | forward_pass_name_scope)); |
3353 | |
3354 | tensorflow::Safe_PyObjectPtr result( |
3355 | PyObject_CallObject(gradient_function, callback_args.get())); |
3356 | |
3357 | if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr); |
3358 | |
3359 | return tensorflow::swig::Flatten(result.get()); |
3360 | }); |
3361 | return function; |
3362 | }, |
3363 | [op_name, attrs, num_inputs, op_inputs, op_outputs, |
3364 | forward_pass_name_scope](PyBackwardFunction* backward_function) { |
3365 | Py_DECREF(op_name); |
3366 | Py_DECREF(attrs); |
3367 | Py_DECREF(num_inputs); |
3368 | Py_DECREF(op_inputs); |
3369 | Py_DECREF(op_outputs); |
3370 | Py_DECREF(forward_pass_name_scope); |
3371 | |
3372 | delete backward_function; |
3373 | }, |
3374 | forward_function); |
3375 | |
3376 | Py_DECREF(num_inputs); |
3377 | if (op_outputs_tuple_created) Py_DECREF(op_outputs); |
3378 | if (op_inputs_tuple_created) Py_DECREF(op_inputs); |
3379 | |
3380 | if (PyErr_Occurred()) { |
3381 | return nullptr; |
3382 | } |
3383 | |
3384 | Py_RETURN_NONE; |
3385 | } |
3386 | |
3387 | void MaybeNotifyVariableAccessed(PyObject* input) { |
3388 | DCHECK(CheckResourceVariable(input)); |
3389 | DCHECK(PyObject_HasAttrString(input, "_trainable" )); |
3390 | |
3391 | tensorflow::Safe_PyObjectPtr trainable( |
3392 | PyObject_GetAttrString(input, "_trainable" )); |
3393 | if (trainable.get() == Py_False) return; |
3394 | TFE_Py_TapeVariableAccessed(input); |
3395 | TFE_Py_VariableWatcherVariableAccessed(input); |
3396 | } |
3397 | |
3398 | bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, |
3399 | PyObject* input, tensorflow::Safe_PyObjectPtr* output, |
3400 | TF_Status* status) { |
3401 | MaybeNotifyVariableAccessed(input); |
3402 | |
3403 | TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp" , status); |
3404 | auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); |
3405 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) |
3406 | return false; |
3407 | |
3408 | TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); |
3409 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) |
3410 | return false; |
3411 | |
3412 | // Set dtype |
3413 | DCHECK(PyObject_HasAttrString(input, "_dtype" )); |
3414 | tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype" )); |
3415 | int value; |
3416 | if (!ParseTypeValue("_dtype" , dtype.get(), status, &value)) { |
3417 | return false; |
3418 | } |
3419 | TFE_OpSetAttrType(op, "dtype" , static_cast<TF_DataType>(value)); |
3420 | |
3421 | // Get handle |
3422 | tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle" )); |
3423 | if (!EagerTensor_CheckExact(handle.get())) return false; |
3424 | TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); |
3425 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) |
3426 | return false; |
3427 | |
3428 | int num_retvals = 1; |
3429 | TFE_TensorHandle* output_handle; |
3430 | TFE_Execute(op, &output_handle, &num_retvals, status); |
3431 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) |
3432 | return false; |
3433 | |
3434 | // Always create the py object (and correctly DECREF it) from the returned |
3435 | // value, else the data will leak. |
3436 | output->reset(EagerTensorFromHandle(output_handle)); |
3437 | |
3438 | // TODO(nareshmodi): Should we run post exec callbacks here? |
3439 | if (parent_op_exec_info.run_gradient_callback) { |
3440 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); |
3441 | PyTuple_SET_ITEM(inputs.get(), 0, handle.release()); |
3442 | |
3443 | tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); |
3444 | Py_INCREF(output->get()); // stay alive after since tuple steals. |
3445 | PyTuple_SET_ITEM(outputs.get(), 0, output->get()); |
3446 | |
3447 | tensorflow::Safe_PyObjectPtr op_string( |
3448 | GetPythonObjectFromString("ReadVariableOp" )); |
3449 | if (!RecordGradient(op_string.get(), inputs.get(), Py_None, |
3450 | outputs.get())) { |
3451 | return false; |
3452 | } |
3453 | } |
3454 | |
3455 | return true; |
3456 | } |
3457 | |
3458 | // Supports 3 cases at the moment: |
3459 | // i) input is an EagerTensor. |
3460 | // ii) input is a ResourceVariable - in this case, the is_variable param is |
3461 | // set to true. |
3462 | // iii) input is an arbitrary python list/tuple (note, this handling doesn't |
3463 | // support packing). |
3464 | // |
3465 | // NOTE: dtype_hint_getter must *always* return a PyObject that can be |
3466 | // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly |
3467 | // increfs Py_None). |
3468 | // |
3469 | // NOTE: This function sets a python error directly, and returns false. |
3470 | // TF_Status is only passed since we don't want to have to reallocate it. |
3471 | bool ConvertToTensor( |
3472 | const FastPathOpExecInfo& op_exec_info, PyObject* input, |
3473 | tensorflow::Safe_PyObjectPtr* output_handle, |
3474 | // This gets a hint for this particular input. |
3475 | const std::function<tensorflow::DataType()>& dtype_hint_getter, |
3476 | // This sets the dtype after conversion is complete. |
3477 | const std::function<void(const tensorflow::DataType dtype)>& dtype_setter, |
3478 | TF_Status* status) { |
3479 | if (EagerTensor_CheckExact(input)) { |
3480 | Py_INCREF(input); |
3481 | output_handle->reset(input); |
3482 | return true; |
3483 | } else if (CheckResourceVariable(input)) { |
3484 | return ReadVariableOp(op_exec_info, input, output_handle, status); |
3485 | } |
3486 | |
3487 | // The hint comes from a supposedly similarly typed tensor. |
3488 | tensorflow::DataType dtype_hint = dtype_hint_getter(); |
3489 | |
3490 | TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor( |
3491 | op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name); |
3492 | if (handle == nullptr) { |
3493 | return tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr); |
3494 | } |
3495 | |
3496 | output_handle->reset(EagerTensorFromHandle(handle)); |
3497 | dtype_setter( |
3498 | static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle))); |
3499 | |
3500 | return true; |
3501 | } |
3502 | |
3503 | // Adds input and type attr to the op, and to the list of flattened |
3504 | // inputs/attrs. |
3505 | bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input, |
3506 | const bool add_type_attr, |
3507 | const tensorflow::OpDef::ArgDef& input_arg, |
3508 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs, |
3509 | std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs, |
3510 | TFE_Op* op, TF_Status* status) { |
3511 | // py_eager_tensor's ownership is transferred to flattened_inputs if it is |
3512 | // required, else the object is destroyed and DECREF'd when the object goes |
3513 | // out of scope in this function. |
3514 | tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; |
3515 | |
3516 | if (!ConvertToTensor( |
3517 | *op_exec_info, input, &py_eager_tensor, |
3518 | [&]() { |
3519 | if (input_arg.type() != tensorflow::DataType::DT_INVALID) { |
3520 | return input_arg.type(); |
3521 | } |
3522 | return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info); |
3523 | }, |
3524 | [&](const tensorflow::DataType dtype) { |
3525 | op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype; |
3526 | }, |
3527 | status)) { |
3528 | return false; |
3529 | } |
3530 | |
3531 | TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); |
3532 | |
3533 | if (add_type_attr && !input_arg.type_attr().empty()) { |
3534 | auto dtype = TFE_TensorHandleDataType(input_handle); |
3535 | TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype); |
3536 | if (flattened_attrs != nullptr) { |
3537 | flattened_attrs->emplace_back( |
3538 | GetPythonObjectFromString(input_arg.type_attr())); |
3539 | flattened_attrs->emplace_back(PyLong_FromLong(dtype)); |
3540 | } |
3541 | } |
3542 | |
3543 | if (flattened_inputs != nullptr) { |
3544 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); |
3545 | } |
3546 | |
3547 | TFE_OpAddInput(op, input_handle, status); |
3548 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
3549 | return false; |
3550 | } |
3551 | |
3552 | return true; |
3553 | } |
3554 | |
3555 | const char* GetDeviceName(PyObject* py_device_name) { |
3556 | if (py_device_name != Py_None) { |
3557 | return TFE_GetPythonString(py_device_name); |
3558 | } |
3559 | return nullptr; |
3560 | } |
3561 | |
3562 | bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) { |
3563 | if (!PySequence_Check(seq)) { |
3564 | PyErr_SetString(PyExc_TypeError, |
3565 | Printf("expected a sequence for attr %s, got %s instead" , |
3566 | attr_name.data(), seq->ob_type->tp_name) |
3567 | .data()); |
3568 | |
3569 | return false; |
3570 | } |
3571 | if (PyArray_Check(seq) && |
3572 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) { |
3573 | PyErr_SetString(PyExc_ValueError, |
3574 | Printf("expected a sequence for attr %s, got an ndarray " |
3575 | "with rank %d instead" , |
3576 | attr_name.data(), |
3577 | PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq))) |
3578 | .data()); |
3579 | return false; |
3580 | } |
3581 | return true; |
3582 | } |
3583 | |
3584 | bool RunCallbacks( |
3585 | const FastPathOpExecInfo& op_exec_info, PyObject* args, |
3586 | int num_inferred_attrs, |
3587 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs, |
3588 | const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs, |
3589 | PyObject* flattened_result) { |
3590 | DCHECK(op_exec_info.run_callbacks); |
3591 | |
3592 | tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); |
3593 | for (int i = 0; i < flattened_inputs.size(); i++) { |
3594 | PyObject* input = flattened_inputs[i].get(); |
3595 | Py_INCREF(input); |
3596 | PyTuple_SET_ITEM(inputs.get(), i, input); |
3597 | } |
3598 | |
3599 | int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs; |
3600 | int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; |
3601 | tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); |
3602 | |
3603 | for (int i = 0; i < num_non_inferred_attrs; i++) { |
3604 | auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i); |
3605 | Py_INCREF(attr); |
3606 | PyTuple_SET_ITEM(attrs.get(), i, attr); |
3607 | } |
3608 | |
3609 | for (int i = num_non_inferred_attrs; i < num_attrs; i++) { |
3610 | PyObject* attr_or_name = |
3611 | flattened_attrs.at(i - num_non_inferred_attrs).get(); |
3612 | Py_INCREF(attr_or_name); |
3613 | PyTuple_SET_ITEM(attrs.get(), i, attr_or_name); |
3614 | } |
3615 | |
3616 | if (op_exec_info.run_gradient_callback) { |
3617 | if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), |
3618 | flattened_result)) { |
3619 | return false; |
3620 | } |
3621 | } |
3622 | |
3623 | if (op_exec_info.run_post_exec_callbacks) { |
3624 | tensorflow::Safe_PyObjectPtr callback_args( |
3625 | Py_BuildValue("OOOOO" , op_exec_info.op_name, inputs.get(), attrs.get(), |
3626 | flattened_result, op_exec_info.name)); |
3627 | for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { |
3628 | PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i); |
3629 | if (!PyCallable_Check(callback_fn)) { |
3630 | PyErr_SetString( |
3631 | PyExc_TypeError, |
3632 | Printf("expected a function for " |
3633 | "post execution callback in index %ld, got %s instead" , |
3634 | i, callback_fn->ob_type->tp_name) |
3635 | .c_str()); |
3636 | return false; |
3637 | } |
3638 | PyObject* callback_result = |
3639 | PyObject_CallObject(callback_fn, callback_args.get()); |
3640 | if (!callback_result) { |
3641 | return false; |
3642 | } |
3643 | Py_DECREF(callback_result); |
3644 | } |
3645 | } |
3646 | |
3647 | return true; |
3648 | } |
3649 | |
3650 | } // namespace |
3651 | |
3652 | PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { |
3653 | tensorflow::profiler::TraceMe activity( |
3654 | "TFE_Py_FastPathExecute_C" , tensorflow::profiler::TraceMeLevel::kInfo); |
3655 | Py_ssize_t args_size = PyTuple_GET_SIZE(args); |
3656 | if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) { |
3657 | PyErr_SetString( |
3658 | PyExc_ValueError, |
3659 | Printf("There must be at least %d items in the input tuple." , |
3660 | FAST_PATH_EXECUTE_ARG_INPUT_START) |
3661 | .c_str()); |
3662 | return nullptr; |
3663 | } |
3664 | |
3665 | FastPathOpExecInfo op_exec_info; |
3666 | |
3667 | PyObject* py_eager_context = |
3668 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT); |
3669 | |
3670 | // TODO(edoper): Use interned string here |
3671 | PyObject* eager_context_handle = |
3672 | PyObject_GetAttrString(py_eager_context, "_context_handle" ); |
3673 | |
3674 | TFE_Context* ctx = reinterpret_cast<TFE_Context*>( |
3675 | PyCapsule_GetPointer(eager_context_handle, nullptr)); |
3676 | op_exec_info.ctx = ctx; |
3677 | op_exec_info.args = args; |
3678 | |
3679 | if (ctx == nullptr) { |
3680 | // The context hasn't been initialized. It will be in the slow path. |
3681 | RaiseFallbackException( |
3682 | "This function does not handle the case of the path where " |
3683 | "all inputs are not already EagerTensors." ); |
3684 | return nullptr; |
3685 | } |
3686 | |
3687 | auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context); |
3688 | if (tld == nullptr) { |
3689 | return nullptr; |
3690 | } |
3691 | op_exec_info.device_name = GetDeviceName(tld->device_name.get()); |
3692 | op_exec_info.callbacks = tld->op_callbacks.get(); |
3693 | |
3694 | op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME); |
3695 | op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME); |
3696 | |
3697 | // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks |
3698 | // (similar to benchmark_tf_gradient_function_*). Also consider using an |
3699 | // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks |
3700 | // point out problems with heap allocs. |
3701 | op_exec_info.run_gradient_callback = |
3702 | !*ThreadTapeIsStopped() && HasAccumulatorOrTape(); |
3703 | op_exec_info.run_post_exec_callbacks = |
3704 | op_exec_info.callbacks != Py_None && |
3705 | PyList_Size(op_exec_info.callbacks) > 0; |
3706 | op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || |
3707 | op_exec_info.run_post_exec_callbacks; |
3708 | |
3709 | TF_Status* status = GetStatus(); |
3710 | const char* op_name = TFE_GetPythonString(op_exec_info.op_name); |
3711 | if (op_name == nullptr) { |
3712 | PyErr_SetString(PyExc_TypeError, |
3713 | Printf("expected a string for op_name, got %s instead" , |
3714 | op_exec_info.op_name->ob_type->tp_name) |
3715 | .c_str()); |
3716 | return nullptr; |
3717 | } |
3718 | |
3719 | TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); |
3720 | |
3721 | auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { |
3722 | ReturnStatus(status); |
3723 | ReturnOp(ctx, op); |
3724 | }); |
3725 | |
3726 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
3727 | return nullptr; |
3728 | } |
3729 | |
3730 | tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace( |
3731 | tensorflow::StackTrace::kStackTraceInitialSize)); |
3732 | |
3733 | const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); |
3734 | if (op_def == nullptr) return nullptr; |
3735 | |
3736 | if (args_size < |
3737 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) { |
3738 | PyErr_SetString( |
3739 | PyExc_ValueError, |
3740 | Printf("Tuple size smaller than intended. Expected to be at least %d, " |
3741 | "was %ld" , |
3742 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), |
3743 | args_size) |
3744 | .c_str()); |
3745 | return nullptr; |
3746 | } |
3747 | |
3748 | if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) { |
3749 | RaiseFallbackException( |
3750 | "This function does not handle the case of the path where " |
3751 | "all inputs are not already EagerTensors." ); |
3752 | return nullptr; |
3753 | } |
3754 | |
3755 | op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def); |
3756 | op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def); |
3757 | |
3758 | // Mapping of attr name to size - used to calculate the number of values |
3759 | // to be expected by the TFE_Execute run. |
3760 | tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes; |
3761 | |
3762 | // Set non-inferred attrs, including setting defaults if the attr is passed in |
3763 | // as None. |
3764 | for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(); |
3765 | i < args_size; i += 2) { |
3766 | PyObject* py_attr_name = PyTuple_GET_ITEM(args, i); |
3767 | const char* attr_name = TFE_GetPythonString(py_attr_name); |
3768 | PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1); |
3769 | |
3770 | // Not creating an index since most of the time there are not more than a |
3771 | // few attrs. |
3772 | // TODO(nareshmodi): Maybe include the index as part of the |
3773 | // OpRegistrationData. |
3774 | for (const auto& attr : op_def->attr()) { |
3775 | if (tensorflow::StringPiece(attr_name) == attr.name()) { |
3776 | SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value, |
3777 | &attr_list_sizes, status); |
3778 | |
3779 | if (!status->status.ok()) { |
3780 | VLOG(1) << "Falling back to slow path for Op \"" << op_def->name() |
3781 | << "\" since we are unable to set the value for attr \"" |
3782 | << attr.name() << "\" due to: " << TF_Message(status); |
3783 | RaiseFallbackException(TF_Message(status)); |
3784 | return nullptr; |
3785 | } |
3786 | |
3787 | break; |
3788 | } |
3789 | } |
3790 | } |
3791 | |
3792 | // Flat attrs and inputs as required by the record_gradient call. The attrs |
3793 | // here only contain inferred attrs (non-inferred attrs are added directly |
3794 | // from the input args). |
3795 | // All items in flattened_attrs and flattened_inputs contain |
3796 | // Safe_PyObjectPtr - any time something steals a reference to this, it must |
3797 | // INCREF. |
3798 | // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work |
3799 | // directly. |
3800 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs = |
3801 | nullptr; |
3802 | std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs = |
3803 | nullptr; |
3804 | |
3805 | // TODO(nareshmodi): Encapsulate callbacks information into a struct. |
3806 | if (op_exec_info.run_callbacks) { |
3807 | flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); |
3808 | flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); |
3809 | } |
3810 | |
3811 | // Add inferred attrs and inputs. |
3812 | // The following code might set duplicate type attrs. This will result in |
3813 | // the CacheKey for the generated AttrBuilder possibly differing from |
3814 | // those where the type attrs are correctly set. Inconsistent CacheKeys |
3815 | // for ops means that there might be unnecessarily duplicated kernels. |
3816 | // TODO(nareshmodi): Fix this. |
3817 | for (int i = 0; i < op_def->input_arg_size(); i++) { |
3818 | const auto& input_arg = op_def->input_arg(i); |
3819 | |
3820 | PyObject* input = |
3821 | PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i); |
3822 | if (!input_arg.number_attr().empty()) { |
3823 | // The item is a homogeneous list. |
3824 | if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr; |
3825 | tensorflow::Safe_PyObjectPtr fast_input( |
3826 | PySequence_Fast(input, "Could not parse sequence." )); |
3827 | if (fast_input.get() == nullptr) { |
3828 | return nullptr; |
3829 | } |
3830 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get()); |
3831 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get()); |
3832 | |
3833 | TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); |
3834 | if (op_exec_info.run_callbacks) { |
3835 | flattened_attrs->emplace_back( |
3836 | GetPythonObjectFromString(input_arg.number_attr())); |
3837 | flattened_attrs->emplace_back(PyLong_FromLong(len)); |
3838 | } |
3839 | attr_list_sizes[input_arg.number_attr()] = len; |
3840 | |
3841 | if (len > 0) { |
3842 | // First item adds the type attr. |
3843 | if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg, |
3844 | flattened_attrs.get(), flattened_inputs.get(), op, |
3845 | status)) { |
3846 | return nullptr; |
3847 | } |
3848 | |
3849 | for (Py_ssize_t j = 1; j < len; j++) { |
3850 | // Since the list is homogeneous, we don't need to re-add the attr. |
3851 | if (!AddInputToOp(&op_exec_info, fast_input_array[j], false, |
3852 | input_arg, nullptr /* flattened_attrs */, |
3853 | flattened_inputs.get(), op, status)) { |
3854 | return nullptr; |
3855 | } |
3856 | } |
3857 | } |
3858 | } else if (!input_arg.type_list_attr().empty()) { |
3859 | // The item is a heterogeneous list. |
3860 | if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) { |
3861 | return nullptr; |
3862 | } |
3863 | tensorflow::Safe_PyObjectPtr fast_input( |
3864 | PySequence_Fast(input, "Could not parse sequence." )); |
3865 | if (fast_input.get() == nullptr) { |
3866 | return nullptr; |
3867 | } |
3868 | const string& attr_name = input_arg.type_list_attr(); |
3869 | Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get()); |
3870 | PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get()); |
3871 | tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); |
3872 | PyObject* py_attr_value = nullptr; |
3873 | if (op_exec_info.run_callbacks) { |
3874 | py_attr_value = PyTuple_New(len); |
3875 | } |
3876 | for (Py_ssize_t j = 0; j < len; j++) { |
3877 | PyObject* py_input = fast_input_array[j]; |
3878 | tensorflow::Safe_PyObjectPtr py_eager_tensor; |
3879 | if (!ConvertToTensor( |
3880 | op_exec_info, py_input, &py_eager_tensor, |
3881 | []() { return tensorflow::DT_INVALID; }, |
3882 | [](const tensorflow::DataType dtype) {}, status)) { |
3883 | return nullptr; |
3884 | } |
3885 | |
3886 | TFE_TensorHandle* input_handle = |
3887 | EagerTensor_Handle(py_eager_tensor.get()); |
3888 | |
3889 | attr_value[j] = TFE_TensorHandleDataType(input_handle); |
3890 | |
3891 | TFE_OpAddInput(op, input_handle, status); |
3892 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
3893 | return nullptr; |
3894 | } |
3895 | |
3896 | if (op_exec_info.run_callbacks) { |
3897 | flattened_inputs->emplace_back(std::move(py_eager_tensor)); |
3898 | |
3899 | PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j])); |
3900 | } |
3901 | } |
3902 | if (op_exec_info.run_callbacks) { |
3903 | flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name)); |
3904 | flattened_attrs->emplace_back(py_attr_value); |
3905 | } |
3906 | TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), |
3907 | attr_value.size()); |
3908 | attr_list_sizes[attr_name] = len; |
3909 | } else { |
3910 | // The item is a single item. |
3911 | if (!AddInputToOp(&op_exec_info, input, true, input_arg, |
3912 | flattened_attrs.get(), flattened_inputs.get(), op, |
3913 | status)) { |
3914 | return nullptr; |
3915 | } |
3916 | } |
3917 | } |
3918 | |
3919 | int64_t num_outputs = 0; |
3920 | for (int i = 0; i < op_def->output_arg_size(); i++) { |
3921 | const auto& output_arg = op_def->output_arg(i); |
3922 | int64_t delta = 1; |
3923 | if (!output_arg.number_attr().empty()) { |
3924 | delta = attr_list_sizes[output_arg.number_attr()]; |
3925 | } else if (!output_arg.type_list_attr().empty()) { |
3926 | delta = attr_list_sizes[output_arg.type_list_attr()]; |
3927 | } |
3928 | if (delta < 0) { |
3929 | RaiseFallbackException( |
3930 | "Attributes suggest that the size of an output list is less than 0" ); |
3931 | return nullptr; |
3932 | } |
3933 | num_outputs += delta; |
3934 | } |
3935 | |
3936 | // If number of retvals is larger than int32, we error out. |
3937 | if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) { |
3938 | PyErr_SetString( |
3939 | PyExc_ValueError, |
3940 | Printf("Number of outputs is too big: %ld" , num_outputs).c_str()); |
3941 | return nullptr; |
3942 | } |
3943 | int num_retvals = num_outputs; |
3944 | |
3945 | tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); |
3946 | |
3947 | Py_BEGIN_ALLOW_THREADS; |
3948 | TFE_Execute(op, retvals.data(), &num_retvals, status); |
3949 | Py_END_ALLOW_THREADS; |
3950 | |
3951 | if (!status->status.ok()) { |
3952 | // Augment the status with the op_name for easier debugging similar to |
3953 | // TFE_Py_Execute. |
3954 | status->status = tensorflow::errors::CreateWithUpdatedMessage( |
3955 | status->status, tensorflow::strings::StrCat( |
3956 | TF_Message(status), " [Op:" , |
3957 | TFE_GetPythonString(op_exec_info.op_name), "]" )); |
3958 | tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr); |
3959 | return nullptr; |
3960 | } |
3961 | |
3962 | tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); |
3963 | for (int i = 0; i < num_retvals; ++i) { |
3964 | PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i])); |
3965 | } |
3966 | |
3967 | if (op_exec_info.run_callbacks) { |
3968 | if (!RunCallbacks( |
3969 | op_exec_info, args, |
3970 | FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(), |
3971 | *flattened_inputs, *flattened_attrs, flat_result.get())) { |
3972 | return nullptr; |
3973 | } |
3974 | } |
3975 | |
3976 | // Unflatten results. |
3977 | if (op_def->output_arg_size() == 0) { |
3978 | Py_RETURN_NONE; |
3979 | } |
3980 | |
3981 | if (op_def->output_arg_size() == 1) { |
3982 | if (!op_def->output_arg(0).number_attr().empty() || |
3983 | !op_def->output_arg(0).type_list_attr().empty()) { |
3984 | return flat_result.release(); |
3985 | } else { |
3986 | auto* result = PyList_GET_ITEM(flat_result.get(), 0); |
3987 | Py_INCREF(result); |
3988 | return result; |
3989 | } |
3990 | } |
3991 | |
3992 | // Correctly output the results that are made into a namedtuple. |
3993 | PyObject* result = PyList_New(op_def->output_arg_size()); |
3994 | int flat_result_index = 0; |
3995 | for (int i = 0; i < op_def->output_arg_size(); i++) { |
3996 | if (!op_def->output_arg(i).number_attr().empty()) { |
3997 | int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; |
3998 | PyObject* inner_list = PyList_New(list_length); |
3999 | for (int j = 0; j < list_length; j++) { |
4000 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
4001 | Py_INCREF(obj); |
4002 | PyList_SET_ITEM(inner_list, j, obj); |
4003 | } |
4004 | PyList_SET_ITEM(result, i, inner_list); |
4005 | } else if (!op_def->output_arg(i).type_list_attr().empty()) { |
4006 | int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; |
4007 | PyObject* inner_list = PyList_New(list_length); |
4008 | for (int j = 0; j < list_length; j++) { |
4009 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
4010 | Py_INCREF(obj); |
4011 | PyList_SET_ITEM(inner_list, j, obj); |
4012 | } |
4013 | PyList_SET_ITEM(result, i, inner_list); |
4014 | } else { |
4015 | PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
4016 | Py_INCREF(obj); |
4017 | PyList_SET_ITEM(result, i, obj); |
4018 | } |
4019 | } |
4020 | return result; |
4021 | } |
4022 | |
4023 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, |
4024 | PyObject* attrs, PyObject* results, |
4025 | PyObject* forward_pass_name_scope) { |
4026 | if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { |
4027 | Py_RETURN_NONE; |
4028 | } |
4029 | |
4030 | return RecordGradient(op_name, inputs, attrs, results, |
4031 | forward_pass_name_scope); |
4032 | } |
4033 | |
4034 | // A method prints incoming messages directly to Python's |
4035 | // stdout using Python's C API. This is necessary in Jupyter notebooks |
4036 | // and colabs where messages to the C stdout don't go to the notebook |
4037 | // cell outputs, but calls to Python's stdout do. |
4038 | void PrintToPythonStdout(const char* msg) { |
4039 | if (Py_IsInitialized()) { |
4040 | PyGILState_STATE py_threadstate; |
4041 | py_threadstate = PyGILState_Ensure(); |
4042 | |
4043 | string string_msg = msg; |
4044 | // PySys_WriteStdout truncates strings over 1000 bytes, so |
4045 | // we write the message in chunks small enough to not be truncated. |
4046 | int CHUNK_SIZE = 900; |
4047 | auto len = string_msg.length(); |
4048 | for (int i = 0; i < len; i += CHUNK_SIZE) { |
4049 | PySys_WriteStdout("%s" , string_msg.substr(i, CHUNK_SIZE).c_str()); |
4050 | } |
4051 | |
4052 | // Force flushing to make sure print newlines aren't interleaved in |
4053 | // some colab environments |
4054 | PyRun_SimpleString("import sys; sys.stdout.flush()" ); |
4055 | |
4056 | PyGILState_Release(py_threadstate); |
4057 | } |
4058 | } |
4059 | |
4060 | // Register PrintToPythonStdout as a log listener, to allow |
4061 | // printing in colabs and jupyter notebooks to work. |
4062 | void TFE_Py_EnableInteractivePythonLogging() { |
4063 | static bool enabled_interactive_logging = false; |
4064 | if (!enabled_interactive_logging) { |
4065 | enabled_interactive_logging = true; |
4066 | TF_RegisterLogListener(PrintToPythonStdout); |
4067 | } |
4068 | } |
4069 | |
4070 | namespace { |
4071 | // TODO(mdan): Clean this. Maybe by decoupling context lifetime from Python GC? |
4072 | // Weak reference to the Python Context (see tensorflow/python/eager/context.py) |
4073 | // object currently active. This object is opaque and wrapped inside a Python |
4074 | // Capsule. However, the EagerContext object it holds is tracked by the |
4075 | // global_c_eager_context object. |
4076 | // Also see common_runtime/eager/context.cc. |
4077 | PyObject* global_py_eager_context = nullptr; |
4078 | } // namespace |
4079 | |
4080 | PyObject* TFE_Py_SetEagerContext(PyObject* py_context) { |
4081 | Py_XDECREF(global_py_eager_context); |
4082 | global_py_eager_context = PyWeakref_NewRef(py_context, nullptr); |
4083 | if (global_py_eager_context == nullptr) { |
4084 | return nullptr; |
4085 | } |
4086 | Py_RETURN_NONE; |
4087 | } |
4088 | |
4089 | PyObject* GetPyEagerContext() { |
4090 | if (global_py_eager_context == nullptr) { |
4091 | PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set" ); |
4092 | return nullptr; |
4093 | } |
4094 | PyObject* py_context = PyWeakref_GET_OBJECT(global_py_eager_context); |
4095 | if (py_context == Py_None) { |
4096 | PyErr_SetString(PyExc_RuntimeError, |
4097 | "Python eager context has been destroyed" ); |
4098 | return nullptr; |
4099 | } |
4100 | Py_INCREF(py_context); |
4101 | return py_context; |
4102 | } |
4103 | |
4104 | namespace { |
4105 | |
4106 | // Default values for thread_local_data fields. |
4107 | struct EagerContextThreadLocalDataDefaults { |
4108 | tensorflow::Safe_PyObjectPtr is_eager; |
4109 | tensorflow::Safe_PyObjectPtr device_spec; |
4110 | }; |
4111 | |
4112 | // Maps each py_eager_context object to its thread_local_data. |
4113 | // |
4114 | // Note: we need to use the python Context object as the key here (and not |
4115 | // its handle object), because the handle object isn't created until the |
4116 | // context is initialized; but thread_local_data is potentially accessed |
4117 | // before then. |
4118 | using EagerContextThreadLocalDataMap = absl::flat_hash_map< |
4119 | PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>; |
4120 | thread_local EagerContextThreadLocalDataMap* |
4121 | eager_context_thread_local_data_map = nullptr; |
4122 | |
4123 | // Maps each py_eager_context object to default values. |
4124 | using EagerContextThreadLocalDataDefaultsMap = |
4125 | absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>; |
4126 | EagerContextThreadLocalDataDefaultsMap* |
4127 | eager_context_thread_local_data_defaults = nullptr; |
4128 | |
4129 | } // namespace |
4130 | |
4131 | namespace tensorflow { |
4132 | |
4133 | void MakeEagerContextThreadLocalData(PyObject* py_eager_context, |
4134 | PyObject* is_eager, |
4135 | PyObject* device_spec) { |
4136 | DCheckPyGilState(); |
4137 | if (eager_context_thread_local_data_defaults == nullptr) { |
4138 | absl::LeakCheckDisabler disabler; |
4139 | eager_context_thread_local_data_defaults = |
4140 | new EagerContextThreadLocalDataDefaultsMap(); |
4141 | } |
4142 | if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) { |
4143 | PyErr_SetString(PyExc_AssertionError, |
4144 | "MakeEagerContextThreadLocalData may not be called " |
4145 | "twice on the same eager Context object." ); |
4146 | } |
4147 | |
4148 | auto& defaults = |
4149 | (*eager_context_thread_local_data_defaults)[py_eager_context]; |
4150 | Py_INCREF(is_eager); |
4151 | defaults.is_eager.reset(is_eager); |
4152 | Py_INCREF(device_spec); |
4153 | defaults.device_spec.reset(device_spec); |
4154 | } |
4155 | |
4156 | EagerContextThreadLocalData* GetEagerContextThreadLocalData( |
4157 | PyObject* py_eager_context) { |
4158 | if (eager_context_thread_local_data_defaults == nullptr) { |
4159 | PyErr_SetString(PyExc_AssertionError, |
4160 | "MakeEagerContextThreadLocalData must be called " |
4161 | "before GetEagerContextThreadLocalData." ); |
4162 | return nullptr; |
4163 | } |
4164 | auto defaults = |
4165 | eager_context_thread_local_data_defaults->find(py_eager_context); |
4166 | if (defaults == eager_context_thread_local_data_defaults->end()) { |
4167 | PyErr_SetString(PyExc_AssertionError, |
4168 | "MakeEagerContextThreadLocalData must be called " |
4169 | "before GetEagerContextThreadLocalData." ); |
4170 | return nullptr; |
4171 | } |
4172 | |
4173 | if (eager_context_thread_local_data_map == nullptr) { |
4174 | absl::LeakCheckDisabler disabler; |
4175 | eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap(); |
4176 | } |
4177 | auto& thread_local_data = |
4178 | (*eager_context_thread_local_data_map)[py_eager_context]; |
4179 | |
4180 | if (!thread_local_data) { |
4181 | thread_local_data.reset(new EagerContextThreadLocalData()); |
4182 | |
4183 | Safe_PyObjectPtr is_eager( |
4184 | PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr)); |
4185 | if (!is_eager) return nullptr; |
4186 | thread_local_data->is_eager = PyObject_IsTrue(is_eager.get()); |
4187 | |
4188 | #if PY_MAJOR_VERSION >= 3 |
4189 | PyObject* scope_name = PyUnicode_FromString("" ); |
4190 | #else |
4191 | PyObject* scope_name = PyString_FromString("" ); |
4192 | #endif |
4193 | thread_local_data->scope_name.reset(scope_name); |
4194 | |
4195 | #if PY_MAJOR_VERSION >= 3 |
4196 | PyObject* device_name = PyUnicode_FromString("" ); |
4197 | #else |
4198 | PyObject* device_name = PyString_FromString("" ); |
4199 | #endif |
4200 | thread_local_data->device_name.reset(device_name); |
4201 | |
4202 | Py_INCREF(defaults->second.device_spec.get()); |
4203 | thread_local_data->device_spec.reset(defaults->second.device_spec.get()); |
4204 | |
4205 | Py_INCREF(Py_None); |
4206 | thread_local_data->function_call_options.reset(Py_None); |
4207 | |
4208 | Py_INCREF(Py_None); |
4209 | thread_local_data->executor.reset(Py_None); |
4210 | |
4211 | thread_local_data->op_callbacks.reset(PyList_New(0)); |
4212 | } |
4213 | return thread_local_data.get(); |
4214 | } |
4215 | |
4216 | void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) { |
4217 | DCheckPyGilState(); |
4218 | if (eager_context_thread_local_data_defaults) { |
4219 | eager_context_thread_local_data_defaults->erase(py_eager_context); |
4220 | } |
4221 | if (eager_context_thread_local_data_map) { |
4222 | eager_context_thread_local_data_map->erase(py_eager_context); |
4223 | } |
4224 | } |
4225 | |
4226 | } // namespace tensorflow |
4227 | |