1#include <c10/util/flat_hash_map.h>
2#include <torch/csrc/Exceptions.h>
3#include <torch/csrc/python_dimname.h>
4#include <torch/csrc/utils/python_strings.h>
5
6namespace torch {
7
8struct InternedStringsTable {
9 InternedStringsTable() = default;
10 ~InternedStringsTable();
11 InternedStringsTable(const InternedStringsTable&) = delete;
12 InternedStringsTable& operator=(InternedStringsTable const&) = delete;
13 InternedStringsTable(InternedStringsTable&&) = delete;
14 InternedStringsTable& operator=(InternedStringsTable&&) = delete;
15
16 at::optional<at::Dimname> lookup(PyObject* obj);
17 // Precondition: obj is an interned python string.
18 void addMapping(PyObject* obj, at::Dimname dimname);
19
20 private:
21 ska::flat_hash_map<PyObject*, at::Dimname> py_interned_string_to_dimname_;
22};
23
24InternedStringsTable kPyInternedStringToDimname;
25
26InternedStringsTable::~InternedStringsTable() {
27 for (auto it = py_interned_string_to_dimname_.begin();
28 it != py_interned_string_to_dimname_.end();
29 ++it) {
30 // See Note [References to python interned strings]
31 Py_DECREF(it->first);
32 }
33}
34
35at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
36 auto it = py_interned_string_to_dimname_.find(obj);
37 if (it == py_interned_string_to_dimname_.end()) {
38 return at::nullopt;
39 }
40 return it->second;
41}
42
43void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
44 // Note [References to python interned strings]
45 // If a Python interned string has no references to it, then it gets
46 // deallocated, invalidating this mapping. Let's immortalize the string by
47 // holding a refcount to it and releasing it in the destructor
48 Py_INCREF(obj);
49 py_interned_string_to_dimname_.emplace(obj, dimname);
50}
51
52} // namespace torch
53
54bool THPUtils_checkDimname(PyObject* obj) {
55 return obj == Py_None || THPUtils_checkString(obj);
56}
57
58// To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if
59// it is a list or tuple and its first elt is a Dimname
60bool THPUtils_checkDimnameList(PyObject* obj) {
61 auto tuple = PyTuple_Check(obj);
62 if (!tuple && !PyList_Check(obj)) {
63 return false;
64 }
65 // NOLINTNEXTLINE(bugprone-branch-clone)
66 const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
67 if (size == 0) {
68 return true;
69 }
70 PyObject* first_elt =
71 tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
72 return THPUtils_checkDimname(first_elt);
73}
74
75at::Dimname THPDimname_parse(PyObject* obj) {
76 if (obj == Py_None) {
77 return at::Dimname::wildcard();
78 }
79
80 if (!THPUtils_checkString(obj)) {
81 throw torch::TypeError(
82 "expected None or string for Dimname but got %s",
83 Py_TYPE(obj)->tp_name);
84 }
85
86 if (!THPUtils_isInterned(obj)) {
87 // internStringInPlace decrefs obj and increfs the result. Because we're
88 // not actually returning the result to the user, we need to undo these.
89 // See
90 // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
91 Py_INCREF(obj);
92 THPUtils_internStringInPlace(&obj);
93 Py_DECREF(obj);
94 }
95
96 auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
97 if (maybeDimname) {
98 return *maybeDimname;
99 }
100
101 const auto name = THPUtils_unpackString(obj);
102 auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
103 torch::kPyInternedStringToDimname.addMapping(obj, dimname);
104 return dimname;
105}
106