1 | // aten_interned_strings.h includes the names of all operators |
2 | #undef TORCH_ASSERT_ONLY_METHOD_OPERATORS |
3 | |
4 | #include <ATen/core/interned_strings.h> |
5 | #include <cstdint> |
6 | #include <cstring> |
7 | #include <iostream> |
8 | #include <mutex> |
9 | #include <sstream> |
10 | #include <string> |
11 | #include <unordered_map> |
12 | #include <vector> |
13 | #include <c10/util/Exception.h> |
14 | #include <ATen/core/interned_strings_class.h> |
15 | #include <c10/util/Exception.h> |
16 | #include <c10/util/Optional.h> |
17 | |
18 | namespace c10 { |
19 | |
20 | const std::string& domain_prefix() { |
21 | static const std::string _domain_prefix = "org.pytorch." ; |
22 | return _domain_prefix; |
23 | } |
24 | |
25 | Symbol InternedStrings::symbol(const std::string& s) { |
26 | std::lock_guard<std::mutex> guard(mutex_); |
27 | return _symbol(s); |
28 | } |
29 | |
30 | std::pair<const char*, const char*> InternedStrings::string(Symbol sym) { |
31 | // Builtin Symbols are also in the maps, but |
32 | // we can bypass the need to acquire a lock |
33 | // to read the map for Builtins because we already |
34 | // know their string value |
35 | #if defined C10_MOBILE |
36 | return customString(sym); |
37 | #else |
38 | switch (sym) { |
39 | #define DEFINE_CASE(ns, s) \ |
40 | case static_cast<unique_t>(ns::s): \ |
41 | return {#ns "::" #s, #s}; |
42 | FORALL_NS_SYMBOLS(DEFINE_CASE) |
43 | #undef DEFINE_CASE |
44 | default: |
45 | return customString(sym); |
46 | } |
47 | #endif |
48 | } |
49 | |
50 | Symbol InternedStrings::ns(Symbol sym) { |
51 | #if defined C10_MOBILE |
52 | std::lock_guard<std::mutex> guard(mutex_); |
53 | return sym_to_info_.at(sym).ns; |
54 | #else |
55 | switch (sym) { |
56 | #define DEFINE_CASE(ns, s) \ |
57 | case static_cast<unique_t>(ns::s): \ |
58 | return namespaces::ns; |
59 | // NOLINTNEXTLINE(bugprone-branch-clone) |
60 | FORALL_NS_SYMBOLS(DEFINE_CASE) |
61 | #undef DEFINE_CASE |
62 | default: { |
63 | std::lock_guard<std::mutex> guard(mutex_); |
64 | return sym_to_info_.at(sym).ns; |
65 | } |
66 | } |
67 | #endif |
68 | } |
69 | |
70 | Symbol InternedStrings::_symbol(const std::string& s) { |
71 | auto it = string_to_sym_.find(s); |
72 | if (it != string_to_sym_.end()) |
73 | return it->second; |
74 | |
75 | auto pos = s.find("::" ); |
76 | if (pos == std::string::npos) { |
77 | std::stringstream ss; |
78 | ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s; |
79 | throw std::runtime_error(ss.str()); |
80 | } |
81 | Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); |
82 | |
83 | Symbol sym(sym_to_info_.size()); |
84 | string_to_sym_[s] = sym; |
85 | sym_to_info_.push_back({ns, s, s.substr(pos + strlen("::" ))}); |
86 | return sym; |
87 | } |
88 | |
89 | std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) { |
90 | std::lock_guard<std::mutex> guard(mutex_); |
91 | SymbolInfo& s = sym_to_info_.at(sym); |
92 | return {s.qual_name.c_str(), s.unqual_name.c_str()}; |
93 | } |
94 | |
95 | static InternedStrings & globalStrings() { |
96 | static InternedStrings s; |
97 | return s; |
98 | } |
99 | |
100 | Symbol Symbol::fromQualString(const std::string & s) { |
101 | return globalStrings().symbol(s); |
102 | } |
103 | |
104 | const char * Symbol::toUnqualString() const { |
105 | return globalStrings().string(*this).second; |
106 | } |
107 | |
108 | const char * Symbol::toQualString() const { |
109 | return globalStrings().string(*this).first; |
110 | } |
111 | |
112 | const char * Symbol::toDisplayString() const { |
113 | // TODO: Make this actually return something that's "user friendly". |
114 | // The trouble is that, for this to be usable in printf-style assert |
115 | // statements, this has to return a const char* (whose lifetime is |
116 | // global), so we can't actually assemble a string on the fly. |
117 | return toQualString(); |
118 | } |
119 | |
120 | Symbol Symbol::ns() const { |
121 | return globalStrings().ns(*this); |
122 | } |
123 | |
124 | std::string Symbol::domainString() const { |
125 | return domain_prefix() + ns().toUnqualString(); |
126 | } |
127 | |
128 | Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { |
129 | if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) { |
130 | std::ostringstream ss; |
131 | ss << "Symbol: domain string is expected to be prefixed with '" |
132 | << domain_prefix() << "', e.g. 'org.pytorch.aten'" ; |
133 | throw std::runtime_error(ss.str()); |
134 | } |
135 | std::string qualString = d.substr(domain_prefix().size()) + "::" + s; |
136 | return fromQualString(qualString); |
137 | } |
138 | |
139 | bool Symbol::is_attr() const { return ns() == namespaces::attr; } |
140 | bool Symbol::is_aten() const { return ns() == namespaces::aten; } |
141 | bool Symbol::is_cuda() const { return ns() == namespaces::cuda; } |
142 | bool Symbol::is_prim() const { return ns() == namespaces::prim; } |
143 | bool Symbol::is_prims() const { return ns() == namespaces::prims; } |
144 | bool Symbol::is_nvprims() const { return ns() == namespaces::nvprims; } |
145 | bool Symbol::is_onnx() const { return ns() == namespaces::onnx; } |
146 | bool Symbol::is_user() const { return ns() == namespaces::user; } |
147 | bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; } |
148 | bool Symbol::is_dimname() const { return ns() == namespaces::dimname; } |
149 | |
150 | } // namespace c10 |
151 | |