1 | #include "spirv_types.h" |
2 | #include "spirv_ir_builder.h" |
3 | |
4 | namespace taichi::lang { |
5 | namespace spirv { |
6 | |
7 | size_t StructType::memory_size(tinyir::LayoutContext &ctx) const { |
8 | if (size_t s = ctx.query_size(this)) { |
9 | return s; |
10 | } |
11 | |
12 | ctx.register_aggregate(this, elements_.size()); |
13 | |
14 | size_t size_head = 0; |
15 | int n = 0; |
16 | for (const Type *elem : elements_) { |
17 | TI_ASSERT(elem->is<tinyir::MemRefElementTypeInterface>()); |
18 | const MemRefElementTypeInterface *mem_ref_type = |
19 | elem->cast<tinyir::MemRefElementTypeInterface>(); |
20 | size_t elem_size = mem_ref_type->memory_size(ctx); |
21 | size_t elem_align = mem_ref_type->memory_alignment_size(ctx); |
22 | // First align the head ptr, then add the size |
23 | size_head = tinyir::ceil_div(size_head, elem_align) * elem_align; |
24 | ctx.register_elem_offset(this, n, size_head); |
25 | size_head += elem_size; |
26 | n++; |
27 | } |
28 | |
29 | if (ctx.is<STD140LayoutContext>()) { |
30 | // With STD140 layout, the next member is rounded up to the alignment size. |
31 | // Thus we should simply size up the struct to the alignment. |
32 | size_t self_alignment = this->memory_alignment_size(ctx); |
33 | size_head = tinyir::ceil_div(size_head, self_alignment) * self_alignment; |
34 | } |
35 | |
36 | ctx.register_size(this, size_head); |
37 | return size_head; |
38 | } |
39 | |
40 | size_t StructType::memory_alignment_size(tinyir::LayoutContext &ctx) const { |
41 | if (size_t s = ctx.query_alignment(this)) { |
42 | return s; |
43 | } |
44 | |
45 | size_t max_align = 0; |
46 | for (const Type *elem : elements_) { |
47 | TI_ASSERT(elem->is<tinyir::MemRefElementTypeInterface>()); |
48 | max_align = std::max( |
49 | max_align, |
50 | elem->cast<MemRefElementTypeInterface>()->memory_alignment_size(ctx)); |
51 | } |
52 | |
53 | if (ctx.is<STD140LayoutContext>()) { |
54 | // With STD140 layout, struct alignment is rounded up to `sizeof(vec4)` |
55 | constexpr size_t vec4_size = sizeof(float) * 4; |
56 | max_align = tinyir::ceil_div(max_align, vec4_size) * vec4_size; |
57 | } |
58 | |
59 | ctx.register_alignment(this, max_align); |
60 | return max_align; |
61 | } |
62 | |
63 | size_t StructType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { |
64 | this->memory_size(ctx); |
65 | |
66 | return ctx.query_elem_offset(this, n); |
67 | } |
68 | |
69 | SmallVectorType::SmallVectorType(const Type *element_type, int num_elements) |
70 | : element_type_(element_type), num_elements_(num_elements) { |
71 | TI_ASSERT(num_elements > 1 && num_elements_ <= 4); |
72 | } |
73 | |
74 | size_t SmallVectorType::memory_size(tinyir::LayoutContext &ctx) const { |
75 | if (size_t s = ctx.query_size(this)) { |
76 | return s; |
77 | } |
78 | |
79 | size_t size = |
80 | element_type_->cast<tinyir::MemRefElementTypeInterface>()->memory_size( |
81 | ctx) * |
82 | num_elements_; |
83 | |
84 | ctx.register_size(this, size); |
85 | return size; |
86 | } |
87 | |
88 | size_t SmallVectorType::memory_alignment_size( |
89 | tinyir::LayoutContext &ctx) const { |
90 | if (size_t s = ctx.query_alignment(this)) { |
91 | return s; |
92 | } |
93 | |
94 | size_t align = |
95 | element_type_->cast<tinyir::MemRefElementTypeInterface>()->memory_size( |
96 | ctx); |
97 | |
98 | if (ctx.is<STD430LayoutContext>() || ctx.is<STD140LayoutContext>()) { |
99 | // For STD140 / STD430, small vectors are Power-of-Two aligned |
100 | // In C or "Scalar block layout", blocks are aligned to its component |
101 | // alignment |
102 | if (num_elements_ == 2) { |
103 | align *= 2; |
104 | } else { |
105 | align *= 4; |
106 | } |
107 | } |
108 | |
109 | ctx.register_alignment(this, align); |
110 | return align; |
111 | } |
112 | |
113 | size_t ArrayType::memory_size(tinyir::LayoutContext &ctx) const { |
114 | if (size_t s = ctx.query_size(this)) { |
115 | return s; |
116 | } |
117 | |
118 | size_t elem_align = element_type_->cast<tinyir::MemRefElementTypeInterface>() |
119 | ->memory_alignment_size(ctx); |
120 | |
121 | if (ctx.is<STD140LayoutContext>()) { |
122 | // For STD140, arrays element stride equals the base alignment of the array |
123 | // itself |
124 | elem_align = this->memory_alignment_size(ctx); |
125 | } |
126 | size_t size = elem_align * size_; |
127 | |
128 | ctx.register_size(this, size); |
129 | return size; |
130 | } |
131 | |
132 | size_t ArrayType::memory_alignment_size(tinyir::LayoutContext &ctx) const { |
133 | if (size_t s = ctx.query_alignment(this)) { |
134 | return s; |
135 | } |
136 | |
137 | size_t elem_align = element_type_->cast<tinyir::MemRefElementTypeInterface>() |
138 | ->memory_alignment_size(ctx); |
139 | |
140 | if (ctx.is<STD140LayoutContext>()) { |
141 | // With STD140 layout, array alignment is rounded up to `sizeof(vec4)` |
142 | constexpr size_t vec4_size = sizeof(float) * 4; |
143 | elem_align = tinyir::ceil_div(elem_align, vec4_size) * vec4_size; |
144 | } |
145 | |
146 | ctx.register_alignment(this, elem_align); |
147 | return elem_align; |
148 | } |
149 | |
150 | size_t ArrayType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { |
151 | size_t elem_align = this->memory_alignment_size(ctx); |
152 | |
153 | return elem_align * n; |
154 | } |
155 | |
156 | bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted) { |
157 | if (a->is<IntType>() && b->is<IntType>()) { |
158 | return a->as<IntType>()->num_bits() == b->as<IntType>()->num_bits(); |
159 | } else if (a->is<FloatType>() && b->is<IntType>()) { |
160 | return a->as<FloatType>()->num_bits() == b->as<IntType>()->num_bits(); |
161 | } else if (!_inverted) { |
162 | return bitcast_possible(b, a, true); |
163 | } |
164 | return false; |
165 | } |
166 | |
167 | const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, |
168 | const DataType t) { |
169 | if (t->is<PrimitiveType>()) { |
170 | if (t == PrimitiveType::i8) { |
171 | return ir_module.emplace_back<IntType>(/*num_bits=*/8, |
172 | /*is_signed=*/true); |
173 | } else if (t == PrimitiveType::i16) { |
174 | return ir_module.emplace_back<IntType>(/*num_bits=*/16, |
175 | /*is_signed=*/true); |
176 | } else if (t == PrimitiveType::i32) { |
177 | return ir_module.emplace_back<IntType>(/*num_bits=*/32, |
178 | /*is_signed=*/true); |
179 | } else if (t == PrimitiveType::i64) { |
180 | return ir_module.emplace_back<IntType>(/*num_bits=*/64, |
181 | /*is_signed=*/true); |
182 | } else if (t == PrimitiveType::u8) { |
183 | return ir_module.emplace_back<IntType>(/*num_bits=*/8, |
184 | /*is_signed=*/false); |
185 | } else if (t == PrimitiveType::u16) { |
186 | return ir_module.emplace_back<IntType>(/*num_bits=*/16, |
187 | /*is_signed=*/false); |
188 | } else if (t == PrimitiveType::u32) { |
189 | return ir_module.emplace_back<IntType>(/*num_bits=*/32, |
190 | /*is_signed=*/false); |
191 | } else if (t == PrimitiveType::u64) { |
192 | return ir_module.emplace_back<IntType>(/*num_bits=*/64, |
193 | /*is_signed=*/false); |
194 | } else if (t == PrimitiveType::f16) { |
195 | return ir_module.emplace_back<FloatType>(/*num_bits=*/16); |
196 | } else if (t == PrimitiveType::f32) { |
197 | return ir_module.emplace_back<FloatType>(/*num_bits=*/32); |
198 | } else if (t == PrimitiveType::f64) { |
199 | return ir_module.emplace_back<FloatType>(/*num_bits=*/64); |
200 | } else { |
201 | TI_NOT_IMPLEMENTED; |
202 | } |
203 | } else { |
204 | TI_NOT_IMPLEMENTED; |
205 | } |
206 | } |
207 | |
208 | void TypeVisitor::visit_type(const tinyir::Type *type) { |
209 | if (type->is<PhysicalPointerType>()) { |
210 | visit_physical_pointer_type(type->as<PhysicalPointerType>()); |
211 | } else if (type->is<SmallVectorType>()) { |
212 | visit_small_vector_type(type->as<SmallVectorType>()); |
213 | } else if (type->is<ArrayType>()) { |
214 | visit_array_type(type->as<ArrayType>()); |
215 | } else if (type->is<StructType>()) { |
216 | visit_struct_type(type->as<StructType>()); |
217 | } else if (type->is<IntType>()) { |
218 | visit_int_type(type->as<IntType>()); |
219 | } else if (type->is<FloatType>()) { |
220 | visit_float_type(type->as<FloatType>()); |
221 | } |
222 | } |
223 | |
224 | class TypePrinter : public TypeVisitor { |
225 | private: |
226 | std::string result_; |
227 | STD140LayoutContext layout_ctx_; |
228 | |
229 | uint32_t head_{0}; |
230 | std::unordered_map<const tinyir::Type *, uint32_t> idmap_; |
231 | |
232 | uint32_t get_id(const tinyir::Type *type) { |
233 | if (idmap_.find(type) == idmap_.end()) { |
234 | uint32_t id = head_++; |
235 | idmap_[type] = id; |
236 | return id; |
237 | } else { |
238 | return idmap_[type]; |
239 | } |
240 | } |
241 | |
242 | public: |
243 | void visit_int_type(const IntType *type) override { |
244 | result_ += fmt::format("T{} = {}int{}_t\n" , get_id(type), |
245 | type->is_signed() ? "" : "u" , type->num_bits()); |
246 | } |
247 | |
248 | void visit_float_type(const FloatType *type) override { |
249 | result_ += fmt::format("T{} = float{}_t\n" , get_id(type), type->num_bits()); |
250 | } |
251 | |
252 | void visit_physical_pointer_type(const PhysicalPointerType *type) override { |
253 | result_ += fmt::format("T{} = T{} *\n" , get_id(type), |
254 | get_id(type->get_pointed_type())); |
255 | } |
256 | |
257 | void visit_struct_type(const StructType *type) override { |
258 | result_ += fmt::format("T{} = struct {{" , get_id(type)); |
259 | for (int i = 0; i < type->get_num_elements(); i++) { |
260 | result_ += fmt::format("T{}, " , get_id(type->nth_element_type(i))); |
261 | } |
262 | result_ += "}}\n" ; |
263 | } |
264 | |
265 | void visit_small_vector_type(const SmallVectorType *type) override { |
266 | result_ += fmt::format("T{} = small_vector<T{}, {}>\n" , get_id(type), |
267 | get_id(type->element_type()), |
268 | type->get_constant_shape()[0]); |
269 | } |
270 | |
271 | void visit_array_type(const ArrayType *type) override { |
272 | result_ += fmt::format("T{} = array<T{}, {}>\n" , get_id(type), |
273 | get_id(type->element_type()), |
274 | type->get_constant_shape()[0]); |
275 | } |
276 | |
277 | static std::string print_types(const tinyir::Block *block) { |
278 | TypePrinter p; |
279 | p.visit(block); |
280 | return p.result_; |
281 | } |
282 | }; |
283 | |
284 | std::string ir_print_types(const tinyir::Block *block) { |
285 | return TypePrinter::print_types(block); |
286 | } |
287 | |
288 | class TypeReducer : public TypeVisitor { |
289 | public: |
290 | std::unique_ptr<tinyir::Block> copy{nullptr}; |
291 | std::unordered_map<const tinyir::Type *, const tinyir::Type *> &oldptr2newptr; |
292 | |
293 | explicit TypeReducer( |
294 | std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new) |
295 | : oldptr2newptr(old2new) { |
296 | copy = std::make_unique<tinyir::Block>(); |
297 | old2new.clear(); |
298 | } |
299 | |
300 | const tinyir::Type *check_type(const tinyir::Type *type) { |
301 | if (oldptr2newptr.find(type) != oldptr2newptr.end()) { |
302 | return oldptr2newptr[type]; |
303 | } |
304 | for (const auto &t : copy->nodes()) { |
305 | if (t->equals(type)) { |
306 | oldptr2newptr[type] = (const tinyir::Type *)t.get(); |
307 | return (const tinyir::Type *)t.get(); |
308 | } |
309 | } |
310 | return nullptr; |
311 | } |
312 | |
313 | void visit_int_type(const IntType *type) override { |
314 | if (!check_type(type)) { |
315 | oldptr2newptr[type] = copy->emplace_back<IntType>(*type); |
316 | } |
317 | } |
318 | |
319 | void visit_float_type(const FloatType *type) override { |
320 | if (!check_type(type)) { |
321 | oldptr2newptr[type] = copy->emplace_back<FloatType>(*type); |
322 | } |
323 | } |
324 | |
325 | void visit_physical_pointer_type(const PhysicalPointerType *type) override { |
326 | if (!check_type(type)) { |
327 | const tinyir::Type *pointed = check_type(type->get_pointed_type()); |
328 | TI_ASSERT(pointed); |
329 | oldptr2newptr[type] = copy->emplace_back<PhysicalPointerType>(pointed); |
330 | } |
331 | } |
332 | |
333 | void visit_struct_type(const StructType *type) override { |
334 | if (!check_type(type)) { |
335 | std::vector<const tinyir::Type *> elements; |
336 | for (int i = 0; i < type->get_num_elements(); i++) { |
337 | const tinyir::Type *elm = check_type(type->nth_element_type(i)); |
338 | TI_ASSERT(elm); |
339 | elements.push_back(elm); |
340 | } |
341 | oldptr2newptr[type] = copy->emplace_back<StructType>(elements); |
342 | } |
343 | } |
344 | |
345 | void visit_small_vector_type(const SmallVectorType *type) override { |
346 | if (!check_type(type)) { |
347 | const tinyir::Type *element = check_type(type->element_type()); |
348 | TI_ASSERT(element); |
349 | oldptr2newptr[type] = copy->emplace_back<SmallVectorType>( |
350 | element, type->get_constant_shape()[0]); |
351 | } |
352 | } |
353 | |
354 | void visit_array_type(const ArrayType *type) override { |
355 | if (!check_type(type)) { |
356 | const tinyir::Type *element = check_type(type->element_type()); |
357 | TI_ASSERT(element); |
358 | oldptr2newptr[type] = |
359 | copy->emplace_back<ArrayType>(element, type->get_constant_shape()[0]); |
360 | } |
361 | } |
362 | }; |
363 | |
364 | std::unique_ptr<tinyir::Block> ir_reduce_types( |
365 | tinyir::Block *blk, |
366 | std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new) { |
367 | TypeReducer reducer(old2new); |
368 | reducer.visit(blk); |
369 | return std::move(reducer.copy); |
370 | } |
371 | |
372 | class Translate2Spirv : public TypeVisitor { |
373 | private: |
374 | IRBuilder *spir_builder_{nullptr}; |
375 | tinyir::LayoutContext &layout_context_; |
376 | |
377 | public: |
378 | std::unordered_map<const tinyir::Node *, uint32_t> ir_node_2_spv_value; |
379 | |
380 | Translate2Spirv(IRBuilder *spir_builder, |
381 | tinyir::LayoutContext &layout_context) |
382 | : spir_builder_(spir_builder), layout_context_(layout_context) { |
383 | } |
384 | |
385 | void visit_int_type(const IntType *type) override { |
386 | SType vt; |
387 | if (type->is_signed()) { |
388 | if (type->num_bits() == 8) { |
389 | vt = spir_builder_->i8_type(); |
390 | } else if (type->num_bits() == 16) { |
391 | vt = spir_builder_->i16_type(); |
392 | } else if (type->num_bits() == 32) { |
393 | vt = spir_builder_->i32_type(); |
394 | } else if (type->num_bits() == 64) { |
395 | vt = spir_builder_->i64_type(); |
396 | } |
397 | } else { |
398 | if (type->num_bits() == 8) { |
399 | vt = spir_builder_->u8_type(); |
400 | } else if (type->num_bits() == 16) { |
401 | vt = spir_builder_->u16_type(); |
402 | } else if (type->num_bits() == 32) { |
403 | vt = spir_builder_->u32_type(); |
404 | } else if (type->num_bits() == 64) { |
405 | vt = spir_builder_->u64_type(); |
406 | } |
407 | } |
408 | ir_node_2_spv_value[type] = vt.id; |
409 | } |
410 | |
411 | void visit_float_type(const FloatType *type) override { |
412 | SType vt; |
413 | if (type->num_bits() == 16) { |
414 | vt = spir_builder_->f16_type(); |
415 | } else if (type->num_bits() == 32) { |
416 | vt = spir_builder_->f32_type(); |
417 | } else if (type->num_bits() == 64) { |
418 | vt = spir_builder_->f64_type(); |
419 | } |
420 | ir_node_2_spv_value[type] = vt.id; |
421 | } |
422 | |
423 | void visit_physical_pointer_type(const PhysicalPointerType *type) override { |
424 | SType vt = spir_builder_->get_null_type(); |
425 | spir_builder_->declare_global( |
426 | spv::OpTypePointer, vt, spv::StorageClassPhysicalStorageBuffer, |
427 | ir_node_2_spv_value[type->get_pointed_type()]); |
428 | ir_node_2_spv_value[type] = vt.id; |
429 | } |
430 | |
431 | void visit_struct_type(const StructType *type) override { |
432 | std::vector<uint32_t> element_ids; |
433 | for (int i = 0; i < type->get_num_elements(); i++) { |
434 | element_ids.push_back(ir_node_2_spv_value[type->nth_element_type(i)]); |
435 | } |
436 | SType vt = spir_builder_->get_null_type(); |
437 | spir_builder_->declare_global(spv::OpTypeStruct, vt, element_ids); |
438 | ir_node_2_spv_value[type] = vt.id; |
439 | for (int i = 0; i < type->get_num_elements(); i++) { |
440 | spir_builder_->decorate(spv::OpMemberDecorate, vt, i, |
441 | spv::DecorationOffset, |
442 | type->nth_element_offset(i, layout_context_)); |
443 | } |
444 | } |
445 | |
446 | void visit_small_vector_type(const SmallVectorType *type) override { |
447 | SType vt = spir_builder_->get_null_type(); |
448 | spir_builder_->declare_global(spv::OpTypeVector, vt, |
449 | ir_node_2_spv_value[type->element_type()], |
450 | type->get_constant_shape()[0]); |
451 | ir_node_2_spv_value[type] = vt.id; |
452 | } |
453 | |
454 | void visit_array_type(const ArrayType *type) override { |
455 | SType vt = spir_builder_->get_null_type(); |
456 | spir_builder_->declare_global(spv::OpTypeArray, vt, |
457 | ir_node_2_spv_value[type->element_type()], |
458 | type->get_constant_shape()[0]); |
459 | ir_node_2_spv_value[type] = vt.id; |
460 | spir_builder_->decorate(spv::OpDecorate, vt, spv::DecorationArrayStride, |
461 | type->memory_alignment_size(layout_context_)); |
462 | } |
463 | }; |
464 | |
465 | std::unordered_map<const tinyir::Node *, uint32_t> ir_translate_to_spirv( |
466 | const tinyir::Block *blk, |
467 | tinyir::LayoutContext &layout_ctx, |
468 | IRBuilder *spir_builder) { |
469 | Translate2Spirv translator(spir_builder, layout_ctx); |
470 | translator.visit(blk); |
471 | return std::move(translator.ir_node_2_spv_value); |
472 | } |
473 | |
474 | } // namespace spirv |
475 | } // namespace taichi::lang |
476 | |