1 | #pragma once |
2 | #include "taichi/aot/module_loader.h" |
3 | #include "taichi/codegen/spirv/spirv_codegen.h" |
4 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
5 | #include "taichi/codegen/spirv/kernel_utils.h" |
6 | |
7 | #include "taichi/rhi/metal/metal_device.h" |
8 | #include "taichi/runtime/gfx/runtime.h" |
9 | #include "taichi/runtime/gfx/snode_tree_manager.h" |
10 | #include "taichi/cache/gfx/cache_manager.h" |
11 | |
12 | #include "taichi/system/memory_pool.h" |
13 | #include "taichi/common/logging.h" |
14 | #include "taichi/struct/snode_tree.h" |
15 | #include "taichi/program/snode_expr_utils.h" |
16 | #include "taichi/program/program_impl.h" |
17 | #include "taichi/program/program.h" |
18 | |
19 | namespace taichi::lang { |
20 | |
21 | class MetalProgramImpl : public ProgramImpl { |
22 | public: |
23 | explicit MetalProgramImpl(CompileConfig &config); |
24 | FunctionType compile(const CompileConfig &compile_config, |
25 | Kernel *kernel) override; |
26 | |
27 | std::size_t get_snode_num_dynamically_allocated( |
28 | SNode *snode, |
29 | uint64 *result_buffer) override { |
30 | return 0; // TODO: support sparse |
31 | } |
32 | |
33 | void compile_snode_tree_types(SNodeTree *tree) override; |
34 | |
35 | void materialize_runtime(MemoryPool *memory_pool, |
36 | KernelProfilerBase *profiler, |
37 | uint64 **result_buffer_ptr) override; |
38 | |
39 | void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override; |
40 | |
41 | void synchronize() override { |
42 | if (gfx_runtime_) { |
43 | gfx_runtime_->synchronize(); |
44 | } |
45 | } |
46 | |
47 | StreamSemaphore flush() override { |
48 | return gfx_runtime_->flush(); |
49 | } |
50 | |
51 | std::unique_ptr<AotModuleBuilder> make_aot_module_builder( |
52 | const DeviceCapabilityConfig &caps) override; |
53 | |
54 | void destroy_snode_tree(SNodeTree *snode_tree) override { |
55 | TI_ASSERT(snode_tree_mgr_ != nullptr); |
56 | snode_tree_mgr_->destroy_snode_tree(snode_tree); |
57 | } |
58 | |
59 | DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, |
60 | uint64 *result_buffer) override; |
61 | DeviceAllocation allocate_texture(const ImageParams ¶ms) override; |
62 | |
63 | bool used_in_kernel(DeviceAllocationId id) override { |
64 | return gfx_runtime_->used_in_kernel(id); |
65 | } |
66 | |
67 | Device *get_compute_device() override { |
68 | if (embedded_device_) { |
69 | return embedded_device_.get(); |
70 | } |
71 | return nullptr; |
72 | } |
73 | |
74 | Device *get_graphics_device() override { |
75 | if (embedded_device_) { |
76 | return embedded_device_.get(); |
77 | } |
78 | return nullptr; |
79 | } |
80 | |
81 | size_t get_field_in_tree_offset(int tree_id, const SNode *child) override { |
82 | return snode_tree_mgr_->get_field_in_tree_offset(tree_id, child); |
83 | } |
84 | |
85 | DevicePtr get_snode_tree_device_ptr(int tree_id) override { |
86 | return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); |
87 | } |
88 | |
89 | void enqueue_compute_op_lambda( |
90 | std::function<void(Device *device, CommandList *cmdlist)> op, |
91 | const std::vector<ComputeOpImageRef> &image_refs) override; |
92 | |
93 | void dump_cache_data_to_disk() override; |
94 | |
95 | const std::unique_ptr<gfx::CacheManager> &get_cache_manager(); |
96 | |
97 | ~MetalProgramImpl() override; |
98 | |
99 | private: |
100 | std::unique_ptr<metal::MetalDevice> embedded_device_{nullptr}; |
101 | std::unique_ptr<gfx::GfxRuntime> gfx_runtime_{nullptr}; |
102 | std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr}; |
103 | std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_{}; |
104 | std::unique_ptr<gfx::CacheManager> cache_manager_{nullptr}; |
105 | }; |
106 | |
107 | } // namespace taichi::lang |
108 | |