1 | #include <c10/core/SymBool.h> |
---|---|
2 | #include <c10/core/SymNodeImpl.h> |
3 | #include <array> |
4 | #include <utility> |
5 | |
6 | namespace c10 { |
7 | |
8 | SymNode SymBool::toSymNodeImpl() const { |
9 | TORCH_CHECK(is_symbolic()); |
10 | return SymNode::reclaim_copy(toSymNodeImplUnowned()); |
11 | } |
12 | |
13 | static std::array<SymNode, 2> normalize_symbools( |
14 | const SymBool& a_, |
15 | const SymBool& b_) { |
16 | SymNode a, b; |
17 | if (a_.is_symbolic()) |
18 | a = a_.toSymNodeImpl(); |
19 | if (b_.is_symbolic()) |
20 | b = b_.toSymNodeImpl(); |
21 | |
22 | SymNodeImpl* common = a ? a.get() : b.get(); |
23 | if (!a) { |
24 | a = common->wrap_bool(a_.as_bool_unchecked()); |
25 | } |
26 | if (!b) { |
27 | b = common->wrap_bool(b_.as_bool_unchecked()); |
28 | } |
29 | return {std::move(a), std::move(b)}; |
30 | } |
31 | |
32 | SymBool SymBool::sym_and(const SymBool& sci) const { |
33 | if (!is_symbolic() && !sci.is_symbolic()) { |
34 | return SymBool(data_ && sci.data_); |
35 | } |
36 | auto res = normalize_symbools(*this, sci); |
37 | return SymBool(res[0]->sym_and(res[1])); |
38 | } |
39 | |
40 | SymBool SymBool::sym_or(const SymBool& sci) const { |
41 | if (!is_symbolic() && !sci.is_symbolic()) { |
42 | return SymBool(data_ || sci.data_); |
43 | } |
44 | auto res = normalize_symbools(*this, sci); |
45 | return SymBool(res[0]->sym_or(res[1])); |
46 | } |
47 | |
48 | SymBool SymBool::sym_not() const { |
49 | if (!is_symbolic()) { |
50 | return SymBool(!data_); |
51 | } |
52 | return SymBool(toSymNodeImpl()->sym_not()); |
53 | } |
54 | |
55 | std::ostream& operator<<(std::ostream& os, const SymBool& s) { |
56 | if (s.is_symbolic()) { |
57 | os << s.toSymNodeImpl()->str(); |
58 | } else { |
59 | os << s.as_bool_unchecked(); |
60 | } |
61 | return os; |
62 | } |
63 | |
64 | bool SymBool::guard_bool(const char* file, int64_t line) const { |
65 | if (!is_symbolic()) { |
66 | return data_; |
67 | } |
68 | SymNode a = toSymNodeImpl(); |
69 | return a->guard_bool(file, line); |
70 | } |
71 | |
72 | } // namespace c10 |
73 |