1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
17#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
18
19// Place `<locale>` before <Python.h> to avoid build failure in macOS.
20#include <locale>
21
22// The empty line above is on purpose as otherwise clang-format will
23// automatically move <Python.h> before <locale>.
24#include <Python.h>
25
26#include "absl/container/flat_hash_map.h"
27#include "absl/hash/hash.h"
28#include "absl/strings/string_view.h"
29#include "tensorflow/c/eager/c_api.h"
30#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
31#include "tensorflow/core/framework/types.pb.h"
32
33namespace tensorflow {
34
35// Wrapper-class allowing to use Python hashing/comparison functions
36// for PyObject*.
37//
38// Note that unlike Safe_PyObjectPtr this class does not steal a
39// reference to a Python object. The caller is responsible for doing
40// Py_INCREF/Py_DECREF.
41struct PyObjectPtr {
42 template <typename H>
43 friend H AbslHashValue(H h, const PyObjectPtr& obj) {
44 return H::combine(std::move(h), PyObject_Hash(obj.ptr));
45 }
46
47 explicit PyObjectPtr(PyObject* ptr) : ptr(ptr) {}
48
49 explicit inline operator PyObject*() const { return ptr; }
50
51 inline bool operator==(const PyObjectPtr& other) const {
52 // We require exact type equality to account for 0 == 0.0 == False.
53 if (Py_TYPE(ptr) != Py_TYPE(other.ptr)) {
54 return false;
55 }
56
57 bool result = PyObject_RichCompareBool(ptr, other.ptr, Py_EQ) > 0;
58 CHECK(!PyErr_Occurred());
59 return result;
60 }
61
62 private:
63 PyObject* ptr;
64};
65
66// Cache mapping PyObject* to the corresponding on-device TFE_TensorHandles.
67// Used to speed up ConvertToEagerTensor for scalars.
68// TODO(slebedev): move ConvertToEagerTensor here.
69struct TFE_TensorHandleCache {
70 static TFE_TensorHandleCache* Get();
71
72 TFE_TensorHandleCache() { cache.reserve(64); }
73 ~TFE_TensorHandleCache() { DecrefUnrefAll(); }
74
75 TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
76 TFE_Context* ctx,
77 absl::string_view device_name) const;
78
79 void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
80 absl::string_view device_name, TFE_TensorHandle* h);
81
82 void Clear();
83
84 private:
85 // TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object
86 // should have TFE_TensorHandleCache instance. Migrate once we Python context
87 // object is backed by C++ data structure. b/169790439
88 using Key = std::tuple<PyObjectPtr, tensorflow::DataType, TFE_Context*,
89 absl::string_view>;
90
91 void DecrefUnrefAll() {
92 for (const auto& p : cache) {
93 Py_DECREF(static_cast<PyObject*>(std::get<0>(p.first)));
94 TFE_DeleteTensorHandle(p.second);
95 }
96 }
97
98 // Not guarded by a mutex because the code is only used while the
99 // GIL is held.
100 absl::flat_hash_map<Key, TFE_TensorHandle*> cache;
101};
102
103} // namespace tensorflow
104
105#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
106