1#pragma once
2
3#include <c10/core/SymNodeImpl.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/Exception.h>
6#include <c10/util/intrusive_ptr.h>
7
8namespace c10 {
9
10class C10_API SymBool {
11 public:
12 /*implicit*/ SymBool(bool b) : data_(b){};
13 SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
14 TORCH_CHECK(ptr_->is_bool());
15 };
16 SymBool() : data_(false) {}
17
18 SymNodeImpl* toSymNodeImplUnowned() const {
19 return ptr_.get();
20 }
21
22 SymNodeImpl* release() && {
23 return std::move(ptr_).release();
24 }
25
26 SymNode toSymNodeImpl() const;
27
28 bool expect_bool() const {
29 TORCH_CHECK(!is_symbolic());
30 return data_;
31 }
32
33 SymBool sym_and(const SymBool&) const;
34 SymBool sym_or(const SymBool&) const;
35 SymBool sym_not() const;
36
37 SymBool operator&(const SymBool& other) const {
38 return sym_and(other);
39 }
40 SymBool operator|(const SymBool& other) const {
41 return sym_or(other);
42 }
43 SymBool operator~() const {
44 return sym_not();
45 }
46
47 // Insert a guard for the bool to be its concrete value, and then return
48 // that value. Note that C++ comparison operations default to returning
49 // bool, so it's not so common to have to call this
50 bool guard_bool(const char* file, int64_t line) const;
51
52 C10_ALWAYS_INLINE bool is_symbolic() const {
53 return ptr_;
54 }
55
56 bool as_bool_unchecked() const {
57 return data_;
58 }
59
60 private:
61 // TODO: optimize to union
62 bool data_;
63 SymNode ptr_;
64};
65
66C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
67} // namespace c10
68