1#include "taichi/codegen/spirv/snode_struct_compiler.h"
2
3namespace taichi::lang {
4namespace spirv {
5namespace {
6
7class StructCompiler {
8 public:
9 CompiledSNodeStructs run(SNode &root) {
10 TI_ASSERT(root.type == SNodeType::root);
11
12 CompiledSNodeStructs result;
13 result.root = &root;
14 result.root_size = compute_snode_size(&root);
15 result.snode_descriptors = std::move(snode_descriptors_);
16 /*
17 result.type_factory = new tinyir::Block;
18 result.root_type = construct(*result.type_factory, &root);
19 */
20 TI_TRACE("RootBuffer size={}", result.root_size);
21
22 /*
23 std::unique_ptr<tinyir::Block> b = ir_reduce_types(result.type_factory);
24
25 TI_WARN("Original types:\n{}", ir_print_types(result.type_factory));
26
27 TI_WARN("Reduced types:\n{}", ir_print_types(b.get()));
28 */
29
30 return result;
31 }
32
33 private:
34 const tinyir::Type *construct(tinyir::Block &ir_module, SNode *sn) {
35 const tinyir::Type *cell_type = nullptr;
36
37 if (sn->is_place()) {
38 // Each cell is a single Type
39 cell_type = translate_ti_primitive(ir_module, sn->dt);
40 } else {
41 // Each cell is a struct
42 std::vector<const tinyir::Type *> struct_elements;
43 for (auto &ch : sn->ch) {
44 const tinyir::Type *elem_type = construct(ir_module, ch.get());
45 struct_elements.push_back(elem_type);
46 }
47 tinyir::Type *st = ir_module.emplace_back<StructType>(struct_elements);
48 st->set_debug_name(
49 fmt::format("{}_{}", snode_type_name(sn->type), sn->get_name()));
50 cell_type = st;
51
52 if (sn->type == SNodeType::pointer) {
53 cell_type = ir_module.emplace_back<PhysicalPointerType>(cell_type);
54 }
55 }
56
57 if (sn->num_cells_per_container == 1 || sn->is_scalar()) {
58 return cell_type;
59 } else {
60 return ir_module.emplace_back<ArrayType>(cell_type,
61 sn->num_cells_per_container);
62 }
63 }
64
65 std::size_t compute_snode_size(SNode *sn) {
66 const bool is_place = sn->is_place();
67
68 SNodeDescriptor sn_desc;
69 sn_desc.snode = sn;
70 if (is_place) {
71 sn_desc.cell_stride = data_type_size(sn->dt);
72 sn_desc.container_stride = sn_desc.cell_stride;
73 } else {
74 // Sort by size, so that smaller subfields are placed first.
75 // This accelerates Nvidia's GLSL compiler, as the compiler tries to
76 // place all statically accessed fields
77 std::vector<std::pair<size_t, int>> element_strides;
78 int i = 0;
79 for (auto &ch : sn->ch) {
80 element_strides.push_back({compute_snode_size(ch.get()), i});
81 i += 1;
82 }
83 std::sort(
84 element_strides.begin(), element_strides.end(),
85 [](const std::pair<size_t, int> &a, const std::pair<size_t, int> &b) {
86 return a.first < b.first;
87 });
88
89 std::size_t cell_stride = 0;
90 for (auto &[snode_size, i] : element_strides) {
91 auto &ch = sn->ch[i];
92 auto child_offset = cell_stride;
93 auto *ch_snode = ch.get();
94 cell_stride += snode_size;
95 snode_descriptors_.find(ch_snode->id)
96 ->second.mem_offset_in_parent_cell = child_offset;
97 ch_snode->offset_bytes_in_parent_cell = child_offset;
98 }
99 sn_desc.cell_stride = cell_stride;
100
101 if (sn->type == SNodeType::bitmasked) {
102 size_t num_cells = sn_desc.snode->num_cells_per_container;
103 size_t bitmask_num_words =
104 num_cells % 32 == 0 ? (num_cells / 32) : (num_cells / 32 + 1);
105 sn_desc.container_stride =
106 cell_stride * num_cells + bitmask_num_words * 4;
107 } else {
108 sn_desc.container_stride =
109 cell_stride * sn_desc.snode->num_cells_per_container;
110 }
111 }
112
113 sn->cell_size_bytes = sn_desc.cell_stride;
114
115 sn_desc.total_num_cells_from_root = 1;
116 for (const auto &e : sn->extractors) {
117 // Note that the extractors are set in two places:
118 // 1. When a new SNode is first defined
119 // 2. StructCompiler::infer_snode_properties()
120 // The second step is the finalized result.
121 sn_desc.total_num_cells_from_root *= e.num_elements_from_root;
122 }
123
124 TI_TRACE("SNodeDescriptor");
125 TI_TRACE("* snode={}", sn_desc.snode->id);
126 TI_TRACE("* type={} (is_place={})", sn_desc.snode->node_type_name,
127 is_place);
128 TI_TRACE("* cell_stride={}", sn_desc.cell_stride);
129 TI_TRACE("* num_cells_per_container={}",
130 sn_desc.snode->num_cells_per_container);
131 TI_TRACE("* container_stride={}", sn_desc.container_stride);
132 TI_TRACE("* total_num_cells_from_root={}",
133 sn_desc.total_num_cells_from_root);
134 TI_TRACE("");
135
136 TI_ASSERT(snode_descriptors_.find(sn->id) == snode_descriptors_.end());
137 snode_descriptors_[sn->id] = sn_desc;
138 return sn_desc.container_stride;
139 }
140
141 SNodeDescriptorsMap snode_descriptors_;
142};
143
144} // namespace
145
146CompiledSNodeStructs compile_snode_structs(SNode &root) {
147 StructCompiler compiler;
148 return compiler.run(root);
149}
150
151} // namespace spirv
152} // namespace taichi::lang
153