1 | #pragma once |
2 | |
3 | #include "taichi/common/core.h" |
4 | |
5 | #include <vector> |
6 | #include <memory> |
7 | |
8 | namespace taichi { |
9 | namespace tinyir { |
10 | |
11 | template <typename T> |
12 | T ceil_div(T v, T div) { |
13 | return (v / div) + (v % div ? 1 : 0); |
14 | } |
15 | |
16 | // Forward decl |
17 | class Polymorphic; |
18 | class Node; |
19 | class Type; |
20 | class LayoutContext; |
21 | class MemRefElementTypeInterface; |
22 | class MemRefAggregateTypeInterface; |
23 | class ShapedTypeInterface; |
24 | class AggregateTypeInterface; |
25 | class PointerTypeInterface; |
26 | class Block; |
27 | class Visitor; |
28 | |
29 | class Polymorphic { |
30 | public: |
31 | virtual ~Polymorphic() { |
32 | } |
33 | |
34 | template <typename T> |
35 | bool is() const { |
36 | return dynamic_cast<const T *>(this) != nullptr; |
37 | } |
38 | |
39 | template <typename T> |
40 | T *as() { |
41 | return static_cast<T *>(this); |
42 | } |
43 | |
44 | template <typename T> |
45 | const T *as() const { |
46 | return static_cast<const T *>(this); |
47 | } |
48 | |
49 | template <typename T> |
50 | T *cast() { |
51 | return dynamic_cast<T *>(this); |
52 | } |
53 | |
54 | template <typename T> |
55 | const T *cast() const { |
56 | return dynamic_cast<const T *>(this); |
57 | } |
58 | |
59 | bool operator==(const Polymorphic &other) const { |
60 | return typeid(*this) == typeid(other) && is_equal(other); |
61 | } |
62 | |
63 | bool equals(const Polymorphic *other) const { |
64 | return (*this) == (*other); |
65 | } |
66 | |
67 | private: |
68 | virtual bool is_equal(const Polymorphic &other) const = 0; |
69 | }; |
70 | |
71 | class Node : public Polymorphic { |
72 | public: |
73 | using NodeRefs = const std::vector<const Node *>; |
74 | |
75 | Node() { |
76 | } |
77 | |
78 | ~Node() override { |
79 | } |
80 | |
81 | const std::string &debug_name() const { |
82 | return debug_name_; |
83 | } |
84 | |
85 | void set_debug_name(const std::string &s) { |
86 | debug_name_ = s; |
87 | } |
88 | |
89 | virtual NodeRefs incoming() const { |
90 | return {}; |
91 | } |
92 | |
93 | virtual NodeRefs outgoing() const { |
94 | return {}; |
95 | } |
96 | |
97 | virtual bool is_leaf() const { |
98 | return false; |
99 | } |
100 | |
101 | virtual bool is_tree_node() const { |
102 | return false; |
103 | } |
104 | |
105 | private: |
106 | bool is_equal(const Polymorphic &other) const override { |
107 | return false; |
108 | } |
109 | |
110 | std::string debug_name_; |
111 | }; |
112 | |
113 | class Type : public Node { |
114 | public: |
115 | Type() { |
116 | } |
117 | |
118 | private: |
119 | bool is_equal(const Polymorphic &other) const override { |
120 | return false; |
121 | } |
122 | }; |
123 | |
124 | // The default LayoutContext is the standard C layout |
125 | class LayoutContext : public Polymorphic { |
126 | private: |
127 | std::unordered_map<const MemRefElementTypeInterface *, size_t> size_cache_; |
128 | std::unordered_map<const MemRefElementTypeInterface *, size_t> |
129 | alignment_cache_; |
130 | std::unordered_map<const MemRefAggregateTypeInterface *, std::vector<size_t>> |
131 | elem_offset_cache_; |
132 | |
133 | public: |
134 | void register_size(const MemRefElementTypeInterface *t, size_t size) { |
135 | TI_ASSERT(size != 0); |
136 | size_cache_[t] = size; |
137 | } |
138 | |
139 | void register_alignment(const MemRefElementTypeInterface *t, size_t size) { |
140 | TI_ASSERT(size != 0); |
141 | alignment_cache_[t] = size; |
142 | } |
143 | |
144 | void register_aggregate(const MemRefAggregateTypeInterface *t, int num_elem) { |
145 | elem_offset_cache_[t] = {}; |
146 | elem_offset_cache_[t].resize(num_elem, 0); |
147 | } |
148 | |
149 | void register_elem_offset(const MemRefAggregateTypeInterface *t, |
150 | int n, |
151 | size_t offset) { |
152 | TI_ASSERT(elem_offset_cache_.find(t) != elem_offset_cache_.end()); |
153 | elem_offset_cache_[t][n] = offset; |
154 | } |
155 | |
156 | // Size or alignment can not be zero |
157 | size_t query_size(const MemRefElementTypeInterface *t) { |
158 | if (size_cache_.find(t) != size_cache_.end()) { |
159 | return size_cache_[t]; |
160 | } else { |
161 | return 0; |
162 | } |
163 | } |
164 | |
165 | size_t query_alignment(const MemRefElementTypeInterface *t) { |
166 | if (alignment_cache_.find(t) != alignment_cache_.end()) { |
167 | return alignment_cache_[t]; |
168 | } else { |
169 | return 0; |
170 | } |
171 | } |
172 | |
173 | size_t query_elem_offset(const MemRefAggregateTypeInterface *t, int n) { |
174 | if (elem_offset_cache_.find(t) != elem_offset_cache_.end()) { |
175 | return elem_offset_cache_[t][n]; |
176 | } else { |
177 | return 0; |
178 | } |
179 | } |
180 | |
181 | private: |
182 | bool is_equal(const Polymorphic &other) const override { |
183 | // This is only called when `other` has the same typeid |
184 | return true; |
185 | } |
186 | }; |
187 | |
188 | class MemRefElementTypeInterface { |
189 | public: |
190 | virtual size_t memory_size(LayoutContext &ctx) const = 0; |
191 | virtual size_t memory_alignment_size(LayoutContext &ctx) const = 0; |
192 | }; |
193 | |
194 | class MemRefAggregateTypeInterface : public MemRefElementTypeInterface { |
195 | public: |
196 | virtual size_t nth_element_offset(int n, LayoutContext &ctx) const = 0; |
197 | }; |
198 | |
199 | class AggregateTypeInterface { |
200 | public: |
201 | virtual const Type *nth_element_type(int n) const = 0; |
202 | virtual int get_num_elements() const = 0; |
203 | }; |
204 | |
205 | class ShapedTypeInterface { |
206 | public: |
207 | virtual const Type *element_type() const = 0; |
208 | virtual bool is_constant_shape() const = 0; |
209 | virtual std::vector<size_t> get_constant_shape() const = 0; |
210 | }; |
211 | |
212 | class PointerTypeInterface { |
213 | public: |
214 | virtual const Type *get_pointed_type() const = 0; |
215 | }; |
216 | |
217 | class Block { |
218 | public: |
219 | template <typename T, class... E> |
220 | T *emplace_back(E... args) { |
221 | nodes_.push_back(std::make_unique<T>(args...)); |
222 | return static_cast<T *>(nodes_.back().get()); |
223 | } |
224 | |
225 | template <typename T> |
226 | T *push_back(std::unique_ptr<T> &&val) { |
227 | T *ptr = val.get(); |
228 | nodes_.push_back(std::move(val)); |
229 | return ptr; |
230 | } |
231 | |
232 | const std::vector<std::unique_ptr<Node>> &nodes() const { |
233 | return nodes_; |
234 | } |
235 | |
236 | private: |
237 | std::vector<std::unique_ptr<Node>> nodes_; |
238 | }; |
239 | |
240 | class Visitor { |
241 | public: |
242 | virtual ~Visitor() { |
243 | } |
244 | |
245 | virtual void visit(const Node *node) { |
246 | if (node->is<Type>()) { |
247 | visit_type(node->as<Type>()); |
248 | } |
249 | } |
250 | |
251 | virtual void visit_type(const Type *type) { |
252 | } |
253 | |
254 | virtual void visit(const Block *block) { |
255 | for (auto &n : block->nodes()) { |
256 | visit(n.get()); |
257 | } |
258 | } |
259 | }; |
260 | |
261 | } // namespace tinyir |
262 | } // namespace taichi |
263 | |