1// Codegen for the hierarchical data structure
2#pragma once
3
4#include <unordered_map>
5
6#include "taichi/ir/snode.h"
7
8#include "spirv_types.h"
9
10namespace taichi::lang {
11namespace spirv {
12
13struct SNodeDescriptor {
14 const SNode *snode = nullptr;
15 // Stride (bytes) of a single cell.
16 size_t cell_stride = 0;
17
18 // Bytes of a single container.
19 size_t container_stride = 0;
20
21 // Total number of CELLS of this SNode
22 // For example, for a layout of
23 // ti.root
24 // .dense(ti.ij, (3, 2)) // S1
25 // .dense(ti.ij, (5, 3)) // S2
26 // |total_num_cells_from_root| for S2 is 3x2x5x3 = 90. That is, S2 has a total
27 // of 90 cells. Note that the number of S2 (container) itself is 3x2=6!
28 size_t total_num_cells_from_root = 0;
29 // An SNode can have multiple number of components, where each component
30 // starts at a fixed offset in its parent cell's memory.
31 size_t mem_offset_in_parent_cell = 0;
32
33 SNode *get_child(int ch_i) const {
34 return snode->ch[ch_i].get();
35 }
36};
37
38using SNodeDescriptorsMap = std::unordered_map<int, SNodeDescriptor>;
39
40struct CompiledSNodeStructs {
41 // Root buffer size in bytes.
42 size_t root_size{0};
43 // Root SNode
44 const SNode *root{nullptr};
45 // Map from SNode ID to its descriptor.
46 SNodeDescriptorsMap snode_descriptors;
47
48 // TODO: Use the new type compiler
49 // tinyir::Block *type_factory;
50 // const tinyir::Type *root_type;
51};
52
53CompiledSNodeStructs compile_snode_structs(SNode &root);
54
55} // namespace spirv
56} // namespace taichi::lang
57