1#pragma once
2
3#include <assert.h>
4#include <stdint.h>
5#include <stdexcept>
6#include <string>
7#include <type_traits>
8#include <utility>
9
10#include <c10/core/OptionalRef.h>
11#include <c10/core/ScalarType.h>
12#include <c10/core/SymFloat.h>
13#include <c10/core/SymInt.h>
14#include <c10/macros/Macros.h>
15#include <c10/util/Exception.h>
16#include <c10/util/Half.h>
17#include <c10/util/TypeCast.h>
18#include <c10/util/intrusive_ptr.h>
19
20C10_CLANG_DIAGNOSTIC_PUSH()
21#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
22C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
23#endif
24
25namespace c10 {
26
27/**
28 * Scalar represents a 0-dimensional tensor which contains a single element.
29 * Unlike a tensor, numeric literals (in C++) are implicitly convertible to
30 * Scalar (which is why, for example, we provide both add(Tensor) and
31 * add(Scalar) overloads for many operations). It may also be used in
32 * circumstances where you statically know a tensor is 0-dim and single size,
33 * but don't know its type.
34 */
35class C10_API Scalar {
36 public:
37 Scalar() : Scalar(int64_t(0)) {}
38
39 void destroy() {
40 if (Tag::HAS_si == tag || Tag::HAS_sd == tag) {
41 raw::intrusive_ptr::decref(v.p);
42 v.p = nullptr;
43 }
44 }
45
46 ~Scalar() {
47 destroy();
48 }
49
50#define DEFINE_IMPLICIT_CTOR(type, name) \
51 Scalar(type vv) : Scalar(vv, true) {}
52
53 AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
54 AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
55
56#undef DEFINE_IMPLICIT_CTOR
57
58 // Value* is both implicitly convertible to SymbolicVariable and bool which
59 // causes ambiguity error. Specialized constructor for bool resolves this
60 // problem.
61 template <
62 typename T,
63 typename std::enable_if<std::is_same<T, bool>::value, bool>::type* =
64 nullptr>
65 Scalar(T vv) : tag(Tag::HAS_b) {
66 v.i = convert<int64_t, bool>(vv);
67 }
68
69#define DEFINE_ACCESSOR(type, name) \
70 type to##name() const { \
71 if (Tag::HAS_d == tag) { \
72 return checked_convert<type, double>(v.d, #type); \
73 } else if (Tag::HAS_z == tag) { \
74 return checked_convert<type, c10::complex<double>>(v.z, #type); \
75 } \
76 if (Tag::HAS_b == tag) { \
77 return checked_convert<type, bool>(v.i, #type); \
78 } else if (Tag::HAS_i == tag) { \
79 return checked_convert<type, int64_t>(v.i, #type); \
80 } else if (Tag::HAS_si == tag) { \
81 TORCH_CHECK(false, "tried to get " #name " out of SymInt") \
82 } else if (Tag::HAS_sd == tag) { \
83 TORCH_CHECK(false, "tried to get " #name " out of SymFloat") \
84 } \
85 TORCH_CHECK(false) \
86 }
87
88 // TODO: Support ComplexHalf accessor
89 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
90
91#undef DEFINE_ACCESSOR
92
93 SymInt toSymInt() const {
94 if (Tag::HAS_si == tag) {
95 return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
96 static_cast<SymNodeImpl*>(v.p)));
97 } else {
98 return toLong();
99 }
100 }
101
102 SymFloat toSymFloat() const {
103 if (Tag::HAS_sd == tag) {
104 return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
105 static_cast<SymNodeImpl*>(v.p)));
106 } else {
107 return toDouble();
108 }
109 }
110
111 // also support scalar.to<int64_t>();
112 // Deleted for unsupported types, but specialized below for supported types
113 template <typename T>
114 T to() const = delete;
115
116 // audit uses of data_ptr
117 const void* data_ptr() const {
118 TORCH_INTERNAL_ASSERT(!isSymbolic());
119 return static_cast<const void*>(&v);
120 }
121
122 bool isFloatingPoint() const {
123 return Tag::HAS_d == tag || Tag::HAS_sd == tag;
124 }
125
126 C10_DEPRECATED_MESSAGE(
127 "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
128 bool isIntegral() const {
129 return Tag::HAS_i == tag || Tag::HAS_si == tag;
130 }
131 bool isIntegral(bool includeBool) const {
132 return Tag::HAS_i == tag || Tag::HAS_si == tag ||
133 (includeBool && isBoolean());
134 }
135
136 bool isComplex() const {
137 return Tag::HAS_z == tag;
138 }
139 bool isBoolean() const {
140 return Tag::HAS_b == tag;
141 }
142
143 // you probably don't actually want these; they're mostly for testing
144 bool isSymInt() const {
145 return Tag::HAS_si == tag;
146 }
147 bool isSymFloat() const {
148 return Tag::HAS_sd == tag;
149 }
150
151 bool isSymbolic() const {
152 return Tag::HAS_si == tag || Tag::HAS_sd == tag;
153 }
154
155 C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept {
156 if (&other == this) {
157 return *this;
158 }
159
160 destroy();
161 moveFrom(std::move(other));
162 return *this;
163 }
164
165 C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
166 if (&other == this) {
167 return *this;
168 }
169
170 *this = Scalar(other);
171 return *this;
172 }
173
174 Scalar operator-() const;
175 Scalar conj() const;
176 Scalar log() const;
177
178 template <
179 typename T,
180 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
181 bool equal(T num) const {
182 if (isComplex()) {
183 TORCH_INTERNAL_ASSERT(!isSymbolic());
184 auto val = v.z;
185 return (val.real() == num) && (val.imag() == T());
186 } else if (isFloatingPoint()) {
187 TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
188 return v.d == num;
189 } else if (isIntegral(/*includeBool=*/false)) {
190 TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
191 return v.i == num;
192 } else if (isBoolean()) {
193 // boolean scalar does not equal to a non boolean value
194 TORCH_INTERNAL_ASSERT(!isSymbolic());
195 return false;
196 } else {
197 TORCH_INTERNAL_ASSERT(false);
198 }
199 }
200
201 template <
202 typename T,
203 typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
204 bool equal(T num) const {
205 if (isComplex()) {
206 TORCH_INTERNAL_ASSERT(!isSymbolic());
207 return v.z == num;
208 } else if (isFloatingPoint()) {
209 TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
210 return (v.d == num.real()) && (num.imag() == T());
211 } else if (isIntegral(/*includeBool=*/false)) {
212 TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
213 return (v.i == num.real()) && (num.imag() == T());
214 } else if (isBoolean()) {
215 // boolean scalar does not equal to a non boolean value
216 TORCH_INTERNAL_ASSERT(!isSymbolic());
217 return false;
218 } else {
219 TORCH_INTERNAL_ASSERT(false);
220 }
221 }
222
223 bool equal(bool num) const {
224 if (isBoolean()) {
225 TORCH_INTERNAL_ASSERT(!isSymbolic());
226 return static_cast<bool>(v.i) == num;
227 } else {
228 return false;
229 }
230 }
231
232 ScalarType type() const {
233 if (isComplex()) {
234 return ScalarType::ComplexDouble;
235 } else if (isFloatingPoint()) {
236 return ScalarType::Double;
237 } else if (isIntegral(/*includeBool=*/false)) {
238 return ScalarType::Long;
239 } else if (isBoolean()) {
240 return ScalarType::Bool;
241 } else {
242 throw std::runtime_error("Unknown scalar type.");
243 }
244 }
245
246 Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
247 moveFrom(std::move(rhs));
248 }
249
250 Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
251 if (isSymbolic()) {
252 c10::raw::intrusive_ptr::incref(v.p);
253 }
254 }
255
256 Scalar(c10::SymInt si) {
257 if (si.is_symbolic()) {
258 tag = Tag::HAS_si;
259 v.p = std::move(si).release();
260 } else {
261 tag = Tag::HAS_i;
262 v.i = si.as_int_unchecked();
263 }
264 }
265
266 Scalar(c10::SymFloat sd) {
267 if (sd.is_symbolic()) {
268 tag = Tag::HAS_sd;
269 v.p = std::move(sd).release();
270 } else {
271 tag = Tag::HAS_d;
272 v.d = sd.as_float_unchecked();
273 }
274 }
275
276 // We can't set v in the initializer list using the
277 // syntax v{ .member = ... } because it doesn't work on MSVC
278 private:
279 enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b, HAS_sd, HAS_si };
280
281 // NB: assumes that self has already been cleared
282 C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
283 v = rhs.v;
284 tag = rhs.tag;
285 if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd) {
286 // Move out of scalar
287 rhs.tag = Tag::HAS_i;
288 rhs.v.i = 0;
289 }
290 }
291
292 Tag tag;
293
294 union v_t {
295 double d{};
296 int64_t i;
297 c10::complex<double> z;
298 c10::intrusive_ptr_target* p;
299 v_t() {} // default constructor
300 } v;
301
302 template <
303 typename T,
304 typename std::enable_if<
305 std::is_integral<T>::value && !std::is_same<T, bool>::value,
306 bool>::type* = nullptr>
307 Scalar(T vv, bool) : tag(Tag::HAS_i) {
308 v.i = convert<decltype(v.i), T>(vv);
309 }
310
311 template <
312 typename T,
313 typename std::enable_if<
314 !std::is_integral<T>::value && !c10::is_complex<T>::value,
315 bool>::type* = nullptr>
316 Scalar(T vv, bool) : tag(Tag::HAS_d) {
317 v.d = convert<decltype(v.d), T>(vv);
318 }
319
320 template <
321 typename T,
322 typename std::enable_if<c10::is_complex<T>::value, bool>::type* = nullptr>
323 Scalar(T vv, bool) : tag(Tag::HAS_z) {
324 v.z = convert<decltype(v.z), T>(vv);
325 }
326};
327
328using OptionalScalarRef = c10::OptionalRef<Scalar>;
329
330// define the scalar.to<int64_t>() specializations
331#define DEFINE_TO(T, name) \
332 template <> \
333 inline T Scalar::to<T>() const { \
334 return to##name(); \
335 }
336AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO)
337#undef DEFINE_TO
338
339} // namespace c10
340
341C10_CLANG_DIAGNOSTIC_POP()
342