1 | #pragma once |
2 | |
3 | #include <assert.h> |
4 | #include <stdint.h> |
5 | #include <stdexcept> |
6 | #include <string> |
7 | #include <type_traits> |
8 | #include <utility> |
9 | |
10 | #include <c10/core/OptionalRef.h> |
11 | #include <c10/core/ScalarType.h> |
12 | #include <c10/core/SymFloat.h> |
13 | #include <c10/core/SymInt.h> |
14 | #include <c10/macros/Macros.h> |
15 | #include <c10/util/Exception.h> |
16 | #include <c10/util/Half.h> |
17 | #include <c10/util/TypeCast.h> |
18 | #include <c10/util/intrusive_ptr.h> |
19 | |
20 | C10_CLANG_DIAGNOSTIC_PUSH() |
21 | #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") |
22 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion" ) |
23 | #endif |
24 | |
25 | namespace c10 { |
26 | |
27 | /** |
28 | * Scalar represents a 0-dimensional tensor which contains a single element. |
29 | * Unlike a tensor, numeric literals (in C++) are implicitly convertible to |
30 | * Scalar (which is why, for example, we provide both add(Tensor) and |
31 | * add(Scalar) overloads for many operations). It may also be used in |
32 | * circumstances where you statically know a tensor is 0-dim and single size, |
33 | * but don't know its type. |
34 | */ |
35 | class C10_API Scalar { |
36 | public: |
37 | Scalar() : Scalar(int64_t(0)) {} |
38 | |
39 | void destroy() { |
40 | if (Tag::HAS_si == tag || Tag::HAS_sd == tag) { |
41 | raw::intrusive_ptr::decref(v.p); |
42 | v.p = nullptr; |
43 | } |
44 | } |
45 | |
46 | ~Scalar() { |
47 | destroy(); |
48 | } |
49 | |
50 | #define DEFINE_IMPLICIT_CTOR(type, name) \ |
51 | Scalar(type vv) : Scalar(vv, true) {} |
52 | |
53 | AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) |
54 | AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) |
55 | |
56 | #undef DEFINE_IMPLICIT_CTOR |
57 | |
58 | // Value* is both implicitly convertible to SymbolicVariable and bool which |
59 | // causes ambiguity error. Specialized constructor for bool resolves this |
60 | // problem. |
61 | template < |
62 | typename T, |
63 | typename std::enable_if<std::is_same<T, bool>::value, bool>::type* = |
64 | nullptr> |
65 | Scalar(T vv) : tag(Tag::HAS_b) { |
66 | v.i = convert<int64_t, bool>(vv); |
67 | } |
68 | |
69 | #define DEFINE_ACCESSOR(type, name) \ |
70 | type to##name() const { \ |
71 | if (Tag::HAS_d == tag) { \ |
72 | return checked_convert<type, double>(v.d, #type); \ |
73 | } else if (Tag::HAS_z == tag) { \ |
74 | return checked_convert<type, c10::complex<double>>(v.z, #type); \ |
75 | } \ |
76 | if (Tag::HAS_b == tag) { \ |
77 | return checked_convert<type, bool>(v.i, #type); \ |
78 | } else if (Tag::HAS_i == tag) { \ |
79 | return checked_convert<type, int64_t>(v.i, #type); \ |
80 | } else if (Tag::HAS_si == tag) { \ |
81 | TORCH_CHECK(false, "tried to get " #name " out of SymInt") \ |
82 | } else if (Tag::HAS_sd == tag) { \ |
83 | TORCH_CHECK(false, "tried to get " #name " out of SymFloat") \ |
84 | } \ |
85 | TORCH_CHECK(false) \ |
86 | } |
87 | |
88 | // TODO: Support ComplexHalf accessor |
89 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) |
90 | |
91 | #undef DEFINE_ACCESSOR |
92 | |
93 | SymInt toSymInt() const { |
94 | if (Tag::HAS_si == tag) { |
95 | return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy( |
96 | static_cast<SymNodeImpl*>(v.p))); |
97 | } else { |
98 | return toLong(); |
99 | } |
100 | } |
101 | |
102 | SymFloat toSymFloat() const { |
103 | if (Tag::HAS_sd == tag) { |
104 | return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy( |
105 | static_cast<SymNodeImpl*>(v.p))); |
106 | } else { |
107 | return toDouble(); |
108 | } |
109 | } |
110 | |
111 | // also support scalar.to<int64_t>(); |
112 | // Deleted for unsupported types, but specialized below for supported types |
113 | template <typename T> |
114 | T to() const = delete; |
115 | |
116 | // audit uses of data_ptr |
117 | const void* data_ptr() const { |
118 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
119 | return static_cast<const void*>(&v); |
120 | } |
121 | |
122 | bool isFloatingPoint() const { |
123 | return Tag::HAS_d == tag || Tag::HAS_sd == tag; |
124 | } |
125 | |
126 | C10_DEPRECATED_MESSAGE( |
127 | "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead." ) |
128 | bool isIntegral() const { |
129 | return Tag::HAS_i == tag || Tag::HAS_si == tag; |
130 | } |
131 | bool isIntegral(bool includeBool) const { |
132 | return Tag::HAS_i == tag || Tag::HAS_si == tag || |
133 | (includeBool && isBoolean()); |
134 | } |
135 | |
136 | bool isComplex() const { |
137 | return Tag::HAS_z == tag; |
138 | } |
139 | bool isBoolean() const { |
140 | return Tag::HAS_b == tag; |
141 | } |
142 | |
143 | // you probably don't actually want these; they're mostly for testing |
144 | bool isSymInt() const { |
145 | return Tag::HAS_si == tag; |
146 | } |
147 | bool isSymFloat() const { |
148 | return Tag::HAS_sd == tag; |
149 | } |
150 | |
151 | bool isSymbolic() const { |
152 | return Tag::HAS_si == tag || Tag::HAS_sd == tag; |
153 | } |
154 | |
155 | C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { |
156 | if (&other == this) { |
157 | return *this; |
158 | } |
159 | |
160 | destroy(); |
161 | moveFrom(std::move(other)); |
162 | return *this; |
163 | } |
164 | |
165 | C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { |
166 | if (&other == this) { |
167 | return *this; |
168 | } |
169 | |
170 | *this = Scalar(other); |
171 | return *this; |
172 | } |
173 | |
174 | Scalar operator-() const; |
175 | Scalar conj() const; |
176 | Scalar log() const; |
177 | |
178 | template < |
179 | typename T, |
180 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
181 | bool equal(T num) const { |
182 | if (isComplex()) { |
183 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
184 | auto val = v.z; |
185 | return (val.real() == num) && (val.imag() == T()); |
186 | } else if (isFloatingPoint()) { |
187 | TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality" ); |
188 | return v.d == num; |
189 | } else if (isIntegral(/*includeBool=*/false)) { |
190 | TORCH_CHECK(!isSymbolic(), "NYI SymInt equality" ); |
191 | return v.i == num; |
192 | } else if (isBoolean()) { |
193 | // boolean scalar does not equal to a non boolean value |
194 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
195 | return false; |
196 | } else { |
197 | TORCH_INTERNAL_ASSERT(false); |
198 | } |
199 | } |
200 | |
201 | template < |
202 | typename T, |
203 | typename std::enable_if<c10::is_complex<T>::value, int>::type = 0> |
204 | bool equal(T num) const { |
205 | if (isComplex()) { |
206 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
207 | return v.z == num; |
208 | } else if (isFloatingPoint()) { |
209 | TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality" ); |
210 | return (v.d == num.real()) && (num.imag() == T()); |
211 | } else if (isIntegral(/*includeBool=*/false)) { |
212 | TORCH_CHECK(!isSymbolic(), "NYI SymInt equality" ); |
213 | return (v.i == num.real()) && (num.imag() == T()); |
214 | } else if (isBoolean()) { |
215 | // boolean scalar does not equal to a non boolean value |
216 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
217 | return false; |
218 | } else { |
219 | TORCH_INTERNAL_ASSERT(false); |
220 | } |
221 | } |
222 | |
223 | bool equal(bool num) const { |
224 | if (isBoolean()) { |
225 | TORCH_INTERNAL_ASSERT(!isSymbolic()); |
226 | return static_cast<bool>(v.i) == num; |
227 | } else { |
228 | return false; |
229 | } |
230 | } |
231 | |
232 | ScalarType type() const { |
233 | if (isComplex()) { |
234 | return ScalarType::ComplexDouble; |
235 | } else if (isFloatingPoint()) { |
236 | return ScalarType::Double; |
237 | } else if (isIntegral(/*includeBool=*/false)) { |
238 | return ScalarType::Long; |
239 | } else if (isBoolean()) { |
240 | return ScalarType::Bool; |
241 | } else { |
242 | throw std::runtime_error("Unknown scalar type." ); |
243 | } |
244 | } |
245 | |
246 | Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { |
247 | moveFrom(std::move(rhs)); |
248 | } |
249 | |
250 | Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) { |
251 | if (isSymbolic()) { |
252 | c10::raw::intrusive_ptr::incref(v.p); |
253 | } |
254 | } |
255 | |
256 | Scalar(c10::SymInt si) { |
257 | if (si.is_symbolic()) { |
258 | tag = Tag::HAS_si; |
259 | v.p = std::move(si).release(); |
260 | } else { |
261 | tag = Tag::HAS_i; |
262 | v.i = si.as_int_unchecked(); |
263 | } |
264 | } |
265 | |
266 | Scalar(c10::SymFloat sd) { |
267 | if (sd.is_symbolic()) { |
268 | tag = Tag::HAS_sd; |
269 | v.p = std::move(sd).release(); |
270 | } else { |
271 | tag = Tag::HAS_d; |
272 | v.d = sd.as_float_unchecked(); |
273 | } |
274 | } |
275 | |
276 | // We can't set v in the initializer list using the |
277 | // syntax v{ .member = ... } because it doesn't work on MSVC |
278 | private: |
279 | enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b, HAS_sd, HAS_si }; |
280 | |
281 | // NB: assumes that self has already been cleared |
282 | C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { |
283 | v = rhs.v; |
284 | tag = rhs.tag; |
285 | if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd) { |
286 | // Move out of scalar |
287 | rhs.tag = Tag::HAS_i; |
288 | rhs.v.i = 0; |
289 | } |
290 | } |
291 | |
292 | Tag tag; |
293 | |
294 | union v_t { |
295 | double d{}; |
296 | int64_t i; |
297 | c10::complex<double> z; |
298 | c10::intrusive_ptr_target* p; |
299 | v_t() {} // default constructor |
300 | } v; |
301 | |
302 | template < |
303 | typename T, |
304 | typename std::enable_if< |
305 | std::is_integral<T>::value && !std::is_same<T, bool>::value, |
306 | bool>::type* = nullptr> |
307 | Scalar(T vv, bool) : tag(Tag::HAS_i) { |
308 | v.i = convert<decltype(v.i), T>(vv); |
309 | } |
310 | |
311 | template < |
312 | typename T, |
313 | typename std::enable_if< |
314 | !std::is_integral<T>::value && !c10::is_complex<T>::value, |
315 | bool>::type* = nullptr> |
316 | Scalar(T vv, bool) : tag(Tag::HAS_d) { |
317 | v.d = convert<decltype(v.d), T>(vv); |
318 | } |
319 | |
320 | template < |
321 | typename T, |
322 | typename std::enable_if<c10::is_complex<T>::value, bool>::type* = nullptr> |
323 | Scalar(T vv, bool) : tag(Tag::HAS_z) { |
324 | v.z = convert<decltype(v.z), T>(vv); |
325 | } |
326 | }; |
327 | |
328 | using OptionalScalarRef = c10::OptionalRef<Scalar>; |
329 | |
330 | // define the scalar.to<int64_t>() specializations |
331 | #define DEFINE_TO(T, name) \ |
332 | template <> \ |
333 | inline T Scalar::to<T>() const { \ |
334 | return to##name(); \ |
335 | } |
336 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) |
337 | #undef DEFINE_TO |
338 | |
339 | } // namespace c10 |
340 | |
341 | C10_CLANG_DIAGNOSTIC_POP() |
342 | |