1// aten_interned_strings.h includes the names of all operators
2#undef TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4#include <ATen/core/interned_strings.h>
5#include <cstdint>
6#include <cstring>
7#include <iostream>
8#include <mutex>
9#include <sstream>
10#include <string>
11#include <unordered_map>
12#include <vector>
13#include <c10/util/Exception.h>
14#include <ATen/core/interned_strings_class.h>
15#include <c10/util/Exception.h>
16#include <c10/util/Optional.h>
17
18namespace c10 {
19
20const std::string& domain_prefix() {
21 static const std::string _domain_prefix = "org.pytorch.";
22 return _domain_prefix;
23}
24
25Symbol InternedStrings::symbol(const std::string& s) {
26 std::lock_guard<std::mutex> guard(mutex_);
27 return _symbol(s);
28}
29
30std::pair<const char*, const char*> InternedStrings::string(Symbol sym) {
31 // Builtin Symbols are also in the maps, but
32 // we can bypass the need to acquire a lock
33 // to read the map for Builtins because we already
34 // know their string value
35#if defined C10_MOBILE
36 return customString(sym);
37#else
38 switch (sym) {
39#define DEFINE_CASE(ns, s) \
40 case static_cast<unique_t>(ns::s): \
41 return {#ns "::" #s, #s};
42 FORALL_NS_SYMBOLS(DEFINE_CASE)
43#undef DEFINE_CASE
44 default:
45 return customString(sym);
46 }
47#endif
48}
49
50Symbol InternedStrings::ns(Symbol sym) {
51#if defined C10_MOBILE
52 std::lock_guard<std::mutex> guard(mutex_);
53 return sym_to_info_.at(sym).ns;
54#else
55 switch (sym) {
56#define DEFINE_CASE(ns, s) \
57 case static_cast<unique_t>(ns::s): \
58 return namespaces::ns;
59 // NOLINTNEXTLINE(bugprone-branch-clone)
60 FORALL_NS_SYMBOLS(DEFINE_CASE)
61#undef DEFINE_CASE
62 default: {
63 std::lock_guard<std::mutex> guard(mutex_);
64 return sym_to_info_.at(sym).ns;
65 }
66 }
67#endif
68}
69
70Symbol InternedStrings::_symbol(const std::string& s) {
71 auto it = string_to_sym_.find(s);
72 if (it != string_to_sym_.end())
73 return it->second;
74
75 auto pos = s.find("::");
76 if (pos == std::string::npos) {
77 std::stringstream ss;
78 ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
79 throw std::runtime_error(ss.str());
80 }
81 Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
82
83 Symbol sym(sym_to_info_.size());
84 string_to_sym_[s] = sym;
85 sym_to_info_.push_back({ns, s, s.substr(pos + strlen("::"))});
86 return sym;
87}
88
89std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) {
90 std::lock_guard<std::mutex> guard(mutex_);
91 SymbolInfo& s = sym_to_info_.at(sym);
92 return {s.qual_name.c_str(), s.unqual_name.c_str()};
93}
94
95static InternedStrings & globalStrings() {
96 static InternedStrings s;
97 return s;
98}
99
100Symbol Symbol::fromQualString(const std::string & s) {
101 return globalStrings().symbol(s);
102}
103
104const char * Symbol::toUnqualString() const {
105 return globalStrings().string(*this).second;
106}
107
108const char * Symbol::toQualString() const {
109 return globalStrings().string(*this).first;
110}
111
112const char * Symbol::toDisplayString() const {
113 // TODO: Make this actually return something that's "user friendly".
114 // The trouble is that, for this to be usable in printf-style assert
115 // statements, this has to return a const char* (whose lifetime is
116 // global), so we can't actually assemble a string on the fly.
117 return toQualString();
118}
119
120Symbol Symbol::ns() const {
121 return globalStrings().ns(*this);
122}
123
124std::string Symbol::domainString() const {
125 return domain_prefix() + ns().toUnqualString();
126}
127
128Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
129 if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
130 std::ostringstream ss;
131 ss << "Symbol: domain string is expected to be prefixed with '"
132 << domain_prefix() << "', e.g. 'org.pytorch.aten'";
133 throw std::runtime_error(ss.str());
134 }
135 std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
136 return fromQualString(qualString);
137}
138
139bool Symbol::is_attr() const { return ns() == namespaces::attr; }
140bool Symbol::is_aten() const { return ns() == namespaces::aten; }
141bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
142bool Symbol::is_prim() const { return ns() == namespaces::prim; }
143bool Symbol::is_prims() const { return ns() == namespaces::prims; }
144bool Symbol::is_nvprims() const { return ns() == namespaces::nvprims; }
145bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
146bool Symbol::is_user() const { return ns() == namespaces::user; }
147bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
148bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }
149
150} // namespace c10
151