1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <c10/util/ArrayRef.h>
5#include <c10/util/Exception.h>
6#include <c10/util/intrusive_ptr.h>
7#include <memory>
8
9namespace c10 {
10
11class SymNodeImpl;
12using SymNode = c10::intrusive_ptr<SymNodeImpl>;
13
14class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
15 public:
16 ~SymNodeImpl() override = default;
17
18 template <typename T>
19 c10::intrusive_ptr<T> dyn_cast() const {
20 return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
21 }
22
23 // these could be pure virtual when we implement LTC versions
24 virtual bool is_int() {
25 TORCH_CHECK(false, "NYI");
26 };
27 virtual bool is_bool() {
28 TORCH_CHECK(false, "NYI");
29 };
30 virtual bool is_float() {
31 TORCH_CHECK(false, "NYI");
32 };
33 virtual SymNode add(const SymNode& other) {
34 TORCH_CHECK(false, "NYI");
35 };
36 virtual SymNode sub(const SymNode& other) {
37 TORCH_CHECK(false, "NYI");
38 };
39 virtual SymNode mul(const SymNode& other) {
40 TORCH_CHECK(false, "NYI");
41 };
42 virtual SymNode truediv(const SymNode& other) {
43 TORCH_CHECK(false, "NYI");
44 };
45 virtual SymNode pow(const SymNode& other) {
46 TORCH_CHECK(false, "NYI");
47 };
48 virtual SymNode floordiv(const SymNode& other) {
49 TORCH_CHECK(false, "NYI");
50 };
51 virtual SymNode mod(const SymNode& other) {
52 TORCH_CHECK(false, "NYI");
53 };
54 virtual SymNode eq(const SymNode& other) {
55 TORCH_CHECK(false, "NYI");
56 };
57 virtual SymNode ne(const SymNode& other) {
58 TORCH_CHECK(false, "NYI");
59 };
60 virtual SymNode gt(const SymNode& other) {
61 TORCH_CHECK(false, "NYI");
62 };
63 virtual SymNode lt(const SymNode& other) {
64 TORCH_CHECK(false, "NYI");
65 };
66 virtual SymNode le(const SymNode& other) {
67 TORCH_CHECK(false, "NYI");
68 };
69 virtual SymNode ge(const SymNode& other) {
70 TORCH_CHECK(false, "NYI");
71 };
72 virtual SymNode ceil() {
73 TORCH_CHECK(false, "NYI");
74 };
75 virtual SymNode floor() {
76 TORCH_CHECK(false, "NYI");
77 };
78 virtual SymNode neg() {
79 TORCH_CHECK(false, "NYI");
80 };
81 virtual SymNode sym_min(const SymNode& other) {
82 TORCH_CHECK(false, "NYI");
83 };
84 virtual SymNode sym_max(const SymNode& other) {
85 TORCH_CHECK(false, "NYI");
86 };
87 virtual SymNode sym_or(const SymNode& other) {
88 TORCH_CHECK(false, "NYI");
89 };
90 virtual SymNode sym_and(const SymNode& other) {
91 TORCH_CHECK(false, "NYI");
92 };
93 virtual SymNode sym_not() {
94 TORCH_CHECK(false, "NYI");
95 };
96 // NB: self is ignored here, only the arguments are used
97 virtual SymNode is_non_overlapping_and_dense(
98 ArrayRef<SymNode> sizes,
99 ArrayRef<SymNode> strides) {
100 TORCH_CHECK(false, "NYI");
101 };
102 virtual SymNode clone() {
103 TORCH_CHECK(false, "NYI");
104 };
105 virtual SymNode sym_float() {
106 TORCH_CHECK(false, "NYI");
107 }
108 virtual SymNode wrap_int(int64_t num) {
109 TORCH_CHECK(false, "NYI");
110 };
111 virtual SymNode wrap_float(double num) {
112 TORCH_CHECK(false, "NYI");
113 };
114 virtual SymNode wrap_bool(bool num) {
115 TORCH_CHECK(false, "NYI");
116 };
117 virtual int64_t guard_int(const char* file, int64_t line) {
118 TORCH_CHECK(false, "NYI");
119 };
120 virtual bool guard_bool(const char* file, int64_t line) {
121 TORCH_CHECK(false, "NYI");
122 };
123 virtual double guard_float(const char* file, int64_t line) {
124 TORCH_CHECK(false, "NYI");
125 };
126 virtual int64_t int_() {
127 TORCH_CHECK(false, "NYI");
128 };
129 virtual bool bool_() {
130 TORCH_CHECK(false, "NYI");
131 };
132 virtual std::string str() {
133 TORCH_CHECK(false, "NYI");
134 };
135 std::ostream& operator<<(std::ostream& os) {
136 os << str();
137 return os;
138 };
139};
140
141} // namespace c10
142