1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_ |
17 | #define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_ |
18 | |
19 | // Place `<locale>` before <Python.h> to avoid build failure in macOS. |
20 | #include <locale> |
21 | |
22 | // The empty line above is on purpose as otherwise clang-format will |
23 | // automatically move <Python.h> before <locale>. |
24 | #include <Python.h> |
25 | |
26 | #include "absl/container/flat_hash_map.h" |
27 | #include "absl/hash/hash.h" |
28 | #include "absl/strings/string_view.h" |
29 | #include "tensorflow/c/eager/c_api.h" |
30 | #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
31 | #include "tensorflow/core/framework/types.pb.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | // Wrapper-class allowing to use Python hashing/comparison functions |
36 | // for PyObject*. |
37 | // |
38 | // Note that unlike Safe_PyObjectPtr this class does not steal a |
39 | // reference to a Python object. The caller is responsible for doing |
40 | // Py_INCREF/Py_DECREF. |
41 | struct PyObjectPtr { |
42 | template <typename H> |
43 | friend H AbslHashValue(H h, const PyObjectPtr& obj) { |
44 | return H::combine(std::move(h), PyObject_Hash(obj.ptr)); |
45 | } |
46 | |
47 | explicit PyObjectPtr(PyObject* ptr) : ptr(ptr) {} |
48 | |
49 | explicit inline operator PyObject*() const { return ptr; } |
50 | |
51 | inline bool operator==(const PyObjectPtr& other) const { |
52 | // We require exact type equality to account for 0 == 0.0 == False. |
53 | if (Py_TYPE(ptr) != Py_TYPE(other.ptr)) { |
54 | return false; |
55 | } |
56 | |
57 | bool result = PyObject_RichCompareBool(ptr, other.ptr, Py_EQ) > 0; |
58 | CHECK(!PyErr_Occurred()); |
59 | return result; |
60 | } |
61 | |
62 | private: |
63 | PyObject* ptr; |
64 | }; |
65 | |
66 | // Cache mapping PyObject* to the corresponding on-device TFE_TensorHandles. |
67 | // Used to speed up ConvertToEagerTensor for scalars. |
68 | // TODO(slebedev): move ConvertToEagerTensor here. |
69 | struct TFE_TensorHandleCache { |
70 | static TFE_TensorHandleCache* Get(); |
71 | |
72 | TFE_TensorHandleCache() { cache.reserve(64); } |
73 | ~TFE_TensorHandleCache() { DecrefUnrefAll(); } |
74 | |
75 | TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype, |
76 | TFE_Context* ctx, |
77 | absl::string_view device_name) const; |
78 | |
79 | void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, |
80 | absl::string_view device_name, TFE_TensorHandle* h); |
81 | |
82 | void Clear(); |
83 | |
84 | private: |
85 | // TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object |
86 | // should have TFE_TensorHandleCache instance. Migrate once we Python context |
87 | // object is backed by C++ data structure. b/169790439 |
88 | using Key = std::tuple<PyObjectPtr, tensorflow::DataType, TFE_Context*, |
89 | absl::string_view>; |
90 | |
91 | void DecrefUnrefAll() { |
92 | for (const auto& p : cache) { |
93 | Py_DECREF(static_cast<PyObject*>(std::get<0>(p.first))); |
94 | TFE_DeleteTensorHandle(p.second); |
95 | } |
96 | } |
97 | |
98 | // Not guarded by a mutex because the code is only used while the |
99 | // GIL is held. |
100 | absl::flat_hash_map<Key, TFE_TensorHandle*> cache; |
101 | }; |
102 | |
103 | } // namespace tensorflow |
104 | |
105 | #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_ |
106 | |