1 | #pragma once |
2 | |
3 | #include <c10/macros/Macros.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/intrusive_ptr.h> |
7 | #include <memory> |
8 | |
9 | namespace c10 { |
10 | |
11 | class SymNodeImpl; |
12 | using SymNode = c10::intrusive_ptr<SymNodeImpl>; |
13 | |
14 | class C10_API SymNodeImpl : public c10::intrusive_ptr_target { |
15 | public: |
16 | ~SymNodeImpl() override = default; |
17 | |
18 | template <typename T> |
19 | c10::intrusive_ptr<T> dyn_cast() const { |
20 | return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this)); |
21 | } |
22 | |
23 | // these could be pure virtual when we implement LTC versions |
24 | virtual bool is_int() { |
25 | TORCH_CHECK(false, "NYI" ); |
26 | }; |
27 | virtual bool is_bool() { |
28 | TORCH_CHECK(false, "NYI" ); |
29 | }; |
30 | virtual bool is_float() { |
31 | TORCH_CHECK(false, "NYI" ); |
32 | }; |
33 | virtual SymNode add(const SymNode& other) { |
34 | TORCH_CHECK(false, "NYI" ); |
35 | }; |
36 | virtual SymNode sub(const SymNode& other) { |
37 | TORCH_CHECK(false, "NYI" ); |
38 | }; |
39 | virtual SymNode mul(const SymNode& other) { |
40 | TORCH_CHECK(false, "NYI" ); |
41 | }; |
42 | virtual SymNode truediv(const SymNode& other) { |
43 | TORCH_CHECK(false, "NYI" ); |
44 | }; |
45 | virtual SymNode pow(const SymNode& other) { |
46 | TORCH_CHECK(false, "NYI" ); |
47 | }; |
48 | virtual SymNode floordiv(const SymNode& other) { |
49 | TORCH_CHECK(false, "NYI" ); |
50 | }; |
51 | virtual SymNode mod(const SymNode& other) { |
52 | TORCH_CHECK(false, "NYI" ); |
53 | }; |
54 | virtual SymNode eq(const SymNode& other) { |
55 | TORCH_CHECK(false, "NYI" ); |
56 | }; |
57 | virtual SymNode ne(const SymNode& other) { |
58 | TORCH_CHECK(false, "NYI" ); |
59 | }; |
60 | virtual SymNode gt(const SymNode& other) { |
61 | TORCH_CHECK(false, "NYI" ); |
62 | }; |
63 | virtual SymNode lt(const SymNode& other) { |
64 | TORCH_CHECK(false, "NYI" ); |
65 | }; |
66 | virtual SymNode le(const SymNode& other) { |
67 | TORCH_CHECK(false, "NYI" ); |
68 | }; |
69 | virtual SymNode ge(const SymNode& other) { |
70 | TORCH_CHECK(false, "NYI" ); |
71 | }; |
72 | virtual SymNode ceil() { |
73 | TORCH_CHECK(false, "NYI" ); |
74 | }; |
75 | virtual SymNode floor() { |
76 | TORCH_CHECK(false, "NYI" ); |
77 | }; |
78 | virtual SymNode neg() { |
79 | TORCH_CHECK(false, "NYI" ); |
80 | }; |
81 | virtual SymNode sym_min(const SymNode& other) { |
82 | TORCH_CHECK(false, "NYI" ); |
83 | }; |
84 | virtual SymNode sym_max(const SymNode& other) { |
85 | TORCH_CHECK(false, "NYI" ); |
86 | }; |
87 | virtual SymNode sym_or(const SymNode& other) { |
88 | TORCH_CHECK(false, "NYI" ); |
89 | }; |
90 | virtual SymNode sym_and(const SymNode& other) { |
91 | TORCH_CHECK(false, "NYI" ); |
92 | }; |
93 | virtual SymNode sym_not() { |
94 | TORCH_CHECK(false, "NYI" ); |
95 | }; |
96 | // NB: self is ignored here, only the arguments are used |
97 | virtual SymNode is_non_overlapping_and_dense( |
98 | ArrayRef<SymNode> sizes, |
99 | ArrayRef<SymNode> strides) { |
100 | TORCH_CHECK(false, "NYI" ); |
101 | }; |
102 | virtual SymNode clone() { |
103 | TORCH_CHECK(false, "NYI" ); |
104 | }; |
105 | virtual SymNode sym_float() { |
106 | TORCH_CHECK(false, "NYI" ); |
107 | } |
108 | virtual SymNode wrap_int(int64_t num) { |
109 | TORCH_CHECK(false, "NYI" ); |
110 | }; |
111 | virtual SymNode wrap_float(double num) { |
112 | TORCH_CHECK(false, "NYI" ); |
113 | }; |
114 | virtual SymNode wrap_bool(bool num) { |
115 | TORCH_CHECK(false, "NYI" ); |
116 | }; |
117 | virtual int64_t guard_int(const char* file, int64_t line) { |
118 | TORCH_CHECK(false, "NYI" ); |
119 | }; |
120 | virtual bool guard_bool(const char* file, int64_t line) { |
121 | TORCH_CHECK(false, "NYI" ); |
122 | }; |
123 | virtual double guard_float(const char* file, int64_t line) { |
124 | TORCH_CHECK(false, "NYI" ); |
125 | }; |
126 | virtual int64_t int_() { |
127 | TORCH_CHECK(false, "NYI" ); |
128 | }; |
129 | virtual bool bool_() { |
130 | TORCH_CHECK(false, "NYI" ); |
131 | }; |
132 | virtual std::string str() { |
133 | TORCH_CHECK(false, "NYI" ); |
134 | }; |
135 | std::ostream& operator<<(std::ostream& os) { |
136 | os << str(); |
137 | return os; |
138 | }; |
139 | }; |
140 | |
141 | } // namespace c10 |
142 | |