1#include <torch/csrc/profiler/orchestration/vulkan.h>
2
3namespace torch {
4namespace profiler {
5namespace impl {
6namespace vulkan {
7namespace {
8
9GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns_fn;
10
11} // namespace
12
13void registerGetShaderNameAndDurationNs(
14 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns) {
15 get_shader_name_and_duration_ns_fn = get_shader_name_and_duration_ns;
16}
17
18void deregisterGetShaderNameAndDurationNs() {
19 get_shader_name_and_duration_ns_fn = nullptr;
20}
21
22std::tuple<std::string, uint64_t> getShaderNameAndDurationNs(
23 const vulkan_id_t& vulkan_id) {
24 /*
25 We don't need to worry about a race condition with
26 deregisterGetShaderNameAndDurationNs here currently because
27 deregisterGetShaderNameAndDurationNs is only called within the destructor
28 of QueryPool, which would only be called after we're done calling
29 getShaderNameAndDurationNs
30 */
31 TORCH_CHECK(
32 get_shader_name_and_duration_ns_fn != nullptr,
33 "Attempting to get shader duration in ",
34 "torch::profiler::impl::vulkan::getShaderNameAndDurationNs, but "
35 "get_shader_duration_fn is unregistered. Use "
36 "torch::profiler::impl::vulkan::registerGetShaderNameAndDurationNs to register "
37 "it first");
38 return get_shader_name_and_duration_ns_fn(vulkan_id.value_of());
39}
40
41} // namespace vulkan
42} // namespace impl
43} // namespace profiler
44} // namespace torch
45