1#include "taichi/runtime/gfx/snode_tree_manager.h"
2
3#include "taichi/runtime/gfx/runtime.h"
4
5namespace taichi::lang {
6namespace gfx {
7
8SNodeTreeManager::SNodeTreeManager(GfxRuntime *rtm) : runtime_(rtm) {
9}
10
11void SNodeTreeManager::materialize_snode_tree(SNodeTree *tree) {
12 auto *const root = tree->root();
13 CompiledSNodeStructs compiled_structs = compile_snode_structs(*root);
14 runtime_->add_root_buffer(compiled_structs.root_size);
15 compiled_snode_structs_.push_back(compiled_structs);
16}
17
18void SNodeTreeManager::destroy_snode_tree(SNodeTree *snode_tree) {
19 int root_id = -1;
20 for (int i = 0; i < compiled_snode_structs_.size(); ++i) {
21 if (compiled_snode_structs_[i].root == snode_tree->root()) {
22 root_id = i;
23 }
24 }
25 if (root_id == -1) {
26 TI_ERROR("the tree to be destroyed cannot be found");
27 }
28 runtime_->root_buffers_[root_id].reset();
29}
30
31size_t SNodeTreeManager::get_field_in_tree_offset(int tree_id,
32 const SNode *child) {
33 auto &snode_struct = compiled_snode_structs_[tree_id];
34 TI_ASSERT_INFO(
35 snode_struct.snode_descriptors.find(child->id) !=
36 snode_struct.snode_descriptors.end() &&
37 snode_struct.snode_descriptors.at(child->id).snode == child,
38 "Requested SNode not found in compiled SNodeTree");
39
40 size_t offset = 0;
41 for (const SNode *sn = child; sn; sn = sn->parent) {
42 offset +=
43 snode_struct.snode_descriptors.at(sn->id).mem_offset_in_parent_cell;
44 }
45
46 return offset;
47}
48
49DevicePtr SNodeTreeManager::get_snode_tree_device_ptr(int tree_id) {
50 return runtime_->root_buffers_[tree_id]->get_ptr();
51}
52
53} // namespace gfx
54} // namespace taichi::lang
55