1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/SymNodeImpl.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/intrusive_ptr.h> |
7 | |
8 | namespace c10 { |
9 | |
10 | class C10_API SymBool { |
11 | public: |
12 | /*implicit*/ SymBool(bool b) : data_(b){}; |
13 | SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) { |
14 | TORCH_CHECK(ptr_->is_bool()); |
15 | }; |
16 | SymBool() : data_(false) {} |
17 | |
18 | SymNodeImpl* toSymNodeImplUnowned() const { |
19 | return ptr_.get(); |
20 | } |
21 | |
22 | SymNodeImpl* release() && { |
23 | return std::move(ptr_).release(); |
24 | } |
25 | |
26 | SymNode toSymNodeImpl() const; |
27 | |
28 | bool expect_bool() const { |
29 | TORCH_CHECK(!is_symbolic()); |
30 | return data_; |
31 | } |
32 | |
33 | SymBool sym_and(const SymBool&) const; |
34 | SymBool sym_or(const SymBool&) const; |
35 | SymBool sym_not() const; |
36 | |
37 | SymBool operator&(const SymBool& other) const { |
38 | return sym_and(other); |
39 | } |
40 | SymBool operator|(const SymBool& other) const { |
41 | return sym_or(other); |
42 | } |
43 | SymBool operator~() const { |
44 | return sym_not(); |
45 | } |
46 | |
47 | // Insert a guard for the bool to be its concrete value, and then return |
48 | // that value. Note that C++ comparison operations default to returning |
49 | // bool, so it's not so common to have to call this |
50 | bool guard_bool(const char* file, int64_t line) const; |
51 | |
52 | C10_ALWAYS_INLINE bool is_symbolic() const { |
53 | return ptr_; |
54 | } |
55 | |
56 | bool as_bool_unchecked() const { |
57 | return data_; |
58 | } |
59 | |
60 | private: |
61 | // TODO: optimize to union |
62 | bool data_; |
63 | SymNode ptr_; |
64 | }; |
65 | |
66 | C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s); |
67 | } // namespace c10 |
68 |