1#include <c10/core/TensorImpl.h>
2#include <ATen/ThreadLocalPythonObjects.h>
3#include <c10/util/Exception.h>
4
5#include <utility>
6
7namespace at {
8namespace impl {
9
10static thread_local ThreadLocalPythonObjects py_objects;
11
12
13void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) {
14 py_objects.obj_dict_[key] = std::move(value);
15}
16
17const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) {
18 TORCH_CHECK(py_objects.obj_dict_.count(key));
19 return py_objects.obj_dict_[key];
20}
21
22bool ThreadLocalPythonObjects::contains(const std::string& key) {
23 return py_objects.obj_dict_.count(key);
24}
25
26void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) {
27 py_objects = std::move(state);
28}
29
30const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() {
31 return py_objects;
32}
33
34
35}
36}
37