1 | #pragma once |
2 | |
3 | #include <c10/core/SymNodeImpl.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/intrusive_ptr.h> |
7 | |
8 | #include <limits> |
9 | #include <memory> |
10 | |
11 | namespace c10 { |
12 | |
13 | // NB: this is actually double precision; we're using the Python naming here |
14 | class C10_API SymFloat { |
15 | public: |
16 | /*implicit*/ SymFloat(double d) : data_(d){}; |
17 | SymFloat(SymNode ptr) |
18 | : data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) { |
19 | TORCH_CHECK(ptr_->is_float()); |
20 | }; |
21 | SymFloat() : data_(0.0) {} |
22 | |
23 | SymNodeImpl* toSymNodeImplUnowned() const { |
24 | return ptr_.get(); |
25 | } |
26 | |
27 | SymNodeImpl* release() && { |
28 | return std::move(ptr_).release(); |
29 | } |
30 | |
31 | SymNode toSymNodeImpl() const; |
32 | |
33 | double expect_float() const { |
34 | TORCH_CHECK(!is_symbolic()); |
35 | return data_; |
36 | } |
37 | |
38 | SymFloat operator+(const SymFloat&) const; |
39 | SymFloat operator-(const SymFloat&) const; |
40 | SymFloat operator*(const SymFloat&) const; |
41 | SymFloat operator/(const SymFloat&) const; |
42 | |
43 | // Need guidance on where to put this code |
44 | SymFloat sqrt() const; |
45 | |
46 | // Insert a guard for the float to be its concrete value, and then return |
47 | // that value. This operation always works, even if the float is symbolic, |
48 | // so long as we know what the underlying value is. Don't blindly put this |
49 | // everywhere; you can cause overspecialization of PyTorch programs with |
50 | // this method. |
51 | // |
52 | // It should be called as guard_float(__FILE__, __LINE__). The file and line |
53 | // number can be used to diagnose overspecialization. |
54 | double guard_float(const char* file, int64_t line) const; |
55 | |
56 | // N.B. It's important to keep this definition in the header |
57 | // as we expect if checks to be folded for mobile builds |
58 | // where `is_symbolic` is always false |
59 | C10_ALWAYS_INLINE bool is_symbolic() const { |
60 | return ptr_; |
61 | } |
62 | |
63 | double as_float_unchecked() const { |
64 | return data_; |
65 | } |
66 | |
67 | private: |
68 | // TODO: optimize to union |
69 | double data_; |
70 | SymNode ptr_; |
71 | }; |
72 | |
73 | C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s); |
74 | } // namespace c10 |
75 | |