1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/eager/pywrap_tensor_conversion.h" |
17 | |
18 | #include "absl/container/flat_hash_map.h" |
19 | #include "absl/hash/hash.h" |
20 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
21 | #include "tensorflow/core/lib/monitoring/counter.h" |
22 | #include "tensorflow/core/platform/logging.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | auto* scalar_cache_hits = tensorflow::monitoring::Counter<0>::New( |
27 | "/tensorflow/eager/python/scalar_cache_hits" , |
28 | "Number of times a scalar TFE_TensorHandle was retrieved from cache" ); |
29 | auto* scalar_cache_misses = tensorflow::monitoring::Counter<0>::New( |
30 | "/tensorflow/eager/python/scalar_cache_misses" , |
31 | "Number of times a scalar TFE_TensorHandle was not available in cache" ); |
32 | |
33 | TFE_TensorHandleCache* TFE_TensorHandleCache::Get() { |
34 | // TODO(slebedev): link with Context (in context.py) instead of having |
35 | // a static global? |
36 | static auto* cache = new TFE_TensorHandleCache(); |
37 | return cache; |
38 | } |
39 | |
40 | TFE_TensorHandle* TFE_TensorHandleCache::Lookup( |
41 | PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, |
42 | absl::string_view device_name) const { |
43 | CHECK_NOTNULL(value); |
44 | const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name}); |
45 | if (it == cache.end()) { |
46 | scalar_cache_misses->GetCell()->IncrementBy(1); |
47 | return nullptr; |
48 | } |
49 | |
50 | scalar_cache_hits->GetCell()->IncrementBy(1); |
51 | auto* h = it->second; |
52 | return tensorflow::wrap(tensorflow::unwrap(h)->Copy()); |
53 | } |
54 | |
55 | void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, |
56 | TFE_Context* ctx, |
57 | absl::string_view device_name, |
58 | TFE_TensorHandle* h) { |
59 | Py_INCREF(value); |
60 | cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name}, |
61 | tensorflow::wrap(tensorflow::unwrap(h)->Copy())); |
62 | } |
63 | |
64 | void TFE_TensorHandleCache::Clear() { |
65 | DecrefUnrefAll(); |
66 | cache.clear(); |
67 | } |
68 | |
69 | } // namespace tensorflow |
70 | |