1 | #pragma once |
2 | |
3 | #include <functional> |
4 | #include <memory> |
5 | |
6 | #include <c10/util/strong_type.h> |
7 | #include <torch/csrc/Export.h> |
8 | |
9 | struct CUevent_st; |
10 | |
11 | namespace torch { |
12 | namespace profiler { |
13 | namespace impl { |
14 | |
15 | // ---------------------------------------------------------------------------- |
16 | // -- Annotation -------------------------------------------------------------- |
17 | // ---------------------------------------------------------------------------- |
18 | using ProfilerEventStub = std::shared_ptr<CUevent_st>; |
19 | |
20 | struct TORCH_API ProfilerStubs { |
21 | virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) |
22 | const = 0; |
23 | virtual float elapsed( |
24 | const ProfilerEventStub* event, |
25 | const ProfilerEventStub* event2) const = 0; |
26 | virtual void mark(const char* name) const = 0; |
27 | virtual void rangePush(const char* name) const = 0; |
28 | virtual void rangePop() const = 0; |
29 | virtual bool enabled() const { |
30 | return false; |
31 | } |
32 | virtual void onEachDevice(std::function<void(int)> op) const = 0; |
33 | virtual void synchronize() const = 0; |
34 | virtual ~ProfilerStubs(); |
35 | }; |
36 | |
37 | TORCH_API void registerCUDAMethods(ProfilerStubs* stubs); |
38 | TORCH_API const ProfilerStubs* cudaStubs(); |
39 | TORCH_API void registerITTMethods(ProfilerStubs* stubs); |
40 | TORCH_API const ProfilerStubs* ittStubs(); |
41 | |
42 | using vulkan_id_t = strong::type< |
43 | int64_t, |
44 | struct _VulkanID, |
45 | strong::regular, |
46 | strong::convertible_to<int64_t>, |
47 | strong::hashable>; |
48 | |
49 | } // namespace impl |
50 | } // namespace profiler |
51 | } // namespace torch |
52 | |