1#pragma once
2#include "taichi/aot/module_loader.h"
3#include "taichi/codegen/spirv/spirv_codegen.h"
4#include "taichi/codegen/spirv/snode_struct_compiler.h"
5#include "taichi/codegen/spirv/kernel_utils.h"
6
7#include "taichi/rhi/vulkan/vulkan_device_creator.h"
8#include "taichi/rhi/vulkan/vulkan_utils.h"
9#include "taichi/rhi/vulkan/vulkan_loader.h"
10#include "taichi/runtime/gfx/runtime.h"
11#include "taichi/runtime/gfx/snode_tree_manager.h"
12#include "taichi/cache/gfx/cache_manager.h"
13#include "taichi/rhi/vulkan/vulkan_device.h"
14#include "vk_mem_alloc.h"
15
16#include "taichi/system/memory_pool.h"
17#include "taichi/common/logging.h"
18#include "taichi/struct/snode_tree.h"
19#include "taichi/program/snode_expr_utils.h"
20#include "taichi/program/program_impl.h"
21#include "taichi/program/program.h"
22
23#include <optional>
24
25namespace taichi::lang {
26
27namespace vulkan {
28class VulkanDeviceCreator;
29}
30
31class VulkanProgramImpl : public ProgramImpl {
32 public:
33 explicit VulkanProgramImpl(CompileConfig &config);
34 FunctionType compile(const CompileConfig &compile_config,
35 Kernel *kernel) override;
36
37 std::size_t get_snode_num_dynamically_allocated(
38 SNode *snode,
39 uint64 *result_buffer) override {
40 return 0; // TODO: support sparse
41 }
42
43 void compile_snode_tree_types(SNodeTree *tree) override;
44
45 void materialize_runtime(MemoryPool *memory_pool,
46 KernelProfilerBase *profiler,
47 uint64 **result_buffer_ptr) override;
48
49 void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;
50
51 void synchronize() override {
52 if (vulkan_runtime_) {
53 vulkan_runtime_->synchronize();
54 }
55 }
56
57 StreamSemaphore flush() override {
58 return vulkan_runtime_->flush();
59 }
60
61 std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
62 const DeviceCapabilityConfig &caps) override;
63
64 void destroy_snode_tree(SNodeTree *snode_tree) override {
65 TI_ASSERT(snode_tree_mgr_ != nullptr);
66 snode_tree_mgr_->destroy_snode_tree(snode_tree);
67 }
68
69 DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
70 uint64 *result_buffer) override;
71 bool used_in_kernel(DeviceAllocationId id) override {
72 return vulkan_runtime_->used_in_kernel(id);
73 }
74 DeviceAllocation allocate_texture(const ImageParams &params) override;
75
76 Device *get_compute_device() override {
77 if (embedded_device_) {
78 return embedded_device_->device();
79 }
80 return nullptr;
81 }
82
83 Device *get_graphics_device() override {
84 if (embedded_device_) {
85 return embedded_device_->device();
86 }
87 return nullptr;
88 }
89
90 size_t get_field_in_tree_offset(int tree_id, const SNode *child) override {
91 return snode_tree_mgr_->get_field_in_tree_offset(tree_id, child);
92 }
93
94 DevicePtr get_snode_tree_device_ptr(int tree_id) override {
95 return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
96 }
97
98 void enqueue_compute_op_lambda(
99 std::function<void(Device *device, CommandList *cmdlist)> op,
100 const std::vector<ComputeOpImageRef> &image_refs) override;
101
102 void dump_cache_data_to_disk() override;
103
104 const std::unique_ptr<gfx::CacheManager> &get_cache_manager();
105
106 ~VulkanProgramImpl() override;
107
108 private:
109 std::unique_ptr<vulkan::VulkanDeviceCreator> embedded_device_{nullptr};
110 std::unique_ptr<gfx::GfxRuntime> vulkan_runtime_{nullptr};
111 std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr};
112 std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_;
113 std::unique_ptr<gfx::CacheManager> cache_manager_{nullptr};
114};
115} // namespace taichi::lang
116