1 | #include <c10/core/SymFloat.h> |
2 | #include <c10/core/SymNodeImpl.h> |
3 | #include <array> |
4 | #include <cmath> |
5 | #include <utility> |
6 | |
7 | namespace c10 { |
8 | |
9 | SymNode SymFloat::toSymNodeImpl() const { |
10 | TORCH_CHECK(is_symbolic()); |
11 | return SymNode::reclaim_copy(toSymNodeImplUnowned()); |
12 | } |
13 | |
14 | static std::array<SymNode, 2> normalize_symfloats( |
15 | const SymFloat& a_, |
16 | const SymFloat& b_) { |
17 | SymNode a, b; |
18 | if (a_.is_symbolic()) |
19 | a = a_.toSymNodeImpl(); |
20 | if (b_.is_symbolic()) |
21 | b = b_.toSymNodeImpl(); |
22 | |
23 | SymNodeImpl* common = a ? a.get() : b.get(); |
24 | if (!a) { |
25 | a = common->wrap_float(a_.as_float_unchecked()); |
26 | } |
27 | if (!b) { |
28 | b = common->wrap_float(b_.as_float_unchecked()); |
29 | } |
30 | return {std::move(a), std::move(b)}; |
31 | } |
32 | |
33 | SymFloat SymFloat::operator+(const SymFloat& sci) const { |
34 | if (!is_symbolic() && !sci.is_symbolic()) { |
35 | return SymFloat(data_ + sci.data_); |
36 | } |
37 | auto res = normalize_symfloats(*this, sci); |
38 | return SymFloat(res[0]->add(res[1])); |
39 | } |
40 | |
41 | SymFloat SymFloat::operator-(const SymFloat& sci) const { |
42 | if (!is_symbolic() && !sci.is_symbolic()) { |
43 | return SymFloat(data_ - sci.data_); |
44 | } |
45 | auto res = normalize_symfloats(*this, sci); |
46 | return SymFloat(res[0]->sub(res[1])); |
47 | } |
48 | |
49 | SymFloat SymFloat::operator*(const SymFloat& sci) const { |
50 | if (!is_symbolic() && !sci.is_symbolic()) { |
51 | return SymFloat(data_ * sci.data_); |
52 | } |
53 | auto res = normalize_symfloats(*this, sci); |
54 | return SymFloat(res[0]->mul(res[1])); |
55 | } |
56 | |
57 | SymFloat SymFloat::operator/(const SymFloat& sci) const { |
58 | if (!is_symbolic() && !sci.is_symbolic()) { |
59 | return SymFloat(data_ / sci.data_); |
60 | } |
61 | auto res = normalize_symfloats(*this, sci); |
62 | return SymFloat(res[0]->truediv(res[1])); |
63 | } |
64 | |
65 | std::ostream& operator<<(std::ostream& os, const SymFloat& s) { |
66 | if (s.is_symbolic()) { |
67 | os << s.toSymNodeImpl()->str(); |
68 | } else { |
69 | os << s.as_float_unchecked(); |
70 | } |
71 | return os; |
72 | } |
73 | |
74 | SymFloat SymFloat::sqrt() const { |
75 | if (!is_symbolic()) { |
76 | return SymFloat(std::sqrt(data_)); |
77 | } |
78 | auto other = SymFloat(-0.5); |
79 | auto res = normalize_symfloats(*this, other); |
80 | return SymFloat(res[0]->pow(res[1])); |
81 | } |
82 | |
83 | double SymFloat::guard_float(const char* file, int64_t line) const { |
84 | if (!is_symbolic()) { |
85 | return data_; |
86 | } |
87 | SymNode a = toSymNodeImpl(); |
88 | return a->guard_float(file, line); |
89 | } |
90 | |
91 | } // namespace c10 |
92 | |