1 | #include <c10/core/TensorImpl.h> |
---|---|
2 | #include <ATen/ThreadLocalPythonObjects.h> |
3 | #include <c10/util/Exception.h> |
4 | |
5 | #include <utility> |
6 | |
7 | namespace at { |
8 | namespace impl { |
9 | |
10 | static thread_local ThreadLocalPythonObjects py_objects; |
11 | |
12 | |
13 | void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) { |
14 | py_objects.obj_dict_[key] = std::move(value); |
15 | } |
16 | |
17 | const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) { |
18 | TORCH_CHECK(py_objects.obj_dict_.count(key)); |
19 | return py_objects.obj_dict_[key]; |
20 | } |
21 | |
22 | bool ThreadLocalPythonObjects::contains(const std::string& key) { |
23 | return py_objects.obj_dict_.count(key); |
24 | } |
25 | |
26 | void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) { |
27 | py_objects = std::move(state); |
28 | } |
29 | |
30 | const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() { |
31 | return py_objects; |
32 | } |
33 | |
34 | |
35 | } |
36 | } |
37 |