1 | #include <c10/core/SymFloat.h> |
2 | #include <c10/core/SymInt.h> |
3 | #include <c10/core/SymNodeImpl.h> |
4 | #include <array> |
5 | #include <utility> |
6 | |
7 | namespace c10 { |
8 | |
9 | static std::array<SymNode, 2> normalize_symints( |
10 | const SymInt& a_, |
11 | const SymInt& b_) { |
12 | SymNode a, b; |
13 | if (a_.is_symbolic()) |
14 | a = a_.toSymNodeImpl(); |
15 | if (b_.is_symbolic()) |
16 | b = b_.toSymNodeImpl(); |
17 | |
18 | SymNodeImpl* common = a ? a.get() : b.get(); |
19 | // TODO: technically we need to check that the classes match |
20 | if (!a) { |
21 | a = common->wrap_int(a_.as_int_unchecked()); |
22 | } |
23 | if (!b) { |
24 | b = common->wrap_int(b_.as_int_unchecked()); |
25 | } |
26 | return {std::move(a), std::move(b)}; |
27 | } |
28 | |
29 | SymNode SymInt::toSymNodeImpl() const { |
30 | TORCH_CHECK(is_symbolic()); |
31 | return SymNode::reclaim_copy(toSymNodeImplUnowned()); |
32 | } |
33 | |
34 | SymInt::SymInt(SymNode sin_sp) { |
35 | TORCH_CHECK(sin_sp->is_int()); |
36 | auto ptr = static_cast<uint64_t>( |
37 | reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release()))); |
38 | auto rep = (ptr & ~MASK) | IS_SYM; |
39 | data_ = static_cast<int64_t>(rep); |
40 | } |
41 | |
42 | int64_t SymInt::guard_int(const char* file, int64_t line) const { |
43 | if (!is_symbolic()) { |
44 | return data_; |
45 | } |
46 | SymNode a = toSymNodeImpl(); |
47 | return a->guard_int(file, line); |
48 | } |
49 | |
50 | SymInt::operator SymFloat() const { |
51 | if (!is_symbolic()) { |
52 | return SymFloat(double(data_)); |
53 | } |
54 | return SymFloat(toSymNodeImpl()->sym_float()); |
55 | } |
56 | |
57 | SymInt SymInt::operator+(const SymInt& sci) const { |
58 | if (!is_symbolic() && !sci.is_symbolic()) { |
59 | return SymInt(data_ + sci.data_); |
60 | } |
61 | auto res = normalize_symints(*this, sci); |
62 | return SymInt(res[0]->add(res[1])); |
63 | } |
64 | |
65 | SymInt SymInt::operator-(const SymInt& sci) const { |
66 | if (!is_symbolic() && !sci.is_symbolic()) { |
67 | return SymInt(data_ - sci.data_); |
68 | } |
69 | auto res = normalize_symints(*this, sci); |
70 | return SymInt(res[0]->sub(res[1])); |
71 | } |
72 | |
73 | SymInt SymInt::operator*(const SymInt& sci) const { |
74 | if (!is_symbolic() && !sci.is_symbolic()) { |
75 | return SymInt(data_ * sci.data_); |
76 | } |
77 | auto res = normalize_symints(*this, sci); |
78 | return SymInt(res[0]->mul(res[1])); |
79 | } |
80 | |
81 | SymInt SymInt::operator/(const SymInt& sci) const { |
82 | if (!is_symbolic() && !sci.is_symbolic()) { |
83 | return SymInt(data_ / sci.data_); |
84 | } |
85 | auto res = normalize_symints(*this, sci); |
86 | return SymInt(res[0]->floordiv(res[1])); |
87 | } |
88 | |
89 | SymInt SymInt::operator%(const SymInt& sci) const { |
90 | if (!is_symbolic() && !sci.is_symbolic()) { |
91 | return SymInt(data_ % sci.data_); |
92 | } |
93 | auto res = normalize_symints(*this, sci); |
94 | return SymInt(res[0]->mod(res[1])); |
95 | } |
96 | |
97 | SymBool SymInt::sym_eq(const SymInt& sci) const { |
98 | if (!is_symbolic() && !sci.is_symbolic()) { |
99 | return data_ == sci.data_; |
100 | } |
101 | auto res = normalize_symints(*this, sci); |
102 | return res[0]->eq(res[1]); |
103 | } |
104 | |
105 | SymBool SymInt::sym_ne(const SymInt& sci) const { |
106 | if (!is_symbolic() && !sci.is_symbolic()) { |
107 | return data_ != sci.data_; |
108 | } |
109 | auto res = normalize_symints(*this, sci); |
110 | return res[0]->ne(res[1]); |
111 | } |
112 | |
113 | SymBool SymInt::sym_lt(const SymInt& sci) const { |
114 | if (!is_symbolic() && !sci.is_symbolic()) { |
115 | return data_ < sci.data_; |
116 | } |
117 | auto res = normalize_symints(*this, sci); |
118 | return res[0]->lt(res[1]); |
119 | } |
120 | |
121 | SymBool SymInt::sym_le(const SymInt& sci) const { |
122 | if (!is_symbolic() && !sci.is_symbolic()) { |
123 | return data_ <= sci.data_; |
124 | } |
125 | auto res = normalize_symints(*this, sci); |
126 | return res[0]->le(res[1]); |
127 | } |
128 | |
129 | SymBool SymInt::sym_gt(const SymInt& sci) const { |
130 | if (!is_symbolic() && !sci.is_symbolic()) { |
131 | return data_ > sci.data_; |
132 | } |
133 | auto res = normalize_symints(*this, sci); |
134 | return res[0]->gt(res[1]); |
135 | } |
136 | |
137 | SymBool SymInt::sym_ge(const SymInt& sci) const { |
138 | if (!is_symbolic() && !sci.is_symbolic()) { |
139 | return data_ >= sci.data_; |
140 | } |
141 | auto res = normalize_symints(*this, sci); |
142 | return res[0]->ge(res[1]); |
143 | } |
144 | |
145 | SymInt SymInt::min(const SymInt& sci) const { |
146 | if (!is_symbolic() && !sci.is_symbolic()) { |
147 | return std::min(data_, sci.data_); |
148 | } |
149 | auto res = normalize_symints(*this, sci); |
150 | return SymInt(res[0]->sym_min(res[1])); |
151 | } |
152 | SymInt SymInt::max(const SymInt& sci) const { |
153 | if (!is_symbolic() && !sci.is_symbolic()) { |
154 | return std::max(data_, sci.data_); |
155 | } |
156 | auto res = normalize_symints(*this, sci); |
157 | return SymInt(res[0]->sym_max(res[1])); |
158 | } |
159 | |
160 | void SymInt::operator*=(const SymInt& sci) { |
161 | *this = *this * sci; |
162 | } |
163 | |
164 | void SymInt::operator/=(const SymInt& sci) { |
165 | *this = *this / sci; |
166 | } |
167 | |
168 | void SymInt::operator+=(const SymInt& sci) { |
169 | *this = *this + sci; |
170 | } |
171 | |
172 | bool SymInt::operator<(int64_t sci) const { |
173 | return *this < c10::SymInt(sci); |
174 | } |
175 | |
176 | bool SymInt::operator<=(int64_t sci) const { |
177 | return *this <= c10::SymInt(sci); |
178 | } |
179 | |
180 | bool SymInt::operator>(int64_t sci) const { |
181 | return *this > c10::SymInt(sci); |
182 | } |
183 | |
184 | bool SymInt::operator>=(int64_t sci) const { |
185 | return *this >= c10::SymInt(sci); |
186 | } |
187 | |
188 | bool SymInt::operator==(int64_t sci) const { |
189 | return *this == c10::SymInt(sci); |
190 | } |
191 | |
192 | bool SymInt::operator!=(int64_t sci) const { |
193 | return *this != c10::SymInt(sci); |
194 | } |
195 | |
196 | SymInt SymInt::operator*(int64_t sci) const { |
197 | return *this * c10::SymInt(sci); |
198 | } |
199 | |
200 | std::ostream& operator<<(std::ostream& os, const SymInt& s) { |
201 | if (s.is_symbolic()) { |
202 | os << s.toSymNodeImpl()->str(); |
203 | } else { |
204 | os << s.as_int_unchecked(); |
205 | } |
206 | return os; |
207 | } |
208 | |
209 | SymInt operator-(const SymInt& s) { |
210 | if (s.is_symbolic()) { |
211 | return SymInt(s.toSymNodeImpl()->neg()); |
212 | } else { |
213 | return SymInt(-s.as_int_unchecked()); |
214 | } |
215 | } |
216 | |
217 | } // namespace c10 |
218 | |