1#include <c10/core/SymFloat.h>
2#include <c10/core/SymNodeImpl.h>
3#include <array>
4#include <cmath>
5#include <utility>
6
7namespace c10 {
8
9SymNode SymFloat::toSymNodeImpl() const {
10 TORCH_CHECK(is_symbolic());
11 return SymNode::reclaim_copy(toSymNodeImplUnowned());
12}
13
14static std::array<SymNode, 2> normalize_symfloats(
15 const SymFloat& a_,
16 const SymFloat& b_) {
17 SymNode a, b;
18 if (a_.is_symbolic())
19 a = a_.toSymNodeImpl();
20 if (b_.is_symbolic())
21 b = b_.toSymNodeImpl();
22
23 SymNodeImpl* common = a ? a.get() : b.get();
24 if (!a) {
25 a = common->wrap_float(a_.as_float_unchecked());
26 }
27 if (!b) {
28 b = common->wrap_float(b_.as_float_unchecked());
29 }
30 return {std::move(a), std::move(b)};
31}
32
33SymFloat SymFloat::operator+(const SymFloat& sci) const {
34 if (!is_symbolic() && !sci.is_symbolic()) {
35 return SymFloat(data_ + sci.data_);
36 }
37 auto res = normalize_symfloats(*this, sci);
38 return SymFloat(res[0]->add(res[1]));
39}
40
41SymFloat SymFloat::operator-(const SymFloat& sci) const {
42 if (!is_symbolic() && !sci.is_symbolic()) {
43 return SymFloat(data_ - sci.data_);
44 }
45 auto res = normalize_symfloats(*this, sci);
46 return SymFloat(res[0]->sub(res[1]));
47}
48
49SymFloat SymFloat::operator*(const SymFloat& sci) const {
50 if (!is_symbolic() && !sci.is_symbolic()) {
51 return SymFloat(data_ * sci.data_);
52 }
53 auto res = normalize_symfloats(*this, sci);
54 return SymFloat(res[0]->mul(res[1]));
55}
56
57SymFloat SymFloat::operator/(const SymFloat& sci) const {
58 if (!is_symbolic() && !sci.is_symbolic()) {
59 return SymFloat(data_ / sci.data_);
60 }
61 auto res = normalize_symfloats(*this, sci);
62 return SymFloat(res[0]->truediv(res[1]));
63}
64
65std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
66 if (s.is_symbolic()) {
67 os << s.toSymNodeImpl()->str();
68 } else {
69 os << s.as_float_unchecked();
70 }
71 return os;
72}
73
74SymFloat SymFloat::sqrt() const {
75 if (!is_symbolic()) {
76 return SymFloat(std::sqrt(data_));
77 }
78 auto other = SymFloat(-0.5);
79 auto res = normalize_symfloats(*this, other);
80 return SymFloat(res[0]->pow(res[1]));
81}
82
83double SymFloat::guard_float(const char* file, int64_t line) const {
84 if (!is_symbolic()) {
85 return data_;
86 }
87 SymNode a = toSymNodeImpl();
88 return a->guard_float(file, line);
89}
90
91} // namespace c10
92