1#pragma once
2
3#include <c10/core/SymBool.h>
4#include <c10/core/SymNodeImpl.h>
5#include <c10/macros/Macros.h>
6#include <c10/util/Exception.h>
7
8#include <numeric>
9
10namespace c10 {
11
12class SymFloat;
13
14// SymInt represents either a regular int64_t, or a symbolic integer
15// (represented in a type erased way as SymNode). The intention is for SymInt
16// to represent symbolic sizes that arise when doing shape computation in
17// operator kernels. This allows for tracing through programs without baking in
18// concrete sizes into kernel calls.
19//
20// SymInt has an API equivalent to int64_t. In particular, it is a value type.
21// Internally, SymInt is represented in a clever packed way, so that it only
22// occupies one word of space; but morally, it is a union between an int64_t
23// and an intrusive pointer to SymNodeImpl.
24//
25// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
26// is_int() returns true
27
28class C10_API SymInt {
29 public:
30 enum Unchecked {
31 UNCHECKED,
32 };
33
34 /*implicit*/ SymInt(int64_t d) : data_(d) {
35 // NB: this relies on exception in constructor inhibiting
36 // destructor; otherwise we would attempt to deallocate
37 // the garbage data!
38 TORCH_CHECK(!is_symbolic());
39 };
40 SymInt() : data_(0) {}
41 SymInt(SymNode n);
42
43 // unchecked c-tor accepting raw `data_`
44 // One appropriate use for this is when you are constructing a symint
45 // in a situation where you know it is non-negative (or, if it is negative,
46 // the negative value is -1; i.e., not user controlled)
47 SymInt(Unchecked, int64_t d) : data_(d) {}
48
49 // TODO: these implementations are not optimal because they allocate a
50 // temporary and then use the move constructor/assignment
51 SymInt(const SymInt& s) : data_(0) {
52 if (s.is_symbolic()) {
53 *this = SymInt(s.toSymNodeImpl());
54 } else {
55 data_ = s.data_;
56 }
57 }
58 SymInt(SymInt&& s) noexcept : data_(s.data_) {
59 s.data_ = 0;
60 }
61
62 SymInt& operator=(const SymInt& s) {
63 if (this != &s) {
64 if (s.is_symbolic()) {
65 *this = SymInt(s.toSymNodeImpl());
66 } else {
67 data_ = s.data_;
68 }
69 }
70 return *this;
71 }
72 SymInt& operator=(SymInt&& s) noexcept {
73 if (this != &s) {
74 release_(); // release the current SymNode if any
75 data_ = s.data_;
76 if (s.is_symbolic())
77 s.data_ = 0;
78 };
79 return *this;
80 }
81
82 SymInt clone() const {
83 if (is_symbolic()) {
84 return SymInt(toSymNodeImplUnowned()->clone());
85 }
86 return *this;
87 }
88
89 SymNodeImpl* toSymNodeImplUnowned() const {
90 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
91 uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
92 uint64_t sign_bit_mask = 1ULL << (62 - 1);
93 // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
94 uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
95 return static_cast<SymNodeImpl*>(
96 reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
97 }
98
99 void release_() {
100 if (is_symbolic()) {
101 SymNode::reclaim(toSymNodeImplUnowned()); // steal
102 }
103 }
104
105 SymNodeImpl* release() && {
106#ifndef C10_MOBILE
107 TORCH_INTERNAL_ASSERT(is_symbolic());
108 auto* r = toSymNodeImplUnowned();
109 data_ = 0; // transfer ownership
110 return r;
111#else
112 TORCH_INTERNAL_ASSERT(false);
113#endif
114 }
115
116 SymNode toSymNodeImpl() const;
117
118 ~SymInt() {
119 release_();
120 }
121
122 // Require the int to be non-symbolic, and if it is symbolic raise an
123 // error. This is safe to use for C++ code that doesn't work for symbolic
124 // shapes, and you don't have time to fix it immediately, as if we
125 // try to trigger the path in C++ you'll appropriately get an error
126 int64_t expect_int() const {
127 TORCH_CHECK(!is_symbolic());
128 return data_;
129 }
130
131 // Insert a guard for the int to be its concrete value, and then return
132 // that value. This operation always works, even if the int is symbolic,
133 // so long as we know what the underlying value is (e.g., this won't work
134 // if you call it on the size of nonzero output). Don't blindly put this
135 // everywhere; you can cause overspecialization of PyTorch programs with
136 // this method.
137 //
138 // It should be called as guard_int(__FILE__, __LINE__). The file and line
139 // number can be used to diagnose overspecialization.
140 int64_t guard_int(const char* file, int64_t line) const;
141
142 // N.B. It's important to keep this definition in the header
143 // as we expect if checks to be folded for mobile builds
144 // where `is_symbolic` is always false and optimize dead code paths
145 C10_ALWAYS_INLINE bool is_symbolic() const {
146#ifdef C10_MOBILE
147 return false;
148#else
149 return !check_range(data_);
150#endif
151 }
152
153 SymInt operator+(const SymInt& sci) const;
154 SymInt operator-(const SymInt& sci) const;
155 SymInt operator*(const SymInt& sci) const;
156 SymInt operator/(const SymInt& sci) const;
157 SymInt operator%(const SymInt& sci) const;
158 void operator*=(const SymInt& sci);
159 void operator+=(const SymInt& sci);
160 void operator/=(const SymInt& sci);
161
162 SymBool sym_eq(const SymInt&) const;
163 SymBool sym_ne(const SymInt&) const;
164 SymBool sym_lt(const SymInt&) const;
165 SymBool sym_le(const SymInt&) const;
166 SymBool sym_gt(const SymInt&) const;
167 SymBool sym_ge(const SymInt&) const;
168
169 bool operator==(const SymInt& o) const {
170 return sym_eq(o).guard_bool(__FILE__, __LINE__);
171 }
172 bool operator!=(const SymInt& o) const {
173 return sym_ne(o).guard_bool(__FILE__, __LINE__);
174 }
175 bool operator<(const SymInt& o) const {
176 return sym_lt(o).guard_bool(__FILE__, __LINE__);
177 }
178 bool operator<=(const SymInt& o) const {
179 return sym_le(o).guard_bool(__FILE__, __LINE__);
180 }
181 bool operator>(const SymInt& o) const {
182 return sym_gt(o).guard_bool(__FILE__, __LINE__);
183 }
184 bool operator>=(const SymInt& o) const {
185 return sym_ge(o).guard_bool(__FILE__, __LINE__);
186 }
187
188 SymInt min(const SymInt& sci) const;
189 SymInt max(const SymInt& sci) const;
190
191 SymInt operator*(int64_t sci) const;
192 bool operator<(int64_t sci) const;
193 bool operator==(int64_t sci) const;
194 bool operator!=(int64_t sci) const;
195 bool operator<=(int64_t sci) const;
196 bool operator>(int64_t sci) const;
197 bool operator>=(int64_t sci) const;
198
199 operator SymFloat() const;
200
201 int64_t as_int_unchecked() const {
202 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_symbolic());
203 return data_;
204 }
205
206 // Return whether the integer is representable as a SymInt.
207 static bool check_range(int64_t i) {
208 return i > MAX_UNREPRESENTABLE_INT;
209 }
210
211 // Return the min represetable integer as a SymInt
212 static constexpr int64_t min_representable_int() {
213 return MAX_UNREPRESENTABLE_INT + 1;
214 }
215
216 private:
217 // Constraints on the internal representation:
218 //
219 // - Should represent positive and small negative ints
220 // - No conversion necessary for operations on ints
221 // - Must represent valid 64-bit pointers
222 // - Is symbolic test should be FAST (two arithmetic instructions is too
223 // much).
224 // This code being a hotpath is based on Strobelight profiles of
225 // is_symbolic(). FB only: https://fburl.com/strobelight/5l50ncxd
226 // (you will need to change the time window).
227 //
228 // So, the scheme is to reserve large negative numbers (asssuming
229 // two's complement):
230 //
231 // - 0b0.... means we are a positive int
232 // - 0b11... means we are a small negative int
233 // - 0b10... means we are are a pointer. This means that
234 // [-2^63, -2^62-1] are not representable as ints.
235 // We don't actually need all of this space as on x86_64
236 // as the top 16bits aren't used for anything
237 static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
238 static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
239 // We must manually translate the bit pattern test into a greater
240 // than test because compiler doesn't figure it out:
241 // https://godbolt.org/z/356aferaW
242 static constexpr int64_t MAX_UNREPRESENTABLE_INT =
243 -1LL & static_cast<int64_t>(~(1ULL << 62));
244 int64_t data_;
245};
246
247/// Sum of a list of SymInt; accumulates into the c10::SymInt expression
248template <
249 typename C,
250 typename std::enable_if<
251 std::is_same<typename C::value_type, c10::SymInt>::value,
252 int>::type = 0>
253inline c10::SymInt multiply_integers(const C& container) {
254 return std::accumulate(
255 container.begin(),
256 container.end(),
257 c10::SymInt(1),
258 [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
259}
260
261template <
262 typename Iter,
263 typename = std::enable_if_t<std::is_same<
264 typename std::iterator_traits<Iter>::value_type,
265 c10::SymInt>::value>>
266inline c10::SymInt multiply_integers(Iter begin, Iter end) {
267 return std::accumulate(
268 begin,
269 end,
270 c10::SymInt(1),
271 [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
272}
273
274inline SymInt operator+(int64_t a, const SymInt& b) {
275 return c10::SymInt(a) + b;
276}
277inline SymInt operator-(int64_t a, const SymInt& b) {
278 return c10::SymInt(a) - b;
279}
280inline SymInt operator*(int64_t a, const SymInt& b) {
281 return c10::SymInt(a) * b;
282}
283inline SymInt operator/(int64_t a, const SymInt& b) {
284 return c10::SymInt(a) / b;
285}
286inline SymInt operator%(int64_t a, const SymInt& b) {
287 return c10::SymInt(a) % b;
288}
289inline bool operator==(int64_t a, const SymInt& b) {
290 return c10::SymInt(a) == b;
291}
292inline bool operator!=(int64_t a, const SymInt& b) {
293 return c10::SymInt(a) != b;
294}
295inline bool operator<(int64_t a, const SymInt& b) {
296 return c10::SymInt(a) < b;
297}
298inline bool operator<=(int64_t a, const SymInt& b) {
299 return c10::SymInt(a) <= b;
300}
301inline bool operator>(int64_t a, const SymInt& b) {
302 return c10::SymInt(a) > b;
303}
304inline bool operator>=(int64_t a, const SymInt& b) {
305 return c10::SymInt(a) >= b;
306}
307
308C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s);
309C10_API SymInt operator-(const SymInt& s);
310} // namespace c10
311