1#pragma once
2
3#include <functional>
4#include <memory>
5
6#include <c10/util/strong_type.h>
7#include <torch/csrc/Export.h>
8
9struct CUevent_st;
10
11namespace torch {
12namespace profiler {
13namespace impl {
14
15// ----------------------------------------------------------------------------
16// -- Annotation --------------------------------------------------------------
17// ----------------------------------------------------------------------------
18using ProfilerEventStub = std::shared_ptr<CUevent_st>;
19
20struct TORCH_API ProfilerStubs {
21 virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
22 const = 0;
23 virtual float elapsed(
24 const ProfilerEventStub* event,
25 const ProfilerEventStub* event2) const = 0;
26 virtual void mark(const char* name) const = 0;
27 virtual void rangePush(const char* name) const = 0;
28 virtual void rangePop() const = 0;
29 virtual bool enabled() const {
30 return false;
31 }
32 virtual void onEachDevice(std::function<void(int)> op) const = 0;
33 virtual void synchronize() const = 0;
34 virtual ~ProfilerStubs();
35};
36
37TORCH_API void registerCUDAMethods(ProfilerStubs* stubs);
38TORCH_API const ProfilerStubs* cudaStubs();
39TORCH_API void registerITTMethods(ProfilerStubs* stubs);
40TORCH_API const ProfilerStubs* ittStubs();
41
42using vulkan_id_t = strong::type<
43 int64_t,
44 struct _VulkanID,
45 strong::regular,
46 strong::convertible_to<int64_t>,
47 strong::hashable>;
48
49} // namespace impl
50} // namespace profiler
51} // namespace torch
52