1 | #include <torch/csrc/profiler/orchestration/vulkan.h> |
---|---|
2 | |
3 | namespace torch { |
4 | namespace profiler { |
5 | namespace impl { |
6 | namespace vulkan { |
7 | namespace { |
8 | |
9 | GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns_fn; |
10 | |
11 | } // namespace |
12 | |
13 | void registerGetShaderNameAndDurationNs( |
14 | GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns) { |
15 | get_shader_name_and_duration_ns_fn = get_shader_name_and_duration_ns; |
16 | } |
17 | |
18 | void deregisterGetShaderNameAndDurationNs() { |
19 | get_shader_name_and_duration_ns_fn = nullptr; |
20 | } |
21 | |
22 | std::tuple<std::string, uint64_t> getShaderNameAndDurationNs( |
23 | const vulkan_id_t& vulkan_id) { |
24 | /* |
25 | We don't need to worry about a race condition with |
26 | deregisterGetShaderNameAndDurationNs here currently because |
27 | deregisterGetShaderNameAndDurationNs is only called within the destructor |
28 | of QueryPool, which would only be called after we're done calling |
29 | getShaderNameAndDurationNs |
30 | */ |
31 | TORCH_CHECK( |
32 | get_shader_name_and_duration_ns_fn != nullptr, |
33 | "Attempting to get shader duration in ", |
34 | "torch::profiler::impl::vulkan::getShaderNameAndDurationNs, but " |
35 | "get_shader_duration_fn is unregistered. Use " |
36 | "torch::profiler::impl::vulkan::registerGetShaderNameAndDurationNs to register " |
37 | "it first"); |
38 | return get_shader_name_and_duration_ns_fn(vulkan_id.value_of()); |
39 | } |
40 | |
41 | } // namespace vulkan |
42 | } // namespace impl |
43 | } // namespace profiler |
44 | } // namespace torch |
45 |