1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/profiler/stubs/base.h> |
4 | #include <torch/csrc/profiler/util.h> |
5 | #include <cstdint> |
6 | |
7 | namespace torch { |
8 | namespace profiler { |
9 | namespace impl { |
10 | namespace vulkan { |
11 | |
12 | // Using function pointer i.e. [std::tuple<std::string, uint64_t> (*)(int64_t)] |
13 | // doesn't work because we need to capture the QueryPool in the lambda context |
14 | // https://stackoverflow.com/a/28746827 |
15 | using GetShaderNameAndDurationNsFn = |
16 | std::function<std::tuple<std::string, uint64_t>(int64_t)>; |
17 | TORCH_API void registerGetShaderNameAndDurationNs( |
18 | GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns); |
19 | |
20 | TORCH_API void deregisterGetShaderNameAndDurationNs(); |
21 | |
22 | std::tuple<std::string, uint64_t> getShaderNameAndDurationNs( |
23 | const vulkan_id_t& vulkan_id); |
24 | |
25 | } // namespace vulkan |
26 | } // namespace impl |
27 | } // namespace profiler |
28 | } // namespace torch |
29 |