1#pragma once
2
3#include <c10/core/SymNodeImpl.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/Exception.h>
6#include <c10/util/intrusive_ptr.h>
7
8#include <limits>
9#include <memory>
10
11namespace c10 {
12
13// NB: this is actually double precision; we're using the Python naming here
14class C10_API SymFloat {
15 public:
16 /*implicit*/ SymFloat(double d) : data_(d){};
17 SymFloat(SymNode ptr)
18 : data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
19 TORCH_CHECK(ptr_->is_float());
20 };
21 SymFloat() : data_(0.0) {}
22
23 SymNodeImpl* toSymNodeImplUnowned() const {
24 return ptr_.get();
25 }
26
27 SymNodeImpl* release() && {
28 return std::move(ptr_).release();
29 }
30
31 SymNode toSymNodeImpl() const;
32
33 double expect_float() const {
34 TORCH_CHECK(!is_symbolic());
35 return data_;
36 }
37
38 SymFloat operator+(const SymFloat&) const;
39 SymFloat operator-(const SymFloat&) const;
40 SymFloat operator*(const SymFloat&) const;
41 SymFloat operator/(const SymFloat&) const;
42
43 // Need guidance on where to put this code
44 SymFloat sqrt() const;
45
46 // Insert a guard for the float to be its concrete value, and then return
47 // that value. This operation always works, even if the float is symbolic,
48 // so long as we know what the underlying value is. Don't blindly put this
49 // everywhere; you can cause overspecialization of PyTorch programs with
50 // this method.
51 //
52 // It should be called as guard_float(__FILE__, __LINE__). The file and line
53 // number can be used to diagnose overspecialization.
54 double guard_float(const char* file, int64_t line) const;
55
56 // N.B. It's important to keep this definition in the header
57 // as we expect if checks to be folded for mobile builds
58 // where `is_symbolic` is always false
59 C10_ALWAYS_INLINE bool is_symbolic() const {
60 return ptr_;
61 }
62
63 double as_float_unchecked() const {
64 return data_;
65 }
66
67 private:
68 // TODO: optimize to union
69 double data_;
70 SymNode ptr_;
71};
72
73C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s);
74} // namespace c10
75