1#pragma once
2#include "taichi/aot/module_loader.h"
3#include "taichi/codegen/spirv/spirv_codegen.h"
4#include "taichi/codegen/spirv/snode_struct_compiler.h"
5#include "taichi/codegen/spirv/kernel_utils.h"
6
7#include "taichi/rhi/metal/metal_device.h"
8#include "taichi/runtime/gfx/runtime.h"
9#include "taichi/runtime/gfx/snode_tree_manager.h"
10#include "taichi/cache/gfx/cache_manager.h"
11
12#include "taichi/system/memory_pool.h"
13#include "taichi/common/logging.h"
14#include "taichi/struct/snode_tree.h"
15#include "taichi/program/snode_expr_utils.h"
16#include "taichi/program/program_impl.h"
17#include "taichi/program/program.h"
18
19namespace taichi::lang {
20
21class MetalProgramImpl : public ProgramImpl {
22 public:
23 explicit MetalProgramImpl(CompileConfig &config);
24 FunctionType compile(const CompileConfig &compile_config,
25 Kernel *kernel) override;
26
27 std::size_t get_snode_num_dynamically_allocated(
28 SNode *snode,
29 uint64 *result_buffer) override {
30 return 0; // TODO: support sparse
31 }
32
33 void compile_snode_tree_types(SNodeTree *tree) override;
34
35 void materialize_runtime(MemoryPool *memory_pool,
36 KernelProfilerBase *profiler,
37 uint64 **result_buffer_ptr) override;
38
39 void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;
40
41 void synchronize() override {
42 if (gfx_runtime_) {
43 gfx_runtime_->synchronize();
44 }
45 }
46
47 StreamSemaphore flush() override {
48 return gfx_runtime_->flush();
49 }
50
51 std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
52 const DeviceCapabilityConfig &caps) override;
53
54 void destroy_snode_tree(SNodeTree *snode_tree) override {
55 TI_ASSERT(snode_tree_mgr_ != nullptr);
56 snode_tree_mgr_->destroy_snode_tree(snode_tree);
57 }
58
59 DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
60 uint64 *result_buffer) override;
61 DeviceAllocation allocate_texture(const ImageParams &params) override;
62
63 bool used_in_kernel(DeviceAllocationId id) override {
64 return gfx_runtime_->used_in_kernel(id);
65 }
66
67 Device *get_compute_device() override {
68 if (embedded_device_) {
69 return embedded_device_.get();
70 }
71 return nullptr;
72 }
73
74 Device *get_graphics_device() override {
75 if (embedded_device_) {
76 return embedded_device_.get();
77 }
78 return nullptr;
79 }
80
81 size_t get_field_in_tree_offset(int tree_id, const SNode *child) override {
82 return snode_tree_mgr_->get_field_in_tree_offset(tree_id, child);
83 }
84
85 DevicePtr get_snode_tree_device_ptr(int tree_id) override {
86 return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
87 }
88
89 void enqueue_compute_op_lambda(
90 std::function<void(Device *device, CommandList *cmdlist)> op,
91 const std::vector<ComputeOpImageRef> &image_refs) override;
92
93 void dump_cache_data_to_disk() override;
94
95 const std::unique_ptr<gfx::CacheManager> &get_cache_manager();
96
97 ~MetalProgramImpl() override;
98
99 private:
100 std::unique_ptr<metal::MetalDevice> embedded_device_{nullptr};
101 std::unique_ptr<gfx::GfxRuntime> gfx_runtime_{nullptr};
102 std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr};
103 std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_{};
104 std::unique_ptr<gfx::CacheManager> cache_manager_{nullptr};
105};
106
107} // namespace taichi::lang
108