1 | #pragma once |
2 | |
3 | #include "lib_tiny_ir.h" |
4 | #include "taichi/ir/type.h" |
5 | |
6 | namespace taichi::lang { |
7 | namespace spirv { |
8 | |
9 | class STD140LayoutContext : public tinyir::LayoutContext {}; |
10 | class STD430LayoutContext : public tinyir::LayoutContext {}; |
11 | |
12 | class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { |
13 | public: |
14 | IntType(int num_bits, bool is_signed) |
15 | : num_bits_(num_bits), is_signed_(is_signed) { |
16 | } |
17 | |
18 | int num_bits() const { |
19 | return num_bits_; |
20 | } |
21 | |
22 | bool is_signed() const { |
23 | return is_signed_; |
24 | } |
25 | |
26 | size_t memory_size(tinyir::LayoutContext &ctx) const override { |
27 | return tinyir::ceil_div(num_bits(), 8); |
28 | } |
29 | |
30 | size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { |
31 | return tinyir::ceil_div(num_bits(), 8); |
32 | } |
33 | |
34 | private: |
35 | bool is_equal(const Polymorphic &other) const override { |
36 | const IntType &t = (const IntType &)other; |
37 | return t.num_bits_ == num_bits_ && t.is_signed_ == is_signed_; |
38 | } |
39 | |
40 | int num_bits_{0}; |
41 | bool is_signed_{false}; |
42 | }; |
43 | |
44 | class FloatType : public tinyir::Type, |
45 | public tinyir::MemRefElementTypeInterface { |
46 | public: |
47 | explicit FloatType(int num_bits) : num_bits_(num_bits) { |
48 | } |
49 | |
50 | int num_bits() const { |
51 | return num_bits_; |
52 | } |
53 | |
54 | size_t memory_size(tinyir::LayoutContext &ctx) const override { |
55 | return tinyir::ceil_div(num_bits(), 8); |
56 | } |
57 | |
58 | size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { |
59 | return tinyir::ceil_div(num_bits(), 8); |
60 | } |
61 | |
62 | private: |
63 | int num_bits_{0}; |
64 | |
65 | bool is_equal(const Polymorphic &other) const override { |
66 | const FloatType &t = (const FloatType &)other; |
67 | return t.num_bits_ == num_bits_; |
68 | } |
69 | }; |
70 | |
71 | class PhysicalPointerType : public IntType, |
72 | public tinyir::PointerTypeInterface { |
73 | public: |
74 | explicit PhysicalPointerType(const tinyir::Type *pointed_type) |
75 | : IntType(/*num_bits=*/64, /*is_signed=*/false), |
76 | pointed_type_(pointed_type) { |
77 | } |
78 | |
79 | const tinyir::Type *get_pointed_type() const override { |
80 | return pointed_type_; |
81 | } |
82 | |
83 | private: |
84 | const tinyir::Type *pointed_type_; |
85 | |
86 | bool is_equal(const Polymorphic &other) const override { |
87 | const PhysicalPointerType &pt = (const PhysicalPointerType &)other; |
88 | return IntType::operator==((const IntType &)other) && |
89 | pointed_type_->equals(pt.pointed_type_); |
90 | } |
91 | }; |
92 | |
93 | class StructType : public tinyir::Type, |
94 | public tinyir::AggregateTypeInterface, |
95 | public tinyir::MemRefAggregateTypeInterface { |
96 | public: |
97 | explicit StructType(std::vector<const tinyir::Type *> &elements) |
98 | : elements_(elements) { |
99 | } |
100 | |
101 | const tinyir::Type *nth_element_type(int n) const override { |
102 | return elements_[n]; |
103 | } |
104 | |
105 | int get_num_elements() const override { |
106 | return elements_.size(); |
107 | } |
108 | |
109 | size_t memory_size(tinyir::LayoutContext &ctx) const override; |
110 | |
111 | size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; |
112 | |
113 | size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; |
114 | |
115 | private: |
116 | std::vector<const tinyir::Type *> elements_; |
117 | |
118 | bool is_equal(const Polymorphic &other) const override { |
119 | const StructType &t = (const StructType &)other; |
120 | if (t.get_num_elements() != get_num_elements()) { |
121 | return false; |
122 | } |
123 | for (int i = 0; i < get_num_elements(); i++) { |
124 | if (!elements_[i]->equals(t.elements_[i])) { |
125 | return false; |
126 | } |
127 | } |
128 | return true; |
129 | } |
130 | }; |
131 | |
132 | class SmallVectorType : public tinyir::Type, |
133 | public tinyir::ShapedTypeInterface, |
134 | public tinyir::MemRefElementTypeInterface { |
135 | public: |
136 | SmallVectorType(const tinyir::Type *element_type, int num_elements); |
137 | |
138 | const tinyir::Type *element_type() const override { |
139 | return element_type_; |
140 | } |
141 | |
142 | bool is_constant_shape() const override { |
143 | return true; |
144 | } |
145 | |
146 | std::vector<size_t> get_constant_shape() const override { |
147 | return {size_t(num_elements_)}; |
148 | } |
149 | |
150 | size_t memory_size(tinyir::LayoutContext &ctx) const override; |
151 | |
152 | size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; |
153 | |
154 | private: |
155 | bool is_equal(const Polymorphic &other) const override { |
156 | const SmallVectorType &t = (const SmallVectorType &)other; |
157 | return num_elements_ == t.num_elements_ && |
158 | element_type_->equals(t.element_type_); |
159 | } |
160 | |
161 | const tinyir::Type *element_type_{nullptr}; |
162 | int num_elements_{0}; |
163 | }; |
164 | |
165 | class ArrayType : public tinyir::Type, |
166 | public tinyir::ShapedTypeInterface, |
167 | public tinyir::MemRefAggregateTypeInterface { |
168 | public: |
169 | ArrayType(const tinyir::Type *element_type, size_t size) |
170 | : element_type_(element_type), size_(size) { |
171 | } |
172 | |
173 | const tinyir::Type *element_type() const override { |
174 | return element_type_; |
175 | } |
176 | |
177 | bool is_constant_shape() const override { |
178 | return true; |
179 | } |
180 | |
181 | std::vector<size_t> get_constant_shape() const override { |
182 | return {size_}; |
183 | } |
184 | |
185 | size_t memory_size(tinyir::LayoutContext &ctx) const override; |
186 | |
187 | size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; |
188 | |
189 | size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; |
190 | |
191 | private: |
192 | bool is_equal(const Polymorphic &other) const override { |
193 | const ArrayType &t = (const ArrayType &)other; |
194 | return size_ == t.size_ && element_type_->equals(t.element_type_); |
195 | } |
196 | |
197 | const tinyir::Type *element_type_{nullptr}; |
198 | size_t size_{0}; |
199 | }; |
200 | |
201 | bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted = false); |
202 | |
203 | class TypeVisitor : public tinyir::Visitor { |
204 | public: |
205 | void visit_type(const tinyir::Type *type) override; |
206 | |
207 | virtual void visit_int_type(const IntType *type) { |
208 | } |
209 | |
210 | virtual void visit_float_type(const FloatType *type) { |
211 | } |
212 | |
213 | virtual void visit_physical_pointer_type(const PhysicalPointerType *type) { |
214 | } |
215 | |
216 | virtual void visit_struct_type(const StructType *type) { |
217 | } |
218 | |
219 | virtual void visit_small_vector_type(const SmallVectorType *type) { |
220 | } |
221 | |
222 | virtual void visit_array_type(const ArrayType *type) { |
223 | } |
224 | }; |
225 | |
226 | const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, |
227 | const DataType t); |
228 | |
229 | std::string ir_print_types(const tinyir::Block *block); |
230 | |
231 | std::unique_ptr<tinyir::Block> ir_reduce_types( |
232 | tinyir::Block *blk, |
233 | std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new); |
234 | |
235 | class IRBuilder; |
236 | |
237 | std::unordered_map<const tinyir::Node *, uint32_t> ir_translate_to_spirv( |
238 | const tinyir::Block *blk, |
239 | tinyir::LayoutContext &layout_ctx, |
240 | IRBuilder *spir_builder); |
241 | |
242 | } // namespace spirv |
243 | } // namespace taichi::lang |
244 | |