1#include "taichi/rhi/vulkan/vulkan_device_creator.h"
2
3#include <iostream>
4#include <stdexcept>
5#include <string>
6#include <unordered_set>
7#include <vector>
8
9#include "taichi/rhi/vulkan/vulkan_common.h"
10#include "taichi/rhi/vulkan/vulkan_loader.h"
11#include "taichi/rhi/vulkan/vulkan_device.h"
12#include "taichi/common/utils.h"
13
14namespace taichi::lang {
15namespace vulkan {
16
17namespace {
18
19const std::vector<const char *> kValidationLayers = {
20 "VK_LAYER_KHRONOS_validation",
21};
22
23bool check_validation_layer_support() {
24 uint32_t layer_count;
25 vkEnumerateInstanceLayerProperties(&layer_count, nullptr);
26
27 std::vector<VkLayerProperties> available_layers(layer_count);
28 vkEnumerateInstanceLayerProperties(&layer_count, available_layers.data());
29
30 std::unordered_set<std::string> available_layer_names;
31 for (const auto &layer_props : available_layers) {
32 available_layer_names.insert(layer_props.layerName);
33 }
34 for (const char *name : kValidationLayers) {
35 if (available_layer_names.count(std::string(name)) == 0) {
36 return false;
37 }
38 }
39 return true;
40}
41
42[[maybe_unused]] bool vk_ignore_validation_warning(
43 const std::string &msg_name) {
44 if (msg_name == "UNASSIGNED-DEBUG-PRINTF") {
45 // Ignore truncated Debug Printf message
46 return true;
47 }
48
49 if (msg_name == "VUID_Undefined") {
50 // FIXME: Remove this branch after upgrading Vulkan driver for built bots
51 return true;
52 }
53
54 return false;
55}
56
57VKAPI_ATTR VkBool32 VKAPI_CALL
58vk_debug_callback(VkDebugUtilsMessageSeverityFlagBitsEXT message_severity,
59 VkDebugUtilsMessageTypeFlagsEXT message_type,
60 const VkDebugUtilsMessengerCallbackDataEXT *p_callback_data,
61 void *p_user_data) {
62 if (message_type == VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT &&
63 message_severity == VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT &&
64 strstr(p_callback_data->pMessage, "DEBUG-PRINTF") != nullptr) {
65 // Message format is "BLABLA | MessageID=xxxxx | <DEBUG_PRINT_MSG>"
66 std::string msg(p_callback_data->pMessage);
67 auto const pos = msg.find_last_of("|");
68 std::cout << msg.substr(pos + 2);
69 }
70
71 if (message_severity > VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT) {
72 char msg_buf[4096];
73 snprintf(msg_buf, sizeof(msg_buf), "Vulkan validation layer: %d, %s",
74 message_type, p_callback_data->pMessage);
75
76 if (is_ci()) {
77 auto msg_name = std::string(p_callback_data->pMessageIdName);
78 if (!vk_ignore_validation_warning(msg_name))
79 TI_ERROR(msg_buf);
80 } else {
81 RHI_LOG_ERROR(msg_buf);
82 }
83 }
84
85 return VK_FALSE;
86}
87
88void populate_debug_messenger_create_info(
89 VkDebugUtilsMessengerCreateInfoEXT *create_info) {
90 *create_info = {};
91 create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
92 create_info->messageSeverity =
93 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT |
94 VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
95 VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT |
96 VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
97 create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT |
98 VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT |
99 VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
100 create_info->pfnUserCallback = vk_debug_callback;
101 create_info->pUserData = nullptr;
102}
103
104VkResult create_debug_utils_messenger_ext(
105 VkInstance instance,
106 const VkDebugUtilsMessengerCreateInfoEXT *p_create_info,
107 const VkAllocationCallbacks *p_allocator,
108 VkDebugUtilsMessengerEXT *p_debug_messenger) {
109 auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(
110 instance, "vkCreateDebugUtilsMessengerEXT");
111 if (func != nullptr) {
112 return func(instance, p_create_info, p_allocator, p_debug_messenger);
113 } else {
114 return VK_ERROR_EXTENSION_NOT_PRESENT;
115 }
116}
117
118void destroy_debug_utils_messenger_ext(
119 VkInstance instance,
120 VkDebugUtilsMessengerEXT debug_messenger,
121 const VkAllocationCallbacks *p_allocator) {
122 auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(
123 instance, "vkDestroyDebugUtilsMessengerEXT");
124 if (func != nullptr) {
125 func(instance, debug_messenger, p_allocator);
126 }
127}
128
129std::vector<const char *> get_required_extensions(bool enable_validation) {
130 std::vector<const char *> extensions;
131 if (enable_validation) {
132 extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
133 }
134 return extensions;
135}
136
137VulkanQueueFamilyIndices find_queue_families(VkPhysicalDevice device,
138 VkSurfaceKHR surface) {
139 VulkanQueueFamilyIndices indices;
140
141 uint32_t queue_family_count = 0;
142 vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count,
143 nullptr);
144 std::vector<VkQueueFamilyProperties> queue_families(queue_family_count);
145 vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count,
146 queue_families.data());
147 // TODO: What the heck is this?
148 constexpr VkQueueFlags kFlagMask =
149 (~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT));
150
151 // first try and find a queue that has just the compute bit set
152 for (int i = 0; i < (int)queue_family_count; ++i) {
153 const VkQueueFlags masked_flags = kFlagMask & queue_families[i].queueFlags;
154 if ((masked_flags & VK_QUEUE_COMPUTE_BIT) &&
155 !(masked_flags & VK_QUEUE_GRAPHICS_BIT)) {
156 indices.compute_family = i;
157 }
158 if (masked_flags & VK_QUEUE_GRAPHICS_BIT) {
159 indices.graphics_family = i;
160 }
161
162 if (surface != VK_NULL_HANDLE) {
163 VkBool32 present_support = false;
164 vkGetPhysicalDeviceSurfaceSupportKHR(device, i, surface,
165 &present_support);
166
167 if (present_support) {
168 indices.present_family = i;
169 }
170 }
171
172 if (indices.is_complete() && indices.is_complete_for_ui()) {
173 char msg_buf[128];
174 RHI_DEBUG_SNPRINTF(msg_buf, sizeof(msg_buf),
175 "Found async compute queue %d, graphics queue %d",
176 indices.compute_family.value(),
177 indices.graphics_family.value());
178 RHI_LOG_DEBUG(msg_buf);
179 return indices;
180 }
181 }
182
183 // lastly get any queue that will work
184 for (int i = 0; i < (int)queue_family_count; ++i) {
185 const VkQueueFlags masked_flags = kFlagMask & queue_families[i].queueFlags;
186 if (masked_flags & VK_QUEUE_COMPUTE_BIT) {
187 indices.compute_family = i;
188 }
189 if (indices.is_complete()) {
190 return indices;
191 }
192 }
193 return indices;
194}
195
196size_t get_device_score(VkPhysicalDevice device, VkSurfaceKHR surface) {
197 auto indices = find_queue_families(device, surface);
198 VkPhysicalDeviceFeatures features{};
199 vkGetPhysicalDeviceFeatures(device, &features);
200 VkPhysicalDeviceProperties properties{};
201 vkGetPhysicalDeviceProperties(device, &properties);
202
203 size_t score = 0;
204
205 if (surface != VK_NULL_HANDLE) {
206 // this means we need ui
207 score = size_t(indices.is_complete_for_ui()) * 1000;
208 } else {
209 score = size_t(indices.is_complete()) * 1000;
210 }
211
212 score += features.wideLines * 100;
213 score +=
214 size_t(properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) *
215 500;
216 score +=
217 size_t(properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) *
218 1000;
219 score += VK_API_VERSION_MINOR(properties.apiVersion) * 100;
220
221 return score;
222}
223
224} // namespace
225
226VulkanDeviceCreator::VulkanDeviceCreator(
227 const VulkanDeviceCreator::Params &params)
228 : params_(params) {
229 if (!VulkanLoader::instance().init()) {
230 throw std::runtime_error("Error loading vulkan");
231 }
232
233 ti_device_ = std::make_unique<VulkanDevice>();
234 uint32_t vk_api_version;
235 bool manual_create;
236 if (params_.api_version.has_value()) {
237 // The version client specified to use
238 //
239 // If the user provided an API version then the device creation process is
240 // totally directed by the information provided externally.
241 vk_api_version = params_.api_version.value();
242 manual_create = true;
243 } else {
244 // The highest version designed to use
245 vk_api_version = VulkanEnvSettings::k_api_version();
246 manual_create = false;
247 }
248
249 create_instance(vk_api_version, manual_create);
250 setup_debug_messenger();
251 if (params_.is_for_ui) {
252 create_surface();
253 }
254 pick_physical_device();
255 create_logical_device(manual_create);
256
257 {
258 VulkanDevice::Params params;
259 params.instance = instance_;
260 params.physical_device = physical_device_;
261 params.device = device_;
262 params.compute_queue = compute_queue_;
263 params.compute_queue_family_index =
264 queue_family_indices_.compute_family.value();
265 params.graphics_queue = graphics_queue_;
266 params.graphics_queue_family_index =
267 queue_family_indices_.graphics_family.value();
268 ti_device_->init_vulkan_structs(params);
269 }
270}
271
272VulkanDeviceCreator::~VulkanDeviceCreator() {
273 ti_device_.reset();
274 if (surface_ != VK_NULL_HANDLE) {
275 vkDestroySurfaceKHR(instance_, surface_, kNoVkAllocCallbacks);
276 }
277 if (params_.enable_validation_layer) {
278 destroy_debug_utils_messenger_ext(instance_, debug_messenger_,
279 kNoVkAllocCallbacks);
280 }
281 vkDestroyDevice(device_, kNoVkAllocCallbacks);
282 vkDestroyInstance(instance_, kNoVkAllocCallbacks);
283}
284
285void VulkanDeviceCreator::create_instance(uint32_t vk_api_version,
286 bool manual_create) {
287 VkApplicationInfo app_info{};
288 app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
289 app_info.pApplicationName = "Taichi Vulkan Backend";
290 app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
291 app_info.pEngineName = "No Engine";
292 app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
293 app_info.apiVersion = VulkanEnvSettings::k_api_version();
294
295 VkInstanceCreateInfo create_info{};
296 create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
297 create_info.pApplicationInfo = &app_info;
298
299 if (params_.enable_validation_layer) {
300 if (!check_validation_layer_support()) {
301 RHI_LOG_ERROR(
302 "Validation layers requested but not available, turning off... "
303 "Please make sure Vulkan SDK from https://vulkan.lunarg.com/sdk/home "
304 "is installed.");
305 params_.enable_validation_layer = false;
306 }
307 }
308
309 VkDebugUtilsMessengerCreateInfoEXT debug_create_info{};
310
311 if (params_.enable_validation_layer) {
312 create_info.enabledLayerCount = (uint32_t)kValidationLayers.size();
313 create_info.ppEnabledLayerNames = kValidationLayers.data();
314
315 populate_debug_messenger_create_info(&debug_create_info);
316 create_info.pNext = &debug_create_info;
317 } else {
318 create_info.enabledLayerCount = 0;
319 create_info.pNext = nullptr;
320 }
321
322 // Response to `DebugPrintf`.
323 std::array<VkValidationFeatureEnableEXT, 1> vfes = {
324 VK_VALIDATION_FEATURE_ENABLE_DEBUG_PRINTF_EXT};
325 VkValidationFeaturesEXT vf = {};
326 if (params_.enable_validation_layer) {
327 vf.sType = VK_STRUCTURE_TYPE_VALIDATION_FEATURES_EXT;
328 vf.pNext = create_info.pNext;
329 vf.enabledValidationFeatureCount = vfes.size();
330 vf.pEnabledValidationFeatures = vfes.data();
331 create_info.pNext = &vf;
332 }
333
334 std::unordered_set<std::string> extensions;
335 for (auto &ext : get_required_extensions(params_.enable_validation_layer)) {
336 extensions.insert(std::string(ext));
337 }
338 for (auto &ext : params_.additional_instance_extensions) {
339 extensions.insert(std::string(ext));
340 }
341
342 uint32_t num_instance_extensions = 0;
343 // FIXME: (penguinliong) This was NOT called when `manual_create` is true.
344 vkEnumerateInstanceExtensionProperties(nullptr, &num_instance_extensions,
345 nullptr);
346 std::vector<VkExtensionProperties> supported_extensions(
347 num_instance_extensions);
348 vkEnumerateInstanceExtensionProperties(nullptr, &num_instance_extensions,
349 supported_extensions.data());
350
351 for (auto &ext : supported_extensions) {
352 std::string name = ext.extensionName;
353 if (name == VK_KHR_SURFACE_EXTENSION_NAME) {
354 extensions.insert(name);
355 ti_device_->vk_caps().surface = true;
356 } else if (name == VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME) {
357 extensions.insert(name);
358 ti_device_->vk_caps().physical_device_features2 = true;
359 } else if (name == VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME) {
360 extensions.insert(name);
361 } else if (name == VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_EXTENSION_NAME) {
362 extensions.insert(name);
363 } else if (name == VK_EXT_DEBUG_UTILS_EXTENSION_NAME) {
364 extensions.insert(name);
365 }
366 }
367
368 std::vector<const char *> confirmed_extensions;
369 confirmed_extensions.reserve(extensions.size());
370 for (auto &ext : extensions) {
371 confirmed_extensions.push_back(ext.data());
372 }
373
374 create_info.enabledExtensionCount = (uint32_t)confirmed_extensions.size();
375 create_info.ppEnabledExtensionNames = confirmed_extensions.data();
376
377 VkResult res =
378 vkCreateInstance(&create_info, kNoVkAllocCallbacks, &instance_);
379
380 if (res == VK_ERROR_INCOMPATIBLE_DRIVER) {
381 // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkApplicationInfo.html
382 // Vulkan 1.0 implementation will return this when api version is not 1.0
383 // Vulkan 1.1+ implementation will work with maximum version set
384 ti_device_->vk_caps().vk_api_version = VK_API_VERSION_1_0;
385 app_info.apiVersion = VK_API_VERSION_1_0;
386
387 res = vkCreateInstance(&create_info, kNoVkAllocCallbacks, &instance_);
388 } else {
389 ti_device_->vk_caps().vk_api_version = vk_api_version;
390 }
391
392 if (res != VK_SUCCESS) {
393 throw std::runtime_error("failed to create instance");
394 }
395
396 VulkanLoader::instance().load_instance(instance_);
397}
398
399void VulkanDeviceCreator::setup_debug_messenger() {
400 if (!params_.enable_validation_layer) {
401 return;
402 }
403 VkDebugUtilsMessengerCreateInfoEXT create_info{};
404 populate_debug_messenger_create_info(&create_info);
405
406 BAIL_ON_VK_BAD_RESULT_NO_RETURN(
407 create_debug_utils_messenger_ext(instance_, &create_info,
408 kNoVkAllocCallbacks, &debug_messenger_),
409 "failed to set up debug messenger");
410}
411
412void VulkanDeviceCreator::create_surface() {
413 surface_ = params_.surface_creator(instance_);
414 RHI_ASSERT(surface_ && "failed to create window surface!");
415}
416
417void VulkanDeviceCreator::pick_physical_device() {
418 uint32_t device_count = 0;
419 vkEnumeratePhysicalDevices(instance_, &device_count, nullptr);
420 RHI_ASSERT(device_count > 0 && "failed to find GPUs with Vulkan support");
421
422 std::vector<VkPhysicalDevice> devices(device_count);
423 vkEnumeratePhysicalDevices(instance_, &device_count, devices.data());
424 physical_device_ = VK_NULL_HANDLE;
425
426 for (int i = 0; i < device_count; i++) {
427 VkPhysicalDeviceProperties properties{};
428 vkGetPhysicalDeviceProperties(devices[i], &properties);
429
430 char msg_buf[128];
431 RHI_DEBUG_SNPRINTF(msg_buf, sizeof(msg_buf), "Found Vulkan Device %d (%s)",
432 i, properties.deviceName);
433 RHI_LOG_DEBUG(msg_buf);
434 }
435
436 auto device_id = VulkanLoader::instance().visible_device_id;
437 bool has_visible_device{false};
438 if (!device_id.empty()) {
439 int id = std::stoi(device_id);
440 if (id < 0 || id >= device_count) {
441 char msg_buf[128];
442 snprintf(msg_buf, sizeof(msg_buf),
443 "TI_VISIBLE_DEVICE=%d is not valid, found %d devices available",
444 id, device_count);
445 RHI_LOG_ERROR(msg_buf);
446 } else if (get_device_score(devices[id], surface_)) {
447 physical_device_ = devices[id];
448 has_visible_device = true;
449 }
450 }
451
452 if (!has_visible_device) {
453 // could not find a user defined visible device, use the first one suitable
454 size_t max_score = 0;
455 for (const auto &device : devices) {
456 size_t score = get_device_score(device, surface_);
457 if (score > max_score) {
458 physical_device_ = device;
459 max_score = score;
460 }
461 }
462 }
463 RHI_ASSERT(physical_device_ != VK_NULL_HANDLE &&
464 "failed to find a suitable GPU");
465
466 queue_family_indices_ = find_queue_families(physical_device_, surface_);
467}
468
469void VulkanDeviceCreator::create_logical_device(bool manual_create) {
470 DeviceCapabilityConfig caps{};
471
472 std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
473 std::unordered_set<uint32_t> unique_families;
474
475 if (queue_family_indices_.compute_family.has_value()) {
476 unique_families.insert(queue_family_indices_.compute_family.value());
477 }
478 if (queue_family_indices_.graphics_family.has_value()) {
479 unique_families.insert(queue_family_indices_.graphics_family.value());
480 }
481
482 float queue_priority = 1.0f;
483 for (uint32_t queue_family : unique_families) {
484 VkDeviceQueueCreateInfo queueCreateInfo{};
485 queueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
486 queueCreateInfo.queueFamilyIndex = queue_family;
487 queueCreateInfo.queueCount = 1;
488 queueCreateInfo.pQueuePriorities = &queue_priority;
489 queue_create_infos.push_back(queueCreateInfo);
490 }
491
492 VkDeviceCreateInfo create_info{};
493 create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
494 create_info.pQueueCreateInfos = queue_create_infos.data();
495 create_info.queueCreateInfoCount = queue_create_infos.size();
496
497 // Get device properties
498 VkPhysicalDeviceProperties physical_device_properties{};
499 vkGetPhysicalDeviceProperties(physical_device_, &physical_device_properties);
500
501 {
502 char msg_buf[256];
503 RHI_DEBUG_SNPRINTF(
504 msg_buf, sizeof(msg_buf),
505 "Vulkan Device \"%s\" supports Vulkan %d version %d.%d.%d",
506 physical_device_properties.deviceName,
507 VK_API_VERSION_VARIANT(physical_device_properties.apiVersion),
508 VK_API_VERSION_MAJOR(physical_device_properties.apiVersion),
509 VK_API_VERSION_MINOR(physical_device_properties.apiVersion),
510 VK_API_VERSION_PATCH(physical_device_properties.apiVersion));
511 RHI_LOG_DEBUG(msg_buf);
512 }
513
514 // (penguinliong) The actual logical device is created with lastest version of
515 // Vulkan but we use the device like it has a lower version (if the user
516 // wanted a lower version device).
517 uint32_t vk_api_version = physical_device_properties.apiVersion;
518 ti_device_->vk_caps().vk_api_version = vk_api_version;
519 if (vk_api_version >= VK_API_VERSION_1_3) {
520 caps.set(DeviceCapability::spirv_version, 0x10500);
521 } else if (vk_api_version >= VK_API_VERSION_1_2) {
522 caps.set(DeviceCapability::spirv_version, 0x10500);
523 } else if (vk_api_version >= VK_API_VERSION_1_1) {
524 caps.set(DeviceCapability::spirv_version, 0x10300);
525 } else {
526 caps.set(DeviceCapability::spirv_version, 0x10000);
527 }
528
529 // Detect extensions
530 std::vector<const char *> enabled_extensions;
531
532 uint32_t extension_count = 0;
533 // FIXME: (penguinliong) This was NOT called when `manual_create` is true.
534 vkEnumerateDeviceExtensionProperties(physical_device_, nullptr,
535 &extension_count, nullptr);
536 std::vector<VkExtensionProperties> extension_properties(extension_count);
537 vkEnumerateDeviceExtensionProperties(
538 physical_device_, nullptr, &extension_count, extension_properties.data());
539
540 bool has_swapchain = false;
541
542 [[maybe_unused]] bool portability_subset_enabled = false;
543
544 for (auto &ext : extension_properties) {
545 char msg_buf[256];
546 RHI_DEBUG_SNPRINTF(msg_buf, sizeof(msg_buf),
547 "Vulkan device extension {%s} (%x)", ext.extensionName,
548 ext.specVersion);
549 RHI_LOG_DEBUG(msg_buf);
550
551 std::string name = std::string(ext.extensionName);
552
553 if (name == "VK_KHR_portability_subset") {
554 RHI_LOG_ERROR(
555 "Potential non-conformant Vulkan implementation, enabling "
556 "VK_KHR_portability_subset");
557 portability_subset_enabled = true;
558 enabled_extensions.push_back(ext.extensionName);
559 } else if (name == VK_KHR_SWAPCHAIN_EXTENSION_NAME) {
560 has_swapchain = true;
561 enabled_extensions.push_back(ext.extensionName);
562 } else if (name == VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME) {
563 enabled_extensions.push_back(ext.extensionName);
564 } else if (name == VK_EXT_SHADER_ATOMIC_FLOAT_2_EXTENSION_NAME) {
565 enabled_extensions.push_back(ext.extensionName);
566 } else if (name == VK_KHR_SHADER_ATOMIC_INT64_EXTENSION_NAME) {
567 enabled_extensions.push_back(ext.extensionName);
568 } else if (name == VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME) {
569 enabled_extensions.push_back(ext.extensionName);
570 } else if (name == VK_KHR_SPIRV_1_4_EXTENSION_NAME) {
571 if (caps.get(DeviceCapability::spirv_version) < 0x10400) {
572 caps.set(DeviceCapability::spirv_version, 0x10400);
573 enabled_extensions.push_back(ext.extensionName);
574 }
575 } else if (name == VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME ||
576 name == VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME) {
577 ti_device_->vk_caps().external_memory = true;
578 enabled_extensions.push_back(ext.extensionName);
579 } else if (name == VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME) {
580 enabled_extensions.push_back(ext.extensionName);
581 } else if (name == VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) {
582 enabled_extensions.push_back(ext.extensionName);
583 } else if (name == VK_KHR_GET_MEMORY_REQUIREMENTS_2_EXTENSION_NAME) {
584 enabled_extensions.push_back(ext.extensionName);
585 } else if (name == VK_KHR_DEDICATED_ALLOCATION_EXTENSION_NAME) {
586 enabled_extensions.push_back(ext.extensionName);
587 } else if (name == VK_KHR_BIND_MEMORY_2_EXTENSION_NAME) {
588 enabled_extensions.push_back(ext.extensionName);
589 } else if (name == VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) {
590 enabled_extensions.push_back(ext.extensionName);
591 } else if (name == VK_KHR_DYNAMIC_RENDERING_EXTENSION_NAME) {
592 enabled_extensions.push_back(ext.extensionName);
593 } else if (name == VK_KHR_SHADER_NON_SEMANTIC_INFO_EXTENSION_NAME &&
594 params_.enable_validation_layer) {
595 // VK_KHR_shader_non_semantic_info isn't supported on molten-vk.
596 // Tracking issue: https://github.com/KhronosGroup/MoltenVK/issues/1214
597 caps.set(DeviceCapability::spirv_has_non_semantic_info, true);
598 enabled_extensions.push_back(ext.extensionName);
599 } else if (std::find(params_.additional_device_extensions.begin(),
600 params_.additional_device_extensions.end(),
601 name) != params_.additional_device_extensions.end()) {
602 enabled_extensions.push_back(ext.extensionName);
603 }
604 // Vulkan doesn't seem to support SPV_KHR_no_integer_wrap_decoration at all.
605 }
606
607 if (has_swapchain) {
608 ti_device_->vk_caps().present = true;
609 }
610
611 VkPhysicalDeviceFeatures device_features{};
612
613 VkPhysicalDeviceFeatures device_supported_features;
614 vkGetPhysicalDeviceFeatures(physical_device_, &device_supported_features);
615
616 if (device_supported_features.shaderInt16) {
617 device_features.shaderInt16 = true;
618 caps.set(DeviceCapability::spirv_has_int16, true);
619 }
620 if (device_supported_features.shaderInt64) {
621 device_features.shaderInt64 = true;
622 caps.set(DeviceCapability::spirv_has_int64, true);
623 }
624 if (device_supported_features.shaderFloat64) {
625 device_features.shaderFloat64 = true;
626 caps.set(DeviceCapability::spirv_has_float64, true);
627 }
628 if (device_supported_features.wideLines) {
629 device_features.wideLines = true;
630 ti_device_->vk_caps().wide_line = true;
631 }
632
633 if (ti_device_->vk_caps().vk_api_version >= VK_API_VERSION_1_1) {
634 VkPhysicalDeviceSubgroupProperties subgroup_properties{};
635 subgroup_properties.sType =
636 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
637 subgroup_properties.pNext = nullptr;
638
639 VkPhysicalDeviceProperties2 physical_device_properties{};
640 physical_device_properties.sType =
641 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
642 physical_device_properties.pNext = &subgroup_properties;
643
644 vkGetPhysicalDeviceProperties2(physical_device_,
645 &physical_device_properties);
646
647 if (subgroup_properties.supportedOperations &
648 VK_SUBGROUP_FEATURE_BASIC_BIT) {
649 caps.set(DeviceCapability::spirv_has_subgroup_basic, true);
650 }
651 if (subgroup_properties.supportedOperations &
652 VK_SUBGROUP_FEATURE_VOTE_BIT) {
653 caps.set(DeviceCapability::spirv_has_subgroup_vote, true);
654 }
655 if (subgroup_properties.supportedOperations &
656 VK_SUBGROUP_FEATURE_ARITHMETIC_BIT) {
657 caps.set(DeviceCapability::spirv_has_subgroup_arithmetic, true);
658 }
659 if (subgroup_properties.supportedOperations &
660 VK_SUBGROUP_FEATURE_BALLOT_BIT) {
661 caps.set(DeviceCapability::spirv_has_subgroup_ballot, true);
662 }
663 }
664
665 create_info.pEnabledFeatures = &device_features;
666 create_info.enabledExtensionCount = enabled_extensions.size();
667 create_info.ppEnabledExtensionNames = enabled_extensions.data();
668
669 void **pNextEnd = (void **)&create_info.pNext;
670
671 // Use physicalDeviceFeatures2 to features enabled by extensions
672 VkPhysicalDeviceVariablePointersFeaturesKHR variable_ptr_feature{};
673 variable_ptr_feature.sType =
674 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTERS_FEATURES_KHR;
675 VkPhysicalDeviceShaderAtomicFloatFeaturesEXT shader_atomic_float_feature{};
676 shader_atomic_float_feature.sType =
677 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
678 VkPhysicalDeviceShaderAtomicFloat2FeaturesEXT shader_atomic_float_2_feature{};
679 shader_atomic_float_2_feature.sType =
680 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_2_FEATURES_EXT;
681 VkPhysicalDeviceFloat16Int8FeaturesKHR shader_f16_i8_feature{};
682 shader_f16_i8_feature.sType =
683 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR;
684 VkPhysicalDeviceBufferDeviceAddressFeaturesKHR
685 buffer_device_address_feature{};
686 buffer_device_address_feature.sType =
687 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_KHR;
688 VkPhysicalDeviceDynamicRenderingFeaturesKHR dynamic_rendering_feature{};
689 dynamic_rendering_feature.sType =
690 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DYNAMIC_RENDERING_FEATURES_KHR;
691
692 if (ti_device_->vk_caps().physical_device_features2) {
693 VkPhysicalDeviceFeatures2KHR features2{};
694 features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
695
696#define CHECK_EXTENSION(ext) \
697 std::find_if(enabled_extensions.begin(), enabled_extensions.end(), \
698 [=](const char *o) { return strcmp(ext, o) == 0; }) != \
699 enabled_extensions.end()
700
701 uint32_t vk_api_version = ti_device_->vk_caps().vk_api_version;
702#define CHECK_VERSION(major, minor) \
703 vk_api_version >= VK_MAKE_API_VERSION(0, major, minor, 0)
704
705 // Variable ptr
706 if (CHECK_VERSION(1, 1) ||
707 CHECK_EXTENSION(VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME)) {
708 features2.pNext = &variable_ptr_feature;
709 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
710
711 if (variable_ptr_feature.variablePointers &&
712 variable_ptr_feature.variablePointersStorageBuffer) {
713 caps.set(DeviceCapability::spirv_has_variable_ptr, true);
714 }
715 *pNextEnd = &variable_ptr_feature;
716 pNextEnd = &variable_ptr_feature.pNext;
717 }
718
719 // Atomic float
720 if (CHECK_EXTENSION(VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME)) {
721 features2.pNext = &shader_atomic_float_feature;
722 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
723 if (shader_atomic_float_feature.shaderBufferFloat32AtomicAdd) {
724 caps.set(DeviceCapability::spirv_has_atomic_float_add, true);
725 }
726 if (shader_atomic_float_feature.shaderBufferFloat64AtomicAdd) {
727 caps.set(DeviceCapability::spirv_has_atomic_float64_add, true);
728 }
729 if (shader_atomic_float_feature.shaderBufferFloat32Atomics) {
730 caps.set(DeviceCapability::spirv_has_atomic_float, true);
731 }
732 if (shader_atomic_float_feature.shaderBufferFloat64Atomics) {
733 caps.set(DeviceCapability::spirv_has_atomic_float64, true);
734 }
735 *pNextEnd = &shader_atomic_float_feature;
736 pNextEnd = &shader_atomic_float_feature.pNext;
737 }
738
739 // Atomic float 2
740 if (CHECK_EXTENSION(VK_EXT_SHADER_ATOMIC_FLOAT_2_EXTENSION_NAME)) {
741 features2.pNext = &shader_atomic_float_2_feature;
742 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
743 if (shader_atomic_float_2_feature.shaderBufferFloat16AtomicAdd) {
744 caps.set(DeviceCapability::spirv_has_atomic_float_add, true);
745 }
746 if (shader_atomic_float_2_feature.shaderBufferFloat16AtomicMinMax) {
747 caps.set(DeviceCapability::spirv_has_atomic_float16_minmax, true);
748 }
749 if (shader_atomic_float_2_feature.shaderBufferFloat16Atomics) {
750 caps.set(DeviceCapability::spirv_has_atomic_float16, true);
751 }
752 if (shader_atomic_float_2_feature.shaderBufferFloat32AtomicMinMax) {
753 caps.set(DeviceCapability::spirv_has_atomic_float_minmax, true);
754 }
755 if (shader_atomic_float_2_feature.shaderBufferFloat64AtomicMinMax) {
756 caps.set(DeviceCapability::spirv_has_atomic_float64_minmax, true);
757 }
758 *pNextEnd = &shader_atomic_float_2_feature;
759 pNextEnd = &shader_atomic_float_2_feature.pNext;
760 }
761
762 // F16 / I8
763 if (CHECK_VERSION(1, 2) ||
764 CHECK_EXTENSION(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME)) {
765 features2.pNext = &shader_f16_i8_feature;
766 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
767
768 if (shader_f16_i8_feature.shaderFloat16) {
769 caps.set(DeviceCapability::spirv_has_float16, true);
770 }
771 if (shader_f16_i8_feature.shaderInt8) {
772 caps.set(DeviceCapability::spirv_has_int8, true);
773 }
774 *pNextEnd = &shader_f16_i8_feature;
775 pNextEnd = &shader_f16_i8_feature.pNext;
776 }
777
778 // Buffer Device Address
779 if (CHECK_VERSION(1, 2) ||
780 CHECK_EXTENSION(VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME)) {
781 features2.pNext = &buffer_device_address_feature;
782 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
783
784 if (CHECK_VERSION(1, 3) ||
785 buffer_device_address_feature.bufferDeviceAddress) {
786 if (device_supported_features.shaderInt64) {
787// Temporarily disable it on macOS:
788// https://github.com/taichi-dev/taichi/issues/6295
789// (penguinliong) Temporarily disabled (until device capability is ready).
790#if !defined(__APPLE__) && false
791 caps.set(DeviceCapability::spirv_has_physical_storage_buffer, true);
792#endif
793 }
794 }
795 *pNextEnd = &buffer_device_address_feature;
796 pNextEnd = &buffer_device_address_feature.pNext;
797 }
798
799 // Dynamic rendering
800 // TODO: Figure out how to integrate this correctly with ImGui,
801 // and then figure out the layout & barrier stuff
802 /*
803 if (CHECK_EXTENSION(VK_KHR_DYNAMIC_RENDERING_EXTENSION_NAME)) {
804 features2.pNext = &dynamic_rendering_feature;
805 vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);
806
807 if (dynamic_rendering_feature.dynamicRendering) {
808 ti_device_->vk_caps().dynamic_rendering = true;
809 }
810
811 *pNextEnd = &dynamic_rendering_feature;
812 pNextEnd = &dynamic_rendering_feature.pNext;
813 }
814 */
815
816 // TODO: add atomic min/max feature
817 }
818
819 if (params_.enable_validation_layer) {
820 create_info.enabledLayerCount = (uint32_t)kValidationLayers.size();
821 create_info.ppEnabledLayerNames = kValidationLayers.data();
822 } else {
823 create_info.enabledLayerCount = 0;
824 }
825 BAIL_ON_VK_BAD_RESULT_NO_RETURN(vkCreateDevice(physical_device_, &create_info,
826 kNoVkAllocCallbacks, &device_),
827 "failed to create logical device");
828 VulkanLoader::instance().load_device(device_);
829
830 if (queue_family_indices_.compute_family.has_value()) {
831 vkGetDeviceQueue(device_, queue_family_indices_.compute_family.value(), 0,
832 &compute_queue_);
833 }
834 if (queue_family_indices_.graphics_family.has_value()) {
835 vkGetDeviceQueue(device_, queue_family_indices_.graphics_family.value(), 0,
836 &graphics_queue_);
837 }
838
839 // Dump capabilities
840 caps.dbg_print_all();
841 ti_device_->set_caps(std::move(caps));
842}
843
844} // namespace vulkan
845} // namespace taichi::lang
846