1#include "spirv_types.h"
2#include "spirv_ir_builder.h"
3
4namespace taichi::lang {
5namespace spirv {
6
7size_t StructType::memory_size(tinyir::LayoutContext &ctx) const {
8 if (size_t s = ctx.query_size(this)) {
9 return s;
10 }
11
12 ctx.register_aggregate(this, elements_.size());
13
14 size_t size_head = 0;
15 int n = 0;
16 for (const Type *elem : elements_) {
17 TI_ASSERT(elem->is<tinyir::MemRefElementTypeInterface>());
18 const MemRefElementTypeInterface *mem_ref_type =
19 elem->cast<tinyir::MemRefElementTypeInterface>();
20 size_t elem_size = mem_ref_type->memory_size(ctx);
21 size_t elem_align = mem_ref_type->memory_alignment_size(ctx);
22 // First align the head ptr, then add the size
23 size_head = tinyir::ceil_div(size_head, elem_align) * elem_align;
24 ctx.register_elem_offset(this, n, size_head);
25 size_head += elem_size;
26 n++;
27 }
28
29 if (ctx.is<STD140LayoutContext>()) {
30 // With STD140 layout, the next member is rounded up to the alignment size.
31 // Thus we should simply size up the struct to the alignment.
32 size_t self_alignment = this->memory_alignment_size(ctx);
33 size_head = tinyir::ceil_div(size_head, self_alignment) * self_alignment;
34 }
35
36 ctx.register_size(this, size_head);
37 return size_head;
38}
39
40size_t StructType::memory_alignment_size(tinyir::LayoutContext &ctx) const {
41 if (size_t s = ctx.query_alignment(this)) {
42 return s;
43 }
44
45 size_t max_align = 0;
46 for (const Type *elem : elements_) {
47 TI_ASSERT(elem->is<tinyir::MemRefElementTypeInterface>());
48 max_align = std::max(
49 max_align,
50 elem->cast<MemRefElementTypeInterface>()->memory_alignment_size(ctx));
51 }
52
53 if (ctx.is<STD140LayoutContext>()) {
54 // With STD140 layout, struct alignment is rounded up to `sizeof(vec4)`
55 constexpr size_t vec4_size = sizeof(float) * 4;
56 max_align = tinyir::ceil_div(max_align, vec4_size) * vec4_size;
57 }
58
59 ctx.register_alignment(this, max_align);
60 return max_align;
61}
62
63size_t StructType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const {
64 this->memory_size(ctx);
65
66 return ctx.query_elem_offset(this, n);
67}
68
69SmallVectorType::SmallVectorType(const Type *element_type, int num_elements)
70 : element_type_(element_type), num_elements_(num_elements) {
71 TI_ASSERT(num_elements > 1 && num_elements_ <= 4);
72}
73
74size_t SmallVectorType::memory_size(tinyir::LayoutContext &ctx) const {
75 if (size_t s = ctx.query_size(this)) {
76 return s;
77 }
78
79 size_t size =
80 element_type_->cast<tinyir::MemRefElementTypeInterface>()->memory_size(
81 ctx) *
82 num_elements_;
83
84 ctx.register_size(this, size);
85 return size;
86}
87
88size_t SmallVectorType::memory_alignment_size(
89 tinyir::LayoutContext &ctx) const {
90 if (size_t s = ctx.query_alignment(this)) {
91 return s;
92 }
93
94 size_t align =
95 element_type_->cast<tinyir::MemRefElementTypeInterface>()->memory_size(
96 ctx);
97
98 if (ctx.is<STD430LayoutContext>() || ctx.is<STD140LayoutContext>()) {
99 // For STD140 / STD430, small vectors are Power-of-Two aligned
100 // In C or "Scalar block layout", blocks are aligned to its component
101 // alignment
102 if (num_elements_ == 2) {
103 align *= 2;
104 } else {
105 align *= 4;
106 }
107 }
108
109 ctx.register_alignment(this, align);
110 return align;
111}
112
113size_t ArrayType::memory_size(tinyir::LayoutContext &ctx) const {
114 if (size_t s = ctx.query_size(this)) {
115 return s;
116 }
117
118 size_t elem_align = element_type_->cast<tinyir::MemRefElementTypeInterface>()
119 ->memory_alignment_size(ctx);
120
121 if (ctx.is<STD140LayoutContext>()) {
122 // For STD140, arrays element stride equals the base alignment of the array
123 // itself
124 elem_align = this->memory_alignment_size(ctx);
125 }
126 size_t size = elem_align * size_;
127
128 ctx.register_size(this, size);
129 return size;
130}
131
132size_t ArrayType::memory_alignment_size(tinyir::LayoutContext &ctx) const {
133 if (size_t s = ctx.query_alignment(this)) {
134 return s;
135 }
136
137 size_t elem_align = element_type_->cast<tinyir::MemRefElementTypeInterface>()
138 ->memory_alignment_size(ctx);
139
140 if (ctx.is<STD140LayoutContext>()) {
141 // With STD140 layout, array alignment is rounded up to `sizeof(vec4)`
142 constexpr size_t vec4_size = sizeof(float) * 4;
143 elem_align = tinyir::ceil_div(elem_align, vec4_size) * vec4_size;
144 }
145
146 ctx.register_alignment(this, elem_align);
147 return elem_align;
148}
149
150size_t ArrayType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const {
151 size_t elem_align = this->memory_alignment_size(ctx);
152
153 return elem_align * n;
154}
155
156bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted) {
157 if (a->is<IntType>() && b->is<IntType>()) {
158 return a->as<IntType>()->num_bits() == b->as<IntType>()->num_bits();
159 } else if (a->is<FloatType>() && b->is<IntType>()) {
160 return a->as<FloatType>()->num_bits() == b->as<IntType>()->num_bits();
161 } else if (!_inverted) {
162 return bitcast_possible(b, a, true);
163 }
164 return false;
165}
166
167const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module,
168 const DataType t) {
169 if (t->is<PrimitiveType>()) {
170 if (t == PrimitiveType::i8) {
171 return ir_module.emplace_back<IntType>(/*num_bits=*/8,
172 /*is_signed=*/true);
173 } else if (t == PrimitiveType::i16) {
174 return ir_module.emplace_back<IntType>(/*num_bits=*/16,
175 /*is_signed=*/true);
176 } else if (t == PrimitiveType::i32) {
177 return ir_module.emplace_back<IntType>(/*num_bits=*/32,
178 /*is_signed=*/true);
179 } else if (t == PrimitiveType::i64) {
180 return ir_module.emplace_back<IntType>(/*num_bits=*/64,
181 /*is_signed=*/true);
182 } else if (t == PrimitiveType::u8) {
183 return ir_module.emplace_back<IntType>(/*num_bits=*/8,
184 /*is_signed=*/false);
185 } else if (t == PrimitiveType::u16) {
186 return ir_module.emplace_back<IntType>(/*num_bits=*/16,
187 /*is_signed=*/false);
188 } else if (t == PrimitiveType::u32) {
189 return ir_module.emplace_back<IntType>(/*num_bits=*/32,
190 /*is_signed=*/false);
191 } else if (t == PrimitiveType::u64) {
192 return ir_module.emplace_back<IntType>(/*num_bits=*/64,
193 /*is_signed=*/false);
194 } else if (t == PrimitiveType::f16) {
195 return ir_module.emplace_back<FloatType>(/*num_bits=*/16);
196 } else if (t == PrimitiveType::f32) {
197 return ir_module.emplace_back<FloatType>(/*num_bits=*/32);
198 } else if (t == PrimitiveType::f64) {
199 return ir_module.emplace_back<FloatType>(/*num_bits=*/64);
200 } else {
201 TI_NOT_IMPLEMENTED;
202 }
203 } else {
204 TI_NOT_IMPLEMENTED;
205 }
206}
207
208void TypeVisitor::visit_type(const tinyir::Type *type) {
209 if (type->is<PhysicalPointerType>()) {
210 visit_physical_pointer_type(type->as<PhysicalPointerType>());
211 } else if (type->is<SmallVectorType>()) {
212 visit_small_vector_type(type->as<SmallVectorType>());
213 } else if (type->is<ArrayType>()) {
214 visit_array_type(type->as<ArrayType>());
215 } else if (type->is<StructType>()) {
216 visit_struct_type(type->as<StructType>());
217 } else if (type->is<IntType>()) {
218 visit_int_type(type->as<IntType>());
219 } else if (type->is<FloatType>()) {
220 visit_float_type(type->as<FloatType>());
221 }
222}
223
224class TypePrinter : public TypeVisitor {
225 private:
226 std::string result_;
227 STD140LayoutContext layout_ctx_;
228
229 uint32_t head_{0};
230 std::unordered_map<const tinyir::Type *, uint32_t> idmap_;
231
232 uint32_t get_id(const tinyir::Type *type) {
233 if (idmap_.find(type) == idmap_.end()) {
234 uint32_t id = head_++;
235 idmap_[type] = id;
236 return id;
237 } else {
238 return idmap_[type];
239 }
240 }
241
242 public:
243 void visit_int_type(const IntType *type) override {
244 result_ += fmt::format("T{} = {}int{}_t\n", get_id(type),
245 type->is_signed() ? "" : "u", type->num_bits());
246 }
247
248 void visit_float_type(const FloatType *type) override {
249 result_ += fmt::format("T{} = float{}_t\n", get_id(type), type->num_bits());
250 }
251
252 void visit_physical_pointer_type(const PhysicalPointerType *type) override {
253 result_ += fmt::format("T{} = T{} *\n", get_id(type),
254 get_id(type->get_pointed_type()));
255 }
256
257 void visit_struct_type(const StructType *type) override {
258 result_ += fmt::format("T{} = struct {{", get_id(type));
259 for (int i = 0; i < type->get_num_elements(); i++) {
260 result_ += fmt::format("T{}, ", get_id(type->nth_element_type(i)));
261 }
262 result_ += "}}\n";
263 }
264
265 void visit_small_vector_type(const SmallVectorType *type) override {
266 result_ += fmt::format("T{} = small_vector<T{}, {}>\n", get_id(type),
267 get_id(type->element_type()),
268 type->get_constant_shape()[0]);
269 }
270
271 void visit_array_type(const ArrayType *type) override {
272 result_ += fmt::format("T{} = array<T{}, {}>\n", get_id(type),
273 get_id(type->element_type()),
274 type->get_constant_shape()[0]);
275 }
276
277 static std::string print_types(const tinyir::Block *block) {
278 TypePrinter p;
279 p.visit(block);
280 return p.result_;
281 }
282};
283
284std::string ir_print_types(const tinyir::Block *block) {
285 return TypePrinter::print_types(block);
286}
287
288class TypeReducer : public TypeVisitor {
289 public:
290 std::unique_ptr<tinyir::Block> copy{nullptr};
291 std::unordered_map<const tinyir::Type *, const tinyir::Type *> &oldptr2newptr;
292
293 explicit TypeReducer(
294 std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new)
295 : oldptr2newptr(old2new) {
296 copy = std::make_unique<tinyir::Block>();
297 old2new.clear();
298 }
299
300 const tinyir::Type *check_type(const tinyir::Type *type) {
301 if (oldptr2newptr.find(type) != oldptr2newptr.end()) {
302 return oldptr2newptr[type];
303 }
304 for (const auto &t : copy->nodes()) {
305 if (t->equals(type)) {
306 oldptr2newptr[type] = (const tinyir::Type *)t.get();
307 return (const tinyir::Type *)t.get();
308 }
309 }
310 return nullptr;
311 }
312
313 void visit_int_type(const IntType *type) override {
314 if (!check_type(type)) {
315 oldptr2newptr[type] = copy->emplace_back<IntType>(*type);
316 }
317 }
318
319 void visit_float_type(const FloatType *type) override {
320 if (!check_type(type)) {
321 oldptr2newptr[type] = copy->emplace_back<FloatType>(*type);
322 }
323 }
324
325 void visit_physical_pointer_type(const PhysicalPointerType *type) override {
326 if (!check_type(type)) {
327 const tinyir::Type *pointed = check_type(type->get_pointed_type());
328 TI_ASSERT(pointed);
329 oldptr2newptr[type] = copy->emplace_back<PhysicalPointerType>(pointed);
330 }
331 }
332
333 void visit_struct_type(const StructType *type) override {
334 if (!check_type(type)) {
335 std::vector<const tinyir::Type *> elements;
336 for (int i = 0; i < type->get_num_elements(); i++) {
337 const tinyir::Type *elm = check_type(type->nth_element_type(i));
338 TI_ASSERT(elm);
339 elements.push_back(elm);
340 }
341 oldptr2newptr[type] = copy->emplace_back<StructType>(elements);
342 }
343 }
344
345 void visit_small_vector_type(const SmallVectorType *type) override {
346 if (!check_type(type)) {
347 const tinyir::Type *element = check_type(type->element_type());
348 TI_ASSERT(element);
349 oldptr2newptr[type] = copy->emplace_back<SmallVectorType>(
350 element, type->get_constant_shape()[0]);
351 }
352 }
353
354 void visit_array_type(const ArrayType *type) override {
355 if (!check_type(type)) {
356 const tinyir::Type *element = check_type(type->element_type());
357 TI_ASSERT(element);
358 oldptr2newptr[type] =
359 copy->emplace_back<ArrayType>(element, type->get_constant_shape()[0]);
360 }
361 }
362};
363
364std::unique_ptr<tinyir::Block> ir_reduce_types(
365 tinyir::Block *blk,
366 std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new) {
367 TypeReducer reducer(old2new);
368 reducer.visit(blk);
369 return std::move(reducer.copy);
370}
371
372class Translate2Spirv : public TypeVisitor {
373 private:
374 IRBuilder *spir_builder_{nullptr};
375 tinyir::LayoutContext &layout_context_;
376
377 public:
378 std::unordered_map<const tinyir::Node *, uint32_t> ir_node_2_spv_value;
379
380 Translate2Spirv(IRBuilder *spir_builder,
381 tinyir::LayoutContext &layout_context)
382 : spir_builder_(spir_builder), layout_context_(layout_context) {
383 }
384
385 void visit_int_type(const IntType *type) override {
386 SType vt;
387 if (type->is_signed()) {
388 if (type->num_bits() == 8) {
389 vt = spir_builder_->i8_type();
390 } else if (type->num_bits() == 16) {
391 vt = spir_builder_->i16_type();
392 } else if (type->num_bits() == 32) {
393 vt = spir_builder_->i32_type();
394 } else if (type->num_bits() == 64) {
395 vt = spir_builder_->i64_type();
396 }
397 } else {
398 if (type->num_bits() == 8) {
399 vt = spir_builder_->u8_type();
400 } else if (type->num_bits() == 16) {
401 vt = spir_builder_->u16_type();
402 } else if (type->num_bits() == 32) {
403 vt = spir_builder_->u32_type();
404 } else if (type->num_bits() == 64) {
405 vt = spir_builder_->u64_type();
406 }
407 }
408 ir_node_2_spv_value[type] = vt.id;
409 }
410
411 void visit_float_type(const FloatType *type) override {
412 SType vt;
413 if (type->num_bits() == 16) {
414 vt = spir_builder_->f16_type();
415 } else if (type->num_bits() == 32) {
416 vt = spir_builder_->f32_type();
417 } else if (type->num_bits() == 64) {
418 vt = spir_builder_->f64_type();
419 }
420 ir_node_2_spv_value[type] = vt.id;
421 }
422
423 void visit_physical_pointer_type(const PhysicalPointerType *type) override {
424 SType vt = spir_builder_->get_null_type();
425 spir_builder_->declare_global(
426 spv::OpTypePointer, vt, spv::StorageClassPhysicalStorageBuffer,
427 ir_node_2_spv_value[type->get_pointed_type()]);
428 ir_node_2_spv_value[type] = vt.id;
429 }
430
431 void visit_struct_type(const StructType *type) override {
432 std::vector<uint32_t> element_ids;
433 for (int i = 0; i < type->get_num_elements(); i++) {
434 element_ids.push_back(ir_node_2_spv_value[type->nth_element_type(i)]);
435 }
436 SType vt = spir_builder_->get_null_type();
437 spir_builder_->declare_global(spv::OpTypeStruct, vt, element_ids);
438 ir_node_2_spv_value[type] = vt.id;
439 for (int i = 0; i < type->get_num_elements(); i++) {
440 spir_builder_->decorate(spv::OpMemberDecorate, vt, i,
441 spv::DecorationOffset,
442 type->nth_element_offset(i, layout_context_));
443 }
444 }
445
446 void visit_small_vector_type(const SmallVectorType *type) override {
447 SType vt = spir_builder_->get_null_type();
448 spir_builder_->declare_global(spv::OpTypeVector, vt,
449 ir_node_2_spv_value[type->element_type()],
450 type->get_constant_shape()[0]);
451 ir_node_2_spv_value[type] = vt.id;
452 }
453
454 void visit_array_type(const ArrayType *type) override {
455 SType vt = spir_builder_->get_null_type();
456 spir_builder_->declare_global(spv::OpTypeArray, vt,
457 ir_node_2_spv_value[type->element_type()],
458 type->get_constant_shape()[0]);
459 ir_node_2_spv_value[type] = vt.id;
460 spir_builder_->decorate(spv::OpDecorate, vt, spv::DecorationArrayStride,
461 type->memory_alignment_size(layout_context_));
462 }
463};
464
465std::unordered_map<const tinyir::Node *, uint32_t> ir_translate_to_spirv(
466 const tinyir::Block *blk,
467 tinyir::LayoutContext &layout_ctx,
468 IRBuilder *spir_builder) {
469 Translate2Spirv translator(spir_builder, layout_ctx);
470 translator.visit(blk);
471 return std::move(translator.ir_node_2_spv_value);
472}
473
474} // namespace spirv
475} // namespace taichi::lang
476