1#pragma once
2
3#include "taichi/common/core.h"
4
5#include <vector>
6#include <memory>
7
8namespace taichi {
9namespace tinyir {
10
11template <typename T>
12T ceil_div(T v, T div) {
13 return (v / div) + (v % div ? 1 : 0);
14}
15
16// Forward decl
17class Polymorphic;
18class Node;
19class Type;
20class LayoutContext;
21class MemRefElementTypeInterface;
22class MemRefAggregateTypeInterface;
23class ShapedTypeInterface;
24class AggregateTypeInterface;
25class PointerTypeInterface;
26class Block;
27class Visitor;
28
29class Polymorphic {
30 public:
31 virtual ~Polymorphic() {
32 }
33
34 template <typename T>
35 bool is() const {
36 return dynamic_cast<const T *>(this) != nullptr;
37 }
38
39 template <typename T>
40 T *as() {
41 return static_cast<T *>(this);
42 }
43
44 template <typename T>
45 const T *as() const {
46 return static_cast<const T *>(this);
47 }
48
49 template <typename T>
50 T *cast() {
51 return dynamic_cast<T *>(this);
52 }
53
54 template <typename T>
55 const T *cast() const {
56 return dynamic_cast<const T *>(this);
57 }
58
59 bool operator==(const Polymorphic &other) const {
60 return typeid(*this) == typeid(other) && is_equal(other);
61 }
62
63 bool equals(const Polymorphic *other) const {
64 return (*this) == (*other);
65 }
66
67 private:
68 virtual bool is_equal(const Polymorphic &other) const = 0;
69};
70
71class Node : public Polymorphic {
72 public:
73 using NodeRefs = const std::vector<const Node *>;
74
75 Node() {
76 }
77
78 ~Node() override {
79 }
80
81 const std::string &debug_name() const {
82 return debug_name_;
83 }
84
85 void set_debug_name(const std::string &s) {
86 debug_name_ = s;
87 }
88
89 virtual NodeRefs incoming() const {
90 return {};
91 }
92
93 virtual NodeRefs outgoing() const {
94 return {};
95 }
96
97 virtual bool is_leaf() const {
98 return false;
99 }
100
101 virtual bool is_tree_node() const {
102 return false;
103 }
104
105 private:
106 bool is_equal(const Polymorphic &other) const override {
107 return false;
108 }
109
110 std::string debug_name_;
111};
112
113class Type : public Node {
114 public:
115 Type() {
116 }
117
118 private:
119 bool is_equal(const Polymorphic &other) const override {
120 return false;
121 }
122};
123
124// The default LayoutContext is the standard C layout
125class LayoutContext : public Polymorphic {
126 private:
127 std::unordered_map<const MemRefElementTypeInterface *, size_t> size_cache_;
128 std::unordered_map<const MemRefElementTypeInterface *, size_t>
129 alignment_cache_;
130 std::unordered_map<const MemRefAggregateTypeInterface *, std::vector<size_t>>
131 elem_offset_cache_;
132
133 public:
134 void register_size(const MemRefElementTypeInterface *t, size_t size) {
135 TI_ASSERT(size != 0);
136 size_cache_[t] = size;
137 }
138
139 void register_alignment(const MemRefElementTypeInterface *t, size_t size) {
140 TI_ASSERT(size != 0);
141 alignment_cache_[t] = size;
142 }
143
144 void register_aggregate(const MemRefAggregateTypeInterface *t, int num_elem) {
145 elem_offset_cache_[t] = {};
146 elem_offset_cache_[t].resize(num_elem, 0);
147 }
148
149 void register_elem_offset(const MemRefAggregateTypeInterface *t,
150 int n,
151 size_t offset) {
152 TI_ASSERT(elem_offset_cache_.find(t) != elem_offset_cache_.end());
153 elem_offset_cache_[t][n] = offset;
154 }
155
156 // Size or alignment can not be zero
157 size_t query_size(const MemRefElementTypeInterface *t) {
158 if (size_cache_.find(t) != size_cache_.end()) {
159 return size_cache_[t];
160 } else {
161 return 0;
162 }
163 }
164
165 size_t query_alignment(const MemRefElementTypeInterface *t) {
166 if (alignment_cache_.find(t) != alignment_cache_.end()) {
167 return alignment_cache_[t];
168 } else {
169 return 0;
170 }
171 }
172
173 size_t query_elem_offset(const MemRefAggregateTypeInterface *t, int n) {
174 if (elem_offset_cache_.find(t) != elem_offset_cache_.end()) {
175 return elem_offset_cache_[t][n];
176 } else {
177 return 0;
178 }
179 }
180
181 private:
182 bool is_equal(const Polymorphic &other) const override {
183 // This is only called when `other` has the same typeid
184 return true;
185 }
186};
187
188class MemRefElementTypeInterface {
189 public:
190 virtual size_t memory_size(LayoutContext &ctx) const = 0;
191 virtual size_t memory_alignment_size(LayoutContext &ctx) const = 0;
192};
193
194class MemRefAggregateTypeInterface : public MemRefElementTypeInterface {
195 public:
196 virtual size_t nth_element_offset(int n, LayoutContext &ctx) const = 0;
197};
198
199class AggregateTypeInterface {
200 public:
201 virtual const Type *nth_element_type(int n) const = 0;
202 virtual int get_num_elements() const = 0;
203};
204
205class ShapedTypeInterface {
206 public:
207 virtual const Type *element_type() const = 0;
208 virtual bool is_constant_shape() const = 0;
209 virtual std::vector<size_t> get_constant_shape() const = 0;
210};
211
212class PointerTypeInterface {
213 public:
214 virtual const Type *get_pointed_type() const = 0;
215};
216
217class Block {
218 public:
219 template <typename T, class... E>
220 T *emplace_back(E... args) {
221 nodes_.push_back(std::make_unique<T>(args...));
222 return static_cast<T *>(nodes_.back().get());
223 }
224
225 template <typename T>
226 T *push_back(std::unique_ptr<T> &&val) {
227 T *ptr = val.get();
228 nodes_.push_back(std::move(val));
229 return ptr;
230 }
231
232 const std::vector<std::unique_ptr<Node>> &nodes() const {
233 return nodes_;
234 }
235
236 private:
237 std::vector<std::unique_ptr<Node>> nodes_;
238};
239
240class Visitor {
241 public:
242 virtual ~Visitor() {
243 }
244
245 virtual void visit(const Node *node) {
246 if (node->is<Type>()) {
247 visit_type(node->as<Type>());
248 }
249 }
250
251 virtual void visit_type(const Type *type) {
252 }
253
254 virtual void visit(const Block *block) {
255 for (auto &n : block->nodes()) {
256 visit(n.get());
257 }
258 }
259};
260
261} // namespace tinyir
262} // namespace taichi
263