1 | #ifdef USE_KINETO |
---|---|
2 | #include <libkineto.h> |
3 | #include <torch/csrc/autograd/profiler_kineto.h> |
4 | #include <cstdlib> |
5 | |
6 | // Ondemand tracing is not supported on Apple or edge platform |
7 | #if defined(__APPLE__) || defined(EDGE_PROFILER_USE_KINETO) |
8 | #define ENABLE_GLOBAL_OBSERVER (0) |
9 | #else |
10 | #define ENABLE_GLOBAL_OBSERVER (1) |
11 | #endif |
12 | |
13 | namespace torch { |
14 | namespace profiler { |
15 | namespace impl { |
16 | |
17 | namespace { |
18 | |
19 | using namespace torch::autograd::profiler; |
20 | |
21 | class LibKinetoClient : public libkineto::ClientInterface { |
22 | public: |
23 | void init() override {} |
24 | |
25 | void warmup(bool setupOpInputsCollection) override { |
26 | reportInputShapes_ = setupOpInputsCollection; |
27 | } |
28 | |
29 | void start() override { |
30 | ProfilerConfig cfg{ |
31 | ProfilerState::KINETO_ONDEMAND, |
32 | /*report_input_shapes=*/reportInputShapes_, |
33 | /*profile_memory=*/false, |
34 | /*with_stack=*/withStack_, |
35 | /*with_flops=*/false, |
36 | /*with_modules=*/false}; |
37 | std::set<ActivityType> activities{ActivityType::CPU}; |
38 | std::unordered_set<at::RecordScope> scopes; |
39 | scopes.insert(at::RecordScope::FUNCTION); |
40 | scopes.insert(at::RecordScope::USER_SCOPE); |
41 | scopes.insert(at::RecordScope::BACKWARD_FUNCTION); |
42 | enableProfiler(cfg, activities, scopes); |
43 | } |
44 | |
45 | void stop() override { |
46 | (void)disableProfiler(); |
47 | } |
48 | |
49 | // NOLINTNEXTLINE(modernize-use-override) |
50 | void set_withstack(bool withStack) override { |
51 | withStack_ = withStack; |
52 | } |
53 | |
54 | private: |
55 | bool reportInputShapes_{true}; |
56 | bool withStack_{false}; |
57 | }; |
58 | |
59 | } // namespace |
60 | |
61 | } // namespace impl |
62 | } // namespace profiler |
63 | |
64 | #if ENABLE_GLOBAL_OBSERVER |
65 | namespace { |
66 | |
67 | struct RegisterLibKinetoClient { |
68 | RegisterLibKinetoClient() { |
69 | static profiler::impl::LibKinetoClient client; |
70 | |
71 | if (std::getenv("KINETO_USE_DAEMON") != nullptr) { |
72 | libkineto_init(/*cpuOnly=*/false, /*logOnError=*/true); |
73 | libkineto::api().suppressLogMessages(); |
74 | } |
75 | |
76 | libkineto::api().registerClient(&client); |
77 | } |
78 | } register_libkineto_client; |
79 | |
80 | } // namespace |
81 | #endif |
82 | |
83 | } // namespace torch |
84 | #endif // USE_KINETO |
85 |