1#pragma once
2
3#include "taichi/system/timeline.h"
4#include "taichi/program/kernel_profiler.h"
5#include "taichi/rhi/cuda/cuda_driver.h"
6#include "taichi/rhi/cuda/cuda_context.h"
7#include "taichi/rhi/cuda/cupti_toolkit.h"
8
9#include <string>
10#include <stdint.h>
11
12namespace taichi::lang {
13
14enum class ProfilingToolkit : int {
15 undef,
16 event,
17 cupti,
18};
19
20class EventToolkit;
21
22// A CUDA kernel profiler
23class KernelProfilerCUDA : public KernelProfilerBase {
24 public:
25 explicit KernelProfilerCUDA(bool enable);
26
27 std::string get_device_name() override;
28
29 bool reinit_with_metrics(const std::vector<std::string> metrics) override;
30 void trace(KernelProfilerBase::TaskHandle &task_handle,
31 const std::string &kernel_name,
32 void *kernel,
33 uint32_t grid_size,
34 uint32_t block_size,
35 uint32_t dynamic_smem_size);
36 void sync() override;
37 void update() override;
38 void clear() override;
39 void stop(KernelProfilerBase::TaskHandle handle) override;
40
41 bool set_profiler_toolkit(std::string toolkit_name) override;
42
43 bool statistics_on_traced_records();
44
45 KernelProfilerBase::TaskHandle start_with_handle(
46 const std::string &kernel_name) override;
47
48 bool record_kernel_attributes(void *kernel,
49 uint32_t grid_size,
50 uint32_t block_size,
51 uint32_t dynamic_smem_size);
52
53 private:
54 ProfilingToolkit tool_ = ProfilingToolkit::undef;
55
56 // Instances of these toolkits may exist at the same time,
57 // but only one will be enabled.
58 std::unique_ptr<EventToolkit> event_toolkit_{nullptr};
59 std::unique_ptr<CuptiToolkit> cupti_toolkit_{nullptr};
60 std::vector<std::string> metric_list_;
61 uint32_t records_size_after_sync_{0};
62};
63
64// default profiling toolkit
65class EventToolkit {
66 public:
67 void update_record(uint32_t records_size_after_sync,
68 std::vector<KernelProfileTracedRecord> &traced_records);
69 KernelProfilerBase::TaskHandle start_with_handle(
70 const std::string &kernel_name);
71 void update_timeline(std::vector<KernelProfileTracedRecord> &traced_records);
72 void clear() {
73 event_records_.clear();
74 }
75
76 private:
77 struct EventRecord {
78 std::string name;
79 float kernel_elapsed_time_in_ms{0.0};
80 float time_since_base{0.0};
81 void *start_event{nullptr};
82 void *stop_event{nullptr};
83 };
84 float64 base_time_{0.0};
85 void *base_event_{nullptr};
86 // for cuEvent profiling, clear after sync()
87 std::vector<EventRecord> event_records_;
88
89 public:
90 EventRecord *get_current_event_record() {
91 return &(event_records_.back());
92 }
93 void *get_base_event() const {
94 return base_event_;
95 }
96};
97
98} // namespace taichi::lang
99