1#pragma once
2
3#include "taichi/rhi/vulkan/vulkan_common.h"
4
5#include <vk_mem_alloc.h>
6
7#include <memory>
8#include <vector>
9#include <stack>
10#include <unordered_map>
11
12namespace vkapi {
13
14struct DeviceObj {
15 VkDevice device{VK_NULL_HANDLE};
16 virtual ~DeviceObj() = default;
17};
18using IDeviceObj = std::shared_ptr<DeviceObj>;
19IDeviceObj create_device_obj(VkDevice device);
20
21// VkSemaphore
22struct DeviceObjVkSemaphore : public DeviceObj {
23 VkSemaphore semaphore{VK_NULL_HANDLE};
24 ~DeviceObjVkSemaphore() override;
25};
26using IVkSemaphore = std::shared_ptr<DeviceObjVkSemaphore>;
27IVkSemaphore create_semaphore(VkDevice device,
28 VkSemaphoreCreateFlags flags,
29 void *pnext = nullptr);
30
31// VkFence
32struct DeviceObjVkFence : public DeviceObj {
33 VkFence fence{VK_NULL_HANDLE};
34 ~DeviceObjVkFence() override;
35};
36using IVkFence = std::shared_ptr<DeviceObjVkFence>;
37IVkFence create_fence(VkDevice device,
38 VkFenceCreateFlags flags,
39 void *pnext = nullptr);
40
41// VkDescriptorSetLayout
42struct DeviceObjVkDescriptorSetLayout : public DeviceObj {
43 VkDescriptorSetLayout layout{VK_NULL_HANDLE};
44 ~DeviceObjVkDescriptorSetLayout() override;
45};
46using IVkDescriptorSetLayout = std::shared_ptr<DeviceObjVkDescriptorSetLayout>;
47IVkDescriptorSetLayout create_descriptor_set_layout(
48 VkDevice device,
49 VkDescriptorSetLayoutCreateInfo *create_info);
50
51// VkDescriptorPool
52struct DeviceObjVkDescriptorPool : public DeviceObj {
53 VkDescriptorPool pool{VK_NULL_HANDLE};
54 // Can recycling of this actually be trivial?
55 // std::unordered_multimap<VkDescriptorSetLayout, VkDescriptorSet> free_list;
56 ~DeviceObjVkDescriptorPool() override;
57};
58using IVkDescriptorPool = std::shared_ptr<DeviceObjVkDescriptorPool>;
59IVkDescriptorPool create_descriptor_pool(
60 VkDevice device,
61 VkDescriptorPoolCreateInfo *create_info);
62
63// VkDescriptorSet
64struct DeviceObjVkDescriptorSet : public DeviceObj {
65 VkDescriptorSet set{VK_NULL_HANDLE};
66 IVkDescriptorSetLayout ref_layout{nullptr};
67 IVkDescriptorPool ref_pool{nullptr};
68 std::vector<IDeviceObj> ref_binding_objs;
69 ~DeviceObjVkDescriptorSet() override;
70};
71using IVkDescriptorSet = std::shared_ptr<DeviceObjVkDescriptorSet>;
72// Returns nullptr is pool is full
73IVkDescriptorSet allocate_descriptor_sets(IVkDescriptorPool pool,
74 IVkDescriptorSetLayout layout,
75 void *pnext = nullptr);
76
77// VkCommandPool
78struct DeviceObjVkCommandPool : public DeviceObj {
79 VkCommandPool pool{VK_NULL_HANDLE};
80 uint32_t queue_family_index{0};
81 std::stack<VkCommandBuffer> free_primary;
82 std::stack<VkCommandBuffer> free_secondary;
83 ~DeviceObjVkCommandPool() override;
84};
85using IVkCommandPool = std::shared_ptr<DeviceObjVkCommandPool>;
86IVkCommandPool create_command_pool(VkDevice device,
87 VkCommandPoolCreateFlags flags,
88 uint32_t queue_family_index);
89
90// VkCommandBuffer
91// Should keep track of used objects in the ref_pool
92struct DeviceObjVkCommandBuffer : public DeviceObj {
93 VkCommandBuffer buffer{VK_NULL_HANDLE};
94 VkCommandBufferLevel level{VK_COMMAND_BUFFER_LEVEL_PRIMARY};
95 IVkCommandPool ref_pool{nullptr};
96 std::vector<IDeviceObj> refs;
97 ~DeviceObjVkCommandBuffer() override;
98};
99using IVkCommandBuffer = std::shared_ptr<DeviceObjVkCommandBuffer>;
100IVkCommandBuffer allocate_command_buffer(
101 IVkCommandPool pool,
102 VkCommandBufferLevel level = VK_COMMAND_BUFFER_LEVEL_PRIMARY);
103
104// VkRenderPass
105struct DeviceObjVkRenderPass : public DeviceObj {
106 VkRenderPass renderpass{VK_NULL_HANDLE};
107 ~DeviceObjVkRenderPass() override;
108};
109using IVkRenderPass = std::shared_ptr<DeviceObjVkRenderPass>;
110IVkRenderPass create_render_pass(VkDevice device,
111 VkRenderPassCreateInfo *create_info);
112
113// VkPipelineLayout
114struct DeviceObjVkPipelineLayout : public DeviceObj {
115 VkPipelineLayout layout{VK_NULL_HANDLE};
116 std::vector<IVkDescriptorSetLayout> ref_desc_layouts;
117 ~DeviceObjVkPipelineLayout() override;
118};
119using IVkPipelineLayout = std::shared_ptr<DeviceObjVkPipelineLayout>;
120IVkPipelineLayout create_pipeline_layout(
121 VkDevice device,
122 std::vector<IVkDescriptorSetLayout> &set_layouts,
123 uint32_t push_constant_range_count = 0,
124 VkPushConstantRange *push_constant_ranges = nullptr);
125
126// VkPipelineCache
127struct DeviceObjVkPipelineCache : public DeviceObj {
128 VkPipelineCache cache{VK_NULL_HANDLE};
129 ~DeviceObjVkPipelineCache() override;
130};
131using IVkPipelineCache = std::shared_ptr<DeviceObjVkPipelineCache>;
132IVkPipelineCache create_pipeline_cache(VkDevice device,
133 VkPipelineCacheCreateFlags flags,
134 size_t initial_size = 0,
135 const void *initial_data = nullptr);
136
137// VkPipeline
138struct DeviceObjVkPipeline : public DeviceObj {
139 VkPipeline pipeline{VK_NULL_HANDLE};
140 IVkPipelineLayout ref_layout{nullptr};
141 IVkRenderPass ref_renderpass{nullptr};
142 IVkPipelineCache ref_cache{nullptr};
143 std::vector<std::shared_ptr<DeviceObjVkPipeline>> ref_pipeline_libraries;
144 ~DeviceObjVkPipeline() override;
145};
146using IVkPipeline = std::shared_ptr<DeviceObjVkPipeline>;
147IVkPipeline create_compute_pipeline(VkDevice device,
148 VkPipelineCreateFlags flags,
149 VkPipelineShaderStageCreateInfo &stage,
150 IVkPipelineLayout layout,
151 IVkPipelineCache cache = nullptr,
152 IVkPipeline base_pipeline = nullptr);
153IVkPipeline create_graphics_pipeline(VkDevice device,
154 VkGraphicsPipelineCreateInfo *create_info,
155 IVkRenderPass renderpass,
156 IVkPipelineLayout layout,
157 IVkPipelineCache cache = nullptr,
158 IVkPipeline base_pipeline = nullptr);
159IVkPipeline create_graphics_pipeline_dynamic(
160 VkDevice device,
161 VkGraphicsPipelineCreateInfo *create_info,
162 VkPipelineRenderingCreateInfoKHR *rendering_info,
163 IVkPipelineLayout layout,
164 IVkPipelineCache cache = nullptr,
165 IVkPipeline base_pipeline = nullptr);
166IVkPipeline create_raytracing_pipeline(
167 VkDevice device,
168 VkRayTracingPipelineCreateInfoKHR *create_info,
169 IVkPipelineLayout layout,
170 std::vector<IVkPipeline> &pipeline_libraries,
171 VkDeferredOperationKHR deferredOperation = VK_NULL_HANDLE,
172 IVkPipelineCache cache = nullptr,
173 IVkPipeline base_pipeline = nullptr);
174
175// VkSampler
176struct DeviceObjVkSampler : public DeviceObj {
177 VkSampler sampler{VK_NULL_HANDLE};
178 ~DeviceObjVkSampler() override;
179};
180using IVkSampler = std::shared_ptr<DeviceObjVkSampler>;
181IVkSampler create_sampler(VkDevice device, const VkSamplerCreateInfo &info);
182
183// VkImage
184struct DeviceObjVkImage : public DeviceObj {
185 VkImage image{VK_NULL_HANDLE};
186 VkFormat format{VK_FORMAT_UNDEFINED};
187 VkImageType type{VK_IMAGE_TYPE_2D};
188 uint32_t width{1};
189 uint32_t height{1};
190 uint32_t depth{1};
191 uint32_t mip_levels{1};
192 uint32_t array_layers{1};
193 VkImageUsageFlags usage{0};
194 VmaAllocator allocator{nullptr};
195 VmaAllocation allocation{nullptr};
196 ~DeviceObjVkImage() override;
197};
198using IVkImage = std::shared_ptr<DeviceObjVkImage>;
199// Allocate image
200IVkImage create_image(VkDevice device,
201 VmaAllocator allocator,
202 VkImageCreateInfo *image_info,
203 VmaAllocationCreateInfo *alloc_info);
204// Importing external image
205IVkImage create_image(VkDevice device,
206 VkImage image,
207 VkFormat format,
208 VkImageType type,
209 VkExtent3D extent,
210 uint32_t mip_levels,
211 uint32_t array_layers,
212 VkImageUsageFlags usage);
213
214// VkImageView
215struct DeviceObjVkImageView : public DeviceObj {
216 VkImageView view{VK_NULL_HANDLE};
217 VkImageViewType type{VK_IMAGE_VIEW_TYPE_2D};
218 VkImageSubresourceRange subresource_range{
219 VK_IMAGE_ASPECT_COLOR_BIT | VK_IMAGE_ASPECT_DEPTH_BIT, 0, 1, 0, 1};
220 IVkImage ref_image{nullptr};
221 ~DeviceObjVkImageView() override;
222};
223using IVkImageView = std::shared_ptr<DeviceObjVkImageView>;
224IVkImageView create_image_view(VkDevice device,
225 IVkImage image,
226 VkImageViewCreateInfo *create_info);
227
228// VkFramebuffer
229struct DeviceObjVkFramebuffer : public DeviceObj {
230 VkFramebuffer framebuffer{VK_NULL_HANDLE};
231 uint32_t width{0};
232 uint32_t height{0};
233 uint32_t layers{1};
234 std::vector<IVkImageView> ref_attachments;
235 IVkRenderPass ref_renderpass{nullptr};
236 ~DeviceObjVkFramebuffer() override;
237};
238using IVkFramebuffer = std::shared_ptr<DeviceObjVkFramebuffer>;
239IVkFramebuffer create_framebuffer(VkFramebufferCreateFlags flags,
240 IVkRenderPass renderpass,
241 const std::vector<IVkImageView> &attachments,
242 uint32_t width,
243 uint32_t height,
244 uint32_t layers = 1,
245 void *pnext = nullptr);
246
247// VkBuffer
248struct DeviceObjVkBuffer : public DeviceObj {
249 VkBuffer buffer{VK_NULL_HANDLE};
250 VkBufferUsageFlags usage{0};
251 VmaAllocator allocator{nullptr};
252 VmaAllocation allocation{nullptr};
253 ~DeviceObjVkBuffer() override;
254};
255using IVkBuffer = std::shared_ptr<DeviceObjVkBuffer>;
256// Allocate buffer
257IVkBuffer create_buffer(VkDevice device,
258 VmaAllocator allocator,
259 VkBufferCreateInfo *buffer_info,
260 VmaAllocationCreateInfo *alloc_info);
261// Importing external buffer
262IVkBuffer create_buffer(VkDevice device,
263 VkBuffer buffer,
264 VkBufferUsageFlags usage);
265
266// VkBufferView
267struct DeviceObjVkBufferView : public DeviceObj {
268 VkBufferView view{VK_NULL_HANDLE};
269 VkFormat format{VK_FORMAT_UNDEFINED};
270 VkDeviceSize offset{0};
271 VkDeviceSize range{0};
272 IVkBuffer ref_buffer{nullptr};
273 ~DeviceObjVkBufferView() override;
274};
275using IVkBufferView = std::shared_ptr<DeviceObjVkBufferView>;
276IVkBufferView create_buffer_view(IVkBuffer buffer,
277 VkBufferViewCreateFlags flags,
278 VkFormat format,
279 VkDeviceSize offset,
280 VkDeviceSize range);
281
282// VkAccelerationStructureKHR
283struct DeviceObjVkAccelerationStructureKHR : public DeviceObj {
284 VkAccelerationStructureKHR accel{VK_NULL_HANDLE};
285 VkAccelerationStructureTypeKHR type{
286 VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR};
287 VkDeviceSize offset{0};
288 VkDeviceSize size{0};
289 IVkBuffer ref_buffer{nullptr};
290 ~DeviceObjVkAccelerationStructureKHR() override;
291};
292using IVkAccelerationStructureKHR =
293 std::shared_ptr<DeviceObjVkAccelerationStructureKHR>;
294IVkAccelerationStructureKHR create_acceleration_structure(
295 VkAccelerationStructureCreateFlagsKHR flags,
296 IVkBuffer buffer,
297 VkDeviceSize offset,
298 VkDeviceSize size,
299 VkAccelerationStructureTypeKHR type);
300
301// VkQueryPool
302struct DeviceObjVkQueryPool : public DeviceObj {
303 VkQueryPool query_pool{VK_NULL_HANDLE};
304 ~DeviceObjVkQueryPool() override;
305};
306using IVkQueryPool = std::shared_ptr<DeviceObjVkQueryPool>;
307IVkQueryPool create_query_pool(VkDevice device);
308
309} // namespace vkapi
310