1#pragma once
2
3#include <memory>
4#include <string>
5
6// Skip Kineto dependency on mobile unless explicitly asked for.
7// When is it explicitly asked for?
8// KinetoEdgeCPUProfiler uses KinetoProfiler for cpu
9// event profiling. This has a dependency on cpu only libkineto
10#if defined(USE_KINETO) && defined(C10_MOBILE) && \
11 !defined(EDGE_PROFILER_USE_KINETO)
12#undef USE_KINETO
13#endif
14
15#include <ActivityType.h>
16
17#include <torch/csrc/Export.h>
18#include <torch/csrc/profiler/api.h>
19
20#ifdef USE_KINETO
21// Forward declarations so we don't have to include `libkineto.h` in a header.
22namespace libkineto {
23class GenericTraceActivity;
24struct CpuTraceBuffer;
25class ActivityTraceInterface;
26} // namespace libkineto
27#endif
28
29namespace torch {
30namespace profiler {
31
32#ifdef USE_KINETO
33constexpr bool kKinetoAvailable{true};
34#else
35constexpr bool kKinetoAvailable{false};
36#endif
37
38namespace impl {
39namespace kineto {
40
41// ----------------------------------------------------------------------------
42// -- Interface (Does not require Kineto) -------------------------------------
43// ----------------------------------------------------------------------------
44struct DeviceAndResource {
45 int32_t device;
46 int32_t resource;
47};
48const DeviceAndResource kineto_ids();
49
50#ifdef USE_KINETO
51using trace_t = libkineto::CpuTraceBuffer;
52using interface_trace_t = libkineto::ActivityTraceInterface;
53using activity_t = libkineto::GenericTraceActivity;
54#else
55struct DummyTraceBuffer {};
56struct DummyTraceInterface {};
57
58using trace_t = DummyTraceBuffer;
59using interface_trace_t = DummyTraceBuffer;
60struct activity_t;
61#endif // USE_KINETO
62
63void addMetadata(
64 const activity_t* activity,
65 const std::string& key,
66 const std::string& value);
67
68// Wraps: libkineto::CpuTraceBuffer
69struct TraceWrapper {
70 TraceWrapper(const int64_t start_time, const std::string& name);
71 TraceWrapper(TraceWrapper&&) = default;
72 TraceWrapper(const TraceWrapper&) = delete;
73 ~TraceWrapper();
74
75 // The caller is expected to hold a mutex when calling `addCPUActivity`.
76 activity_t* addCPUActivity(
77 const std::string& name,
78 const libkineto::ActivityType type,
79 const DeviceAndResource device_and_resource,
80 const uint64_t correlation_id,
81 const int64_t start_time,
82 const int64_t end_time);
83
84 void transferCpuTrace(int64_t end_time);
85
86 explicit operator bool() const;
87
88 std::unique_ptr<trace_t>& get() {
89 return cpu_trace_;
90 }
91
92 private:
93 std::unique_ptr<trace_t> cpu_trace_;
94};
95
96// Wraps libkineto::ActivityTraceInterface
97struct ActivityTraceWrapper {
98 explicit ActivityTraceWrapper(std::unique_ptr<interface_trace_t>&& trace);
99 ActivityTraceWrapper() = default;
100 ActivityTraceWrapper(ActivityTraceWrapper&&) = default;
101 ActivityTraceWrapper(const ActivityTraceWrapper&) = delete;
102 explicit operator bool() const;
103 void save(const std::string& path);
104
105 const std::unique_ptr<interface_trace_t>& get() {
106 return trace_;
107 }
108
109 private:
110 std::unique_ptr<interface_trace_t> trace_;
111#ifdef USE_KINETO
112 bool saved_ = false; // Kineto's save is destructive
113#endif
114};
115
116using ActivitySet = std::set<torch::autograd::profiler::ActivityType>;
117void prepareTrace(
118 const bool cpuOnly,
119 const ActivitySet& activities,
120 const torch::profiler::impl::ExperimentalConfig& config);
121void startTrace();
122ActivityTraceWrapper stopTrace();
123void pushCorrelationId(uint64_t correlation_id);
124void pushUserCorrelationId(uint64_t correlation_id);
125void popCorrelationId();
126void popUserCorrelationId();
127void recordThreadInfo();
128
129void logInvariantViolation(
130 const std::string& assertion,
131 const std::string& error,
132 const std::string& profile_id,
133 const std::string& group_profile_id);
134
135} // namespace kineto
136} // namespace impl
137} // namespace profiler
138
139namespace autograd {
140namespace profiler {
141c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type);
142
143TORCH_API void addMetadataJson(
144 const std::string& key,
145 const std::string& value);
146
147TORCH_API void profilerStep();
148
149} // namespace profiler
150} // namespace autograd
151} // namespace torch
152