1#include <c10/core/SymFloat.h>
2#include <c10/core/SymInt.h>
3#include <c10/core/SymNodeImpl.h>
4#include <array>
5#include <utility>
6
7namespace c10 {
8
9static std::array<SymNode, 2> normalize_symints(
10 const SymInt& a_,
11 const SymInt& b_) {
12 SymNode a, b;
13 if (a_.is_symbolic())
14 a = a_.toSymNodeImpl();
15 if (b_.is_symbolic())
16 b = b_.toSymNodeImpl();
17
18 SymNodeImpl* common = a ? a.get() : b.get();
19 // TODO: technically we need to check that the classes match
20 if (!a) {
21 a = common->wrap_int(a_.as_int_unchecked());
22 }
23 if (!b) {
24 b = common->wrap_int(b_.as_int_unchecked());
25 }
26 return {std::move(a), std::move(b)};
27}
28
29SymNode SymInt::toSymNodeImpl() const {
30 TORCH_CHECK(is_symbolic());
31 return SymNode::reclaim_copy(toSymNodeImplUnowned());
32}
33
34SymInt::SymInt(SymNode sin_sp) {
35 TORCH_CHECK(sin_sp->is_int());
36 auto ptr = static_cast<uint64_t>(
37 reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
38 auto rep = (ptr & ~MASK) | IS_SYM;
39 data_ = static_cast<int64_t>(rep);
40}
41
42int64_t SymInt::guard_int(const char* file, int64_t line) const {
43 if (!is_symbolic()) {
44 return data_;
45 }
46 SymNode a = toSymNodeImpl();
47 return a->guard_int(file, line);
48}
49
50SymInt::operator SymFloat() const {
51 if (!is_symbolic()) {
52 return SymFloat(double(data_));
53 }
54 return SymFloat(toSymNodeImpl()->sym_float());
55}
56
57SymInt SymInt::operator+(const SymInt& sci) const {
58 if (!is_symbolic() && !sci.is_symbolic()) {
59 return SymInt(data_ + sci.data_);
60 }
61 auto res = normalize_symints(*this, sci);
62 return SymInt(res[0]->add(res[1]));
63}
64
65SymInt SymInt::operator-(const SymInt& sci) const {
66 if (!is_symbolic() && !sci.is_symbolic()) {
67 return SymInt(data_ - sci.data_);
68 }
69 auto res = normalize_symints(*this, sci);
70 return SymInt(res[0]->sub(res[1]));
71}
72
73SymInt SymInt::operator*(const SymInt& sci) const {
74 if (!is_symbolic() && !sci.is_symbolic()) {
75 return SymInt(data_ * sci.data_);
76 }
77 auto res = normalize_symints(*this, sci);
78 return SymInt(res[0]->mul(res[1]));
79}
80
81SymInt SymInt::operator/(const SymInt& sci) const {
82 if (!is_symbolic() && !sci.is_symbolic()) {
83 return SymInt(data_ / sci.data_);
84 }
85 auto res = normalize_symints(*this, sci);
86 return SymInt(res[0]->floordiv(res[1]));
87}
88
89SymInt SymInt::operator%(const SymInt& sci) const {
90 if (!is_symbolic() && !sci.is_symbolic()) {
91 return SymInt(data_ % sci.data_);
92 }
93 auto res = normalize_symints(*this, sci);
94 return SymInt(res[0]->mod(res[1]));
95}
96
97SymBool SymInt::sym_eq(const SymInt& sci) const {
98 if (!is_symbolic() && !sci.is_symbolic()) {
99 return data_ == sci.data_;
100 }
101 auto res = normalize_symints(*this, sci);
102 return res[0]->eq(res[1]);
103}
104
105SymBool SymInt::sym_ne(const SymInt& sci) const {
106 if (!is_symbolic() && !sci.is_symbolic()) {
107 return data_ != sci.data_;
108 }
109 auto res = normalize_symints(*this, sci);
110 return res[0]->ne(res[1]);
111}
112
113SymBool SymInt::sym_lt(const SymInt& sci) const {
114 if (!is_symbolic() && !sci.is_symbolic()) {
115 return data_ < sci.data_;
116 }
117 auto res = normalize_symints(*this, sci);
118 return res[0]->lt(res[1]);
119}
120
121SymBool SymInt::sym_le(const SymInt& sci) const {
122 if (!is_symbolic() && !sci.is_symbolic()) {
123 return data_ <= sci.data_;
124 }
125 auto res = normalize_symints(*this, sci);
126 return res[0]->le(res[1]);
127}
128
129SymBool SymInt::sym_gt(const SymInt& sci) const {
130 if (!is_symbolic() && !sci.is_symbolic()) {
131 return data_ > sci.data_;
132 }
133 auto res = normalize_symints(*this, sci);
134 return res[0]->gt(res[1]);
135}
136
137SymBool SymInt::sym_ge(const SymInt& sci) const {
138 if (!is_symbolic() && !sci.is_symbolic()) {
139 return data_ >= sci.data_;
140 }
141 auto res = normalize_symints(*this, sci);
142 return res[0]->ge(res[1]);
143}
144
145SymInt SymInt::min(const SymInt& sci) const {
146 if (!is_symbolic() && !sci.is_symbolic()) {
147 return std::min(data_, sci.data_);
148 }
149 auto res = normalize_symints(*this, sci);
150 return SymInt(res[0]->sym_min(res[1]));
151}
152SymInt SymInt::max(const SymInt& sci) const {
153 if (!is_symbolic() && !sci.is_symbolic()) {
154 return std::max(data_, sci.data_);
155 }
156 auto res = normalize_symints(*this, sci);
157 return SymInt(res[0]->sym_max(res[1]));
158}
159
160void SymInt::operator*=(const SymInt& sci) {
161 *this = *this * sci;
162}
163
164void SymInt::operator/=(const SymInt& sci) {
165 *this = *this / sci;
166}
167
168void SymInt::operator+=(const SymInt& sci) {
169 *this = *this + sci;
170}
171
172bool SymInt::operator<(int64_t sci) const {
173 return *this < c10::SymInt(sci);
174}
175
176bool SymInt::operator<=(int64_t sci) const {
177 return *this <= c10::SymInt(sci);
178}
179
180bool SymInt::operator>(int64_t sci) const {
181 return *this > c10::SymInt(sci);
182}
183
184bool SymInt::operator>=(int64_t sci) const {
185 return *this >= c10::SymInt(sci);
186}
187
188bool SymInt::operator==(int64_t sci) const {
189 return *this == c10::SymInt(sci);
190}
191
192bool SymInt::operator!=(int64_t sci) const {
193 return *this != c10::SymInt(sci);
194}
195
196SymInt SymInt::operator*(int64_t sci) const {
197 return *this * c10::SymInt(sci);
198}
199
200std::ostream& operator<<(std::ostream& os, const SymInt& s) {
201 if (s.is_symbolic()) {
202 os << s.toSymNodeImpl()->str();
203 } else {
204 os << s.as_int_unchecked();
205 }
206 return os;
207}
208
209SymInt operator-(const SymInt& s) {
210 if (s.is_symbolic()) {
211 return SymInt(s.toSymNodeImpl()->neg());
212 } else {
213 return SymInt(-s.as_int_unchecked());
214 }
215}
216
217} // namespace c10
218