1 | #pragma once |
2 | |
3 | #include <c10/core/SymBool.h> |
4 | #include <c10/core/SymNodeImpl.h> |
5 | #include <c10/macros/Macros.h> |
6 | #include <c10/util/Exception.h> |
7 | |
8 | #include <numeric> |
9 | |
10 | namespace c10 { |
11 | |
12 | class SymFloat; |
13 | |
14 | // SymInt represents either a regular int64_t, or a symbolic integer |
15 | // (represented in a type erased way as SymNode). The intention is for SymInt |
16 | // to represent symbolic sizes that arise when doing shape computation in |
17 | // operator kernels. This allows for tracing through programs without baking in |
18 | // concrete sizes into kernel calls. |
19 | // |
20 | // SymInt has an API equivalent to int64_t. In particular, it is a value type. |
21 | // Internally, SymInt is represented in a clever packed way, so that it only |
22 | // occupies one word of space; but morally, it is a union between an int64_t |
23 | // and an intrusive pointer to SymNodeImpl. |
24 | // |
25 | // Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where |
26 | // is_int() returns true |
27 | |
28 | class C10_API SymInt { |
29 | public: |
30 | enum Unchecked { |
31 | UNCHECKED, |
32 | }; |
33 | |
34 | /*implicit*/ SymInt(int64_t d) : data_(d) { |
35 | // NB: this relies on exception in constructor inhibiting |
36 | // destructor; otherwise we would attempt to deallocate |
37 | // the garbage data! |
38 | TORCH_CHECK(!is_symbolic()); |
39 | }; |
40 | SymInt() : data_(0) {} |
41 | SymInt(SymNode n); |
42 | |
43 | // unchecked c-tor accepting raw `data_` |
44 | // One appropriate use for this is when you are constructing a symint |
45 | // in a situation where you know it is non-negative (or, if it is negative, |
46 | // the negative value is -1; i.e., not user controlled) |
47 | SymInt(Unchecked, int64_t d) : data_(d) {} |
48 | |
49 | // TODO: these implementations are not optimal because they allocate a |
50 | // temporary and then use the move constructor/assignment |
51 | SymInt(const SymInt& s) : data_(0) { |
52 | if (s.is_symbolic()) { |
53 | *this = SymInt(s.toSymNodeImpl()); |
54 | } else { |
55 | data_ = s.data_; |
56 | } |
57 | } |
58 | SymInt(SymInt&& s) noexcept : data_(s.data_) { |
59 | s.data_ = 0; |
60 | } |
61 | |
62 | SymInt& operator=(const SymInt& s) { |
63 | if (this != &s) { |
64 | if (s.is_symbolic()) { |
65 | *this = SymInt(s.toSymNodeImpl()); |
66 | } else { |
67 | data_ = s.data_; |
68 | } |
69 | } |
70 | return *this; |
71 | } |
72 | SymInt& operator=(SymInt&& s) noexcept { |
73 | if (this != &s) { |
74 | release_(); // release the current SymNode if any |
75 | data_ = s.data_; |
76 | if (s.is_symbolic()) |
77 | s.data_ = 0; |
78 | }; |
79 | return *this; |
80 | } |
81 | |
82 | SymInt clone() const { |
83 | if (is_symbolic()) { |
84 | return SymInt(toSymNodeImplUnowned()->clone()); |
85 | } |
86 | return *this; |
87 | } |
88 | |
89 | SymNodeImpl* toSymNodeImplUnowned() const { |
90 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic()); |
91 | uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK; |
92 | uint64_t sign_bit_mask = 1ULL << (62 - 1); |
93 | // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c |
94 | uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; |
95 | return static_cast<SymNodeImpl*>( |
96 | reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits))); |
97 | } |
98 | |
99 | void release_() { |
100 | if (is_symbolic()) { |
101 | SymNode::reclaim(toSymNodeImplUnowned()); // steal |
102 | } |
103 | } |
104 | |
105 | SymNodeImpl* release() && { |
106 | #ifndef C10_MOBILE |
107 | TORCH_INTERNAL_ASSERT(is_symbolic()); |
108 | auto* r = toSymNodeImplUnowned(); |
109 | data_ = 0; // transfer ownership |
110 | return r; |
111 | #else |
112 | TORCH_INTERNAL_ASSERT(false); |
113 | #endif |
114 | } |
115 | |
116 | SymNode toSymNodeImpl() const; |
117 | |
118 | ~SymInt() { |
119 | release_(); |
120 | } |
121 | |
122 | // Require the int to be non-symbolic, and if it is symbolic raise an |
123 | // error. This is safe to use for C++ code that doesn't work for symbolic |
124 | // shapes, and you don't have time to fix it immediately, as if we |
125 | // try to trigger the path in C++ you'll appropriately get an error |
126 | int64_t expect_int() const { |
127 | TORCH_CHECK(!is_symbolic()); |
128 | return data_; |
129 | } |
130 | |
131 | // Insert a guard for the int to be its concrete value, and then return |
132 | // that value. This operation always works, even if the int is symbolic, |
133 | // so long as we know what the underlying value is (e.g., this won't work |
134 | // if you call it on the size of nonzero output). Don't blindly put this |
135 | // everywhere; you can cause overspecialization of PyTorch programs with |
136 | // this method. |
137 | // |
138 | // It should be called as guard_int(__FILE__, __LINE__). The file and line |
139 | // number can be used to diagnose overspecialization. |
140 | int64_t guard_int(const char* file, int64_t line) const; |
141 | |
142 | // N.B. It's important to keep this definition in the header |
143 | // as we expect if checks to be folded for mobile builds |
144 | // where `is_symbolic` is always false and optimize dead code paths |
145 | C10_ALWAYS_INLINE bool is_symbolic() const { |
146 | #ifdef C10_MOBILE |
147 | return false; |
148 | #else |
149 | return !check_range(data_); |
150 | #endif |
151 | } |
152 | |
153 | SymInt operator+(const SymInt& sci) const; |
154 | SymInt operator-(const SymInt& sci) const; |
155 | SymInt operator*(const SymInt& sci) const; |
156 | SymInt operator/(const SymInt& sci) const; |
157 | SymInt operator%(const SymInt& sci) const; |
158 | void operator*=(const SymInt& sci); |
159 | void operator+=(const SymInt& sci); |
160 | void operator/=(const SymInt& sci); |
161 | |
162 | SymBool sym_eq(const SymInt&) const; |
163 | SymBool sym_ne(const SymInt&) const; |
164 | SymBool sym_lt(const SymInt&) const; |
165 | SymBool sym_le(const SymInt&) const; |
166 | SymBool sym_gt(const SymInt&) const; |
167 | SymBool sym_ge(const SymInt&) const; |
168 | |
169 | bool operator==(const SymInt& o) const { |
170 | return sym_eq(o).guard_bool(__FILE__, __LINE__); |
171 | } |
172 | bool operator!=(const SymInt& o) const { |
173 | return sym_ne(o).guard_bool(__FILE__, __LINE__); |
174 | } |
175 | bool operator<(const SymInt& o) const { |
176 | return sym_lt(o).guard_bool(__FILE__, __LINE__); |
177 | } |
178 | bool operator<=(const SymInt& o) const { |
179 | return sym_le(o).guard_bool(__FILE__, __LINE__); |
180 | } |
181 | bool operator>(const SymInt& o) const { |
182 | return sym_gt(o).guard_bool(__FILE__, __LINE__); |
183 | } |
184 | bool operator>=(const SymInt& o) const { |
185 | return sym_ge(o).guard_bool(__FILE__, __LINE__); |
186 | } |
187 | |
188 | SymInt min(const SymInt& sci) const; |
189 | SymInt max(const SymInt& sci) const; |
190 | |
191 | SymInt operator*(int64_t sci) const; |
192 | bool operator<(int64_t sci) const; |
193 | bool operator==(int64_t sci) const; |
194 | bool operator!=(int64_t sci) const; |
195 | bool operator<=(int64_t sci) const; |
196 | bool operator>(int64_t sci) const; |
197 | bool operator>=(int64_t sci) const; |
198 | |
199 | operator SymFloat() const; |
200 | |
201 | int64_t as_int_unchecked() const { |
202 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_symbolic()); |
203 | return data_; |
204 | } |
205 | |
206 | // Return whether the integer is representable as a SymInt. |
207 | static bool check_range(int64_t i) { |
208 | return i > MAX_UNREPRESENTABLE_INT; |
209 | } |
210 | |
211 | // Return the min represetable integer as a SymInt |
212 | static constexpr int64_t min_representable_int() { |
213 | return MAX_UNREPRESENTABLE_INT + 1; |
214 | } |
215 | |
216 | private: |
217 | // Constraints on the internal representation: |
218 | // |
219 | // - Should represent positive and small negative ints |
220 | // - No conversion necessary for operations on ints |
221 | // - Must represent valid 64-bit pointers |
222 | // - Is symbolic test should be FAST (two arithmetic instructions is too |
223 | // much). |
224 | // This code being a hotpath is based on Strobelight profiles of |
225 | // is_symbolic(). FB only: https://fburl.com/strobelight/5l50ncxd |
226 | // (you will need to change the time window). |
227 | // |
228 | // So, the scheme is to reserve large negative numbers (asssuming |
229 | // two's complement): |
230 | // |
231 | // - 0b0.... means we are a positive int |
232 | // - 0b11... means we are a small negative int |
233 | // - 0b10... means we are are a pointer. This means that |
234 | // [-2^63, -2^62-1] are not representable as ints. |
235 | // We don't actually need all of this space as on x86_64 |
236 | // as the top 16bits aren't used for anything |
237 | static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61; |
238 | static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61; |
239 | // We must manually translate the bit pattern test into a greater |
240 | // than test because compiler doesn't figure it out: |
241 | // https://godbolt.org/z/356aferaW |
242 | static constexpr int64_t MAX_UNREPRESENTABLE_INT = |
243 | -1LL & static_cast<int64_t>(~(1ULL << 62)); |
244 | int64_t data_; |
245 | }; |
246 | |
247 | /// Sum of a list of SymInt; accumulates into the c10::SymInt expression |
248 | template < |
249 | typename C, |
250 | typename std::enable_if< |
251 | std::is_same<typename C::value_type, c10::SymInt>::value, |
252 | int>::type = 0> |
253 | inline c10::SymInt multiply_integers(const C& container) { |
254 | return std::accumulate( |
255 | container.begin(), |
256 | container.end(), |
257 | c10::SymInt(1), |
258 | [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); |
259 | } |
260 | |
261 | template < |
262 | typename Iter, |
263 | typename = std::enable_if_t<std::is_same< |
264 | typename std::iterator_traits<Iter>::value_type, |
265 | c10::SymInt>::value>> |
266 | inline c10::SymInt multiply_integers(Iter begin, Iter end) { |
267 | return std::accumulate( |
268 | begin, |
269 | end, |
270 | c10::SymInt(1), |
271 | [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); |
272 | } |
273 | |
274 | inline SymInt operator+(int64_t a, const SymInt& b) { |
275 | return c10::SymInt(a) + b; |
276 | } |
277 | inline SymInt operator-(int64_t a, const SymInt& b) { |
278 | return c10::SymInt(a) - b; |
279 | } |
280 | inline SymInt operator*(int64_t a, const SymInt& b) { |
281 | return c10::SymInt(a) * b; |
282 | } |
283 | inline SymInt operator/(int64_t a, const SymInt& b) { |
284 | return c10::SymInt(a) / b; |
285 | } |
286 | inline SymInt operator%(int64_t a, const SymInt& b) { |
287 | return c10::SymInt(a) % b; |
288 | } |
289 | inline bool operator==(int64_t a, const SymInt& b) { |
290 | return c10::SymInt(a) == b; |
291 | } |
292 | inline bool operator!=(int64_t a, const SymInt& b) { |
293 | return c10::SymInt(a) != b; |
294 | } |
295 | inline bool operator<(int64_t a, const SymInt& b) { |
296 | return c10::SymInt(a) < b; |
297 | } |
298 | inline bool operator<=(int64_t a, const SymInt& b) { |
299 | return c10::SymInt(a) <= b; |
300 | } |
301 | inline bool operator>(int64_t a, const SymInt& b) { |
302 | return c10::SymInt(a) > b; |
303 | } |
304 | inline bool operator>=(int64_t a, const SymInt& b) { |
305 | return c10::SymInt(a) >= b; |
306 | } |
307 | |
308 | C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); |
309 | C10_API SymInt operator-(const SymInt& s); |
310 | } // namespace c10 |
311 | |