1 | #pragma once |
2 | #include "taichi/aot/module_loader.h" |
3 | #include "taichi/codegen/spirv/spirv_codegen.h" |
4 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
5 | #include "taichi/codegen/spirv/kernel_utils.h" |
6 | |
7 | #include "taichi/rhi/vulkan/vulkan_device_creator.h" |
8 | #include "taichi/rhi/vulkan/vulkan_utils.h" |
9 | #include "taichi/rhi/vulkan/vulkan_loader.h" |
10 | #include "taichi/runtime/gfx/runtime.h" |
11 | #include "taichi/runtime/gfx/snode_tree_manager.h" |
12 | #include "taichi/cache/gfx/cache_manager.h" |
13 | #include "taichi/rhi/vulkan/vulkan_device.h" |
14 | #include "vk_mem_alloc.h" |
15 | |
16 | #include "taichi/system/memory_pool.h" |
17 | #include "taichi/common/logging.h" |
18 | #include "taichi/struct/snode_tree.h" |
19 | #include "taichi/program/snode_expr_utils.h" |
20 | #include "taichi/program/program_impl.h" |
21 | #include "taichi/program/program.h" |
22 | |
23 | #include <optional> |
24 | |
25 | namespace taichi::lang { |
26 | |
27 | namespace vulkan { |
28 | class VulkanDeviceCreator; |
29 | } |
30 | |
31 | class VulkanProgramImpl : public ProgramImpl { |
32 | public: |
33 | explicit VulkanProgramImpl(CompileConfig &config); |
34 | FunctionType compile(const CompileConfig &compile_config, |
35 | Kernel *kernel) override; |
36 | |
37 | std::size_t get_snode_num_dynamically_allocated( |
38 | SNode *snode, |
39 | uint64 *result_buffer) override { |
40 | return 0; // TODO: support sparse |
41 | } |
42 | |
43 | void compile_snode_tree_types(SNodeTree *tree) override; |
44 | |
45 | void materialize_runtime(MemoryPool *memory_pool, |
46 | KernelProfilerBase *profiler, |
47 | uint64 **result_buffer_ptr) override; |
48 | |
49 | void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override; |
50 | |
51 | void synchronize() override { |
52 | if (vulkan_runtime_) { |
53 | vulkan_runtime_->synchronize(); |
54 | } |
55 | } |
56 | |
57 | StreamSemaphore flush() override { |
58 | return vulkan_runtime_->flush(); |
59 | } |
60 | |
61 | std::unique_ptr<AotModuleBuilder> make_aot_module_builder( |
62 | const DeviceCapabilityConfig &caps) override; |
63 | |
64 | void destroy_snode_tree(SNodeTree *snode_tree) override { |
65 | TI_ASSERT(snode_tree_mgr_ != nullptr); |
66 | snode_tree_mgr_->destroy_snode_tree(snode_tree); |
67 | } |
68 | |
69 | DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, |
70 | uint64 *result_buffer) override; |
71 | bool used_in_kernel(DeviceAllocationId id) override { |
72 | return vulkan_runtime_->used_in_kernel(id); |
73 | } |
74 | DeviceAllocation allocate_texture(const ImageParams ¶ms) override; |
75 | |
76 | Device *get_compute_device() override { |
77 | if (embedded_device_) { |
78 | return embedded_device_->device(); |
79 | } |
80 | return nullptr; |
81 | } |
82 | |
83 | Device *get_graphics_device() override { |
84 | if (embedded_device_) { |
85 | return embedded_device_->device(); |
86 | } |
87 | return nullptr; |
88 | } |
89 | |
90 | size_t get_field_in_tree_offset(int tree_id, const SNode *child) override { |
91 | return snode_tree_mgr_->get_field_in_tree_offset(tree_id, child); |
92 | } |
93 | |
94 | DevicePtr get_snode_tree_device_ptr(int tree_id) override { |
95 | return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); |
96 | } |
97 | |
98 | void enqueue_compute_op_lambda( |
99 | std::function<void(Device *device, CommandList *cmdlist)> op, |
100 | const std::vector<ComputeOpImageRef> &image_refs) override; |
101 | |
102 | void dump_cache_data_to_disk() override; |
103 | |
104 | const std::unique_ptr<gfx::CacheManager> &get_cache_manager(); |
105 | |
106 | ~VulkanProgramImpl() override; |
107 | |
108 | private: |
109 | std::unique_ptr<vulkan::VulkanDeviceCreator> embedded_device_{nullptr}; |
110 | std::unique_ptr<gfx::GfxRuntime> vulkan_runtime_{nullptr}; |
111 | std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr}; |
112 | std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_; |
113 | std::unique_ptr<gfx::CacheManager> cache_manager_{nullptr}; |
114 | }; |
115 | } // namespace taichi::lang |
116 | |