1#pragma once
2
3#include <torch/csrc/profiler/stubs/base.h>
4#include <torch/csrc/profiler/util.h>
5#include <cstdint>
6
7namespace torch {
8namespace profiler {
9namespace impl {
10namespace vulkan {
11
12// Using function pointer i.e. [std::tuple<std::string, uint64_t> (*)(int64_t)]
13// doesn't work because we need to capture the QueryPool in the lambda context
14// https://stackoverflow.com/a/28746827
15using GetShaderNameAndDurationNsFn =
16 std::function<std::tuple<std::string, uint64_t>(int64_t)>;
17TORCH_API void registerGetShaderNameAndDurationNs(
18 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns);
19
20TORCH_API void deregisterGetShaderNameAndDurationNs();
21
22std::tuple<std::string, uint64_t> getShaderNameAndDurationNs(
23 const vulkan_id_t& vulkan_id);
24
25} // namespace vulkan
26} // namespace impl
27} // namespace profiler
28} // namespace torch
29