1#include <c10/core/SymBool.h>
2#include <c10/core/SymNodeImpl.h>
3#include <array>
4#include <utility>
5
6namespace c10 {
7
8SymNode SymBool::toSymNodeImpl() const {
9 TORCH_CHECK(is_symbolic());
10 return SymNode::reclaim_copy(toSymNodeImplUnowned());
11}
12
13static std::array<SymNode, 2> normalize_symbools(
14 const SymBool& a_,
15 const SymBool& b_) {
16 SymNode a, b;
17 if (a_.is_symbolic())
18 a = a_.toSymNodeImpl();
19 if (b_.is_symbolic())
20 b = b_.toSymNodeImpl();
21
22 SymNodeImpl* common = a ? a.get() : b.get();
23 if (!a) {
24 a = common->wrap_bool(a_.as_bool_unchecked());
25 }
26 if (!b) {
27 b = common->wrap_bool(b_.as_bool_unchecked());
28 }
29 return {std::move(a), std::move(b)};
30}
31
32SymBool SymBool::sym_and(const SymBool& sci) const {
33 if (!is_symbolic() && !sci.is_symbolic()) {
34 return SymBool(data_ && sci.data_);
35 }
36 auto res = normalize_symbools(*this, sci);
37 return SymBool(res[0]->sym_and(res[1]));
38}
39
40SymBool SymBool::sym_or(const SymBool& sci) const {
41 if (!is_symbolic() && !sci.is_symbolic()) {
42 return SymBool(data_ || sci.data_);
43 }
44 auto res = normalize_symbools(*this, sci);
45 return SymBool(res[0]->sym_or(res[1]));
46}
47
48SymBool SymBool::sym_not() const {
49 if (!is_symbolic()) {
50 return SymBool(!data_);
51 }
52 return SymBool(toSymNodeImpl()->sym_not());
53}
54
55std::ostream& operator<<(std::ostream& os, const SymBool& s) {
56 if (s.is_symbolic()) {
57 os << s.toSymNodeImpl()->str();
58 } else {
59 os << s.as_bool_unchecked();
60 }
61 return os;
62}
63
64bool SymBool::guard_bool(const char* file, int64_t line) const {
65 if (!is_symbolic()) {
66 return data_;
67 }
68 SymNode a = toSymNodeImpl();
69 return a->guard_bool(file, line);
70}
71
72} // namespace c10
73