1 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
2 | |
3 | namespace taichi::lang { |
4 | namespace spirv { |
5 | namespace { |
6 | |
7 | class StructCompiler { |
8 | public: |
9 | CompiledSNodeStructs run(SNode &root) { |
10 | TI_ASSERT(root.type == SNodeType::root); |
11 | |
12 | CompiledSNodeStructs result; |
13 | result.root = &root; |
14 | result.root_size = compute_snode_size(&root); |
15 | result.snode_descriptors = std::move(snode_descriptors_); |
16 | /* |
17 | result.type_factory = new tinyir::Block; |
18 | result.root_type = construct(*result.type_factory, &root); |
19 | */ |
20 | TI_TRACE("RootBuffer size={}" , result.root_size); |
21 | |
22 | /* |
23 | std::unique_ptr<tinyir::Block> b = ir_reduce_types(result.type_factory); |
24 | |
25 | TI_WARN("Original types:\n{}", ir_print_types(result.type_factory)); |
26 | |
27 | TI_WARN("Reduced types:\n{}", ir_print_types(b.get())); |
28 | */ |
29 | |
30 | return result; |
31 | } |
32 | |
33 | private: |
34 | const tinyir::Type *construct(tinyir::Block &ir_module, SNode *sn) { |
35 | const tinyir::Type *cell_type = nullptr; |
36 | |
37 | if (sn->is_place()) { |
38 | // Each cell is a single Type |
39 | cell_type = translate_ti_primitive(ir_module, sn->dt); |
40 | } else { |
41 | // Each cell is a struct |
42 | std::vector<const tinyir::Type *> struct_elements; |
43 | for (auto &ch : sn->ch) { |
44 | const tinyir::Type *elem_type = construct(ir_module, ch.get()); |
45 | struct_elements.push_back(elem_type); |
46 | } |
47 | tinyir::Type *st = ir_module.emplace_back<StructType>(struct_elements); |
48 | st->set_debug_name( |
49 | fmt::format("{}_{}" , snode_type_name(sn->type), sn->get_name())); |
50 | cell_type = st; |
51 | |
52 | if (sn->type == SNodeType::pointer) { |
53 | cell_type = ir_module.emplace_back<PhysicalPointerType>(cell_type); |
54 | } |
55 | } |
56 | |
57 | if (sn->num_cells_per_container == 1 || sn->is_scalar()) { |
58 | return cell_type; |
59 | } else { |
60 | return ir_module.emplace_back<ArrayType>(cell_type, |
61 | sn->num_cells_per_container); |
62 | } |
63 | } |
64 | |
65 | std::size_t compute_snode_size(SNode *sn) { |
66 | const bool is_place = sn->is_place(); |
67 | |
68 | SNodeDescriptor sn_desc; |
69 | sn_desc.snode = sn; |
70 | if (is_place) { |
71 | sn_desc.cell_stride = data_type_size(sn->dt); |
72 | sn_desc.container_stride = sn_desc.cell_stride; |
73 | } else { |
74 | // Sort by size, so that smaller subfields are placed first. |
75 | // This accelerates Nvidia's GLSL compiler, as the compiler tries to |
76 | // place all statically accessed fields |
77 | std::vector<std::pair<size_t, int>> element_strides; |
78 | int i = 0; |
79 | for (auto &ch : sn->ch) { |
80 | element_strides.push_back({compute_snode_size(ch.get()), i}); |
81 | i += 1; |
82 | } |
83 | std::sort( |
84 | element_strides.begin(), element_strides.end(), |
85 | [](const std::pair<size_t, int> &a, const std::pair<size_t, int> &b) { |
86 | return a.first < b.first; |
87 | }); |
88 | |
89 | std::size_t cell_stride = 0; |
90 | for (auto &[snode_size, i] : element_strides) { |
91 | auto &ch = sn->ch[i]; |
92 | auto child_offset = cell_stride; |
93 | auto *ch_snode = ch.get(); |
94 | cell_stride += snode_size; |
95 | snode_descriptors_.find(ch_snode->id) |
96 | ->second.mem_offset_in_parent_cell = child_offset; |
97 | ch_snode->offset_bytes_in_parent_cell = child_offset; |
98 | } |
99 | sn_desc.cell_stride = cell_stride; |
100 | |
101 | if (sn->type == SNodeType::bitmasked) { |
102 | size_t num_cells = sn_desc.snode->num_cells_per_container; |
103 | size_t bitmask_num_words = |
104 | num_cells % 32 == 0 ? (num_cells / 32) : (num_cells / 32 + 1); |
105 | sn_desc.container_stride = |
106 | cell_stride * num_cells + bitmask_num_words * 4; |
107 | } else { |
108 | sn_desc.container_stride = |
109 | cell_stride * sn_desc.snode->num_cells_per_container; |
110 | } |
111 | } |
112 | |
113 | sn->cell_size_bytes = sn_desc.cell_stride; |
114 | |
115 | sn_desc.total_num_cells_from_root = 1; |
116 | for (const auto &e : sn->extractors) { |
117 | // Note that the extractors are set in two places: |
118 | // 1. When a new SNode is first defined |
119 | // 2. StructCompiler::infer_snode_properties() |
120 | // The second step is the finalized result. |
121 | sn_desc.total_num_cells_from_root *= e.num_elements_from_root; |
122 | } |
123 | |
124 | TI_TRACE("SNodeDescriptor" ); |
125 | TI_TRACE("* snode={}" , sn_desc.snode->id); |
126 | TI_TRACE("* type={} (is_place={})" , sn_desc.snode->node_type_name, |
127 | is_place); |
128 | TI_TRACE("* cell_stride={}" , sn_desc.cell_stride); |
129 | TI_TRACE("* num_cells_per_container={}" , |
130 | sn_desc.snode->num_cells_per_container); |
131 | TI_TRACE("* container_stride={}" , sn_desc.container_stride); |
132 | TI_TRACE("* total_num_cells_from_root={}" , |
133 | sn_desc.total_num_cells_from_root); |
134 | TI_TRACE("" ); |
135 | |
136 | TI_ASSERT(snode_descriptors_.find(sn->id) == snode_descriptors_.end()); |
137 | snode_descriptors_[sn->id] = sn_desc; |
138 | return sn_desc.container_stride; |
139 | } |
140 | |
141 | SNodeDescriptorsMap snode_descriptors_; |
142 | }; |
143 | |
144 | } // namespace |
145 | |
146 | CompiledSNodeStructs compile_snode_structs(SNode &root) { |
147 | StructCompiler compiler; |
148 | return compiler.run(root); |
149 | } |
150 | |
151 | } // namespace spirv |
152 | } // namespace taichi::lang |
153 | |