1 | #include <c10/util/flat_hash_map.h> |
2 | #include <torch/csrc/Exceptions.h> |
3 | #include <torch/csrc/python_dimname.h> |
4 | #include <torch/csrc/utils/python_strings.h> |
5 | |
6 | namespace torch { |
7 | |
8 | struct InternedStringsTable { |
9 | InternedStringsTable() = default; |
10 | ~InternedStringsTable(); |
11 | InternedStringsTable(const InternedStringsTable&) = delete; |
12 | InternedStringsTable& operator=(InternedStringsTable const&) = delete; |
13 | InternedStringsTable(InternedStringsTable&&) = delete; |
14 | InternedStringsTable& operator=(InternedStringsTable&&) = delete; |
15 | |
16 | at::optional<at::Dimname> lookup(PyObject* obj); |
17 | // Precondition: obj is an interned python string. |
18 | void addMapping(PyObject* obj, at::Dimname dimname); |
19 | |
20 | private: |
21 | ska::flat_hash_map<PyObject*, at::Dimname> py_interned_string_to_dimname_; |
22 | }; |
23 | |
24 | InternedStringsTable kPyInternedStringToDimname; |
25 | |
26 | InternedStringsTable::~InternedStringsTable() { |
27 | for (auto it = py_interned_string_to_dimname_.begin(); |
28 | it != py_interned_string_to_dimname_.end(); |
29 | ++it) { |
30 | // See Note [References to python interned strings] |
31 | Py_DECREF(it->first); |
32 | } |
33 | } |
34 | |
35 | at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) { |
36 | auto it = py_interned_string_to_dimname_.find(obj); |
37 | if (it == py_interned_string_to_dimname_.end()) { |
38 | return at::nullopt; |
39 | } |
40 | return it->second; |
41 | } |
42 | |
43 | void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) { |
44 | // Note [References to python interned strings] |
45 | // If a Python interned string has no references to it, then it gets |
46 | // deallocated, invalidating this mapping. Let's immortalize the string by |
47 | // holding a refcount to it and releasing it in the destructor |
48 | Py_INCREF(obj); |
49 | py_interned_string_to_dimname_.emplace(obj, dimname); |
50 | } |
51 | |
52 | } // namespace torch |
53 | |
54 | bool THPUtils_checkDimname(PyObject* obj) { |
55 | return obj == Py_None || THPUtils_checkString(obj); |
56 | } |
57 | |
58 | // To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if |
59 | // it is a list or tuple and its first elt is a Dimname |
60 | bool THPUtils_checkDimnameList(PyObject* obj) { |
61 | auto tuple = PyTuple_Check(obj); |
62 | if (!tuple && !PyList_Check(obj)) { |
63 | return false; |
64 | } |
65 | // NOLINTNEXTLINE(bugprone-branch-clone) |
66 | const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); |
67 | if (size == 0) { |
68 | return true; |
69 | } |
70 | PyObject* first_elt = |
71 | tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); |
72 | return THPUtils_checkDimname(first_elt); |
73 | } |
74 | |
75 | at::Dimname THPDimname_parse(PyObject* obj) { |
76 | if (obj == Py_None) { |
77 | return at::Dimname::wildcard(); |
78 | } |
79 | |
80 | if (!THPUtils_checkString(obj)) { |
81 | throw torch::TypeError( |
82 | "expected None or string for Dimname but got %s" , |
83 | Py_TYPE(obj)->tp_name); |
84 | } |
85 | |
86 | if (!THPUtils_isInterned(obj)) { |
87 | // internStringInPlace decrefs obj and increfs the result. Because we're |
88 | // not actually returning the result to the user, we need to undo these. |
89 | // See |
90 | // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace |
91 | Py_INCREF(obj); |
92 | THPUtils_internStringInPlace(&obj); |
93 | Py_DECREF(obj); |
94 | } |
95 | |
96 | auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj); |
97 | if (maybeDimname) { |
98 | return *maybeDimname; |
99 | } |
100 | |
101 | const auto name = THPUtils_unpackString(obj); |
102 | auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name)); |
103 | torch::kPyInternedStringToDimname.addMapping(obj, dimname); |
104 | return dimname; |
105 | } |
106 | |