1 | #pragma once |
2 | |
3 | #include "taichi/rhi/arch.h" |
4 | #include "taichi/util/lang_util.h" |
5 | |
6 | #include <algorithm> |
7 | #include <map> |
8 | #include <string> |
9 | #include <vector> |
10 | #include <memory> |
11 | #include <regex> |
12 | |
13 | namespace taichi::lang { |
14 | |
15 | struct KernelProfileTracedRecord { |
16 | // kernel attributes |
17 | int register_per_thread{0}; |
18 | int shared_mem_per_block{0}; |
19 | int grid_size{0}; |
20 | int block_size{0}; |
21 | int active_blocks_per_multiprocessor{0}; |
22 | // kernel time |
23 | float kernel_elapsed_time_in_ms{0.0}; |
24 | float time_since_base{0.0}; // for Timeline |
25 | std::string name; // kernel name |
26 | std::vector<float> metric_values; // user selected metrics |
27 | }; |
28 | |
29 | struct KernelProfileStatisticalResult { |
30 | std::string name; |
31 | int counter; |
32 | double min; |
33 | double max; |
34 | double total; |
35 | |
36 | explicit KernelProfileStatisticalResult(const std::string &name) |
37 | : name(name), counter(0), min(0), max(0), total(0) { |
38 | } |
39 | |
40 | void insert_record(double t); // TODO replace `double time` with |
41 | // `KernelProfileTracedRecord record` |
42 | |
43 | bool operator<(const KernelProfileStatisticalResult &o) const; |
44 | }; |
45 | |
46 | class KernelProfilerBase { |
47 | protected: |
48 | std::vector<KernelProfileTracedRecord> traced_records_; |
49 | std::vector<KernelProfileStatisticalResult> statistical_results_; |
50 | double total_time_ms_{0}; |
51 | |
52 | public: |
53 | // Needed for the CUDA backend since we need to know which task to "stop" |
54 | using TaskHandle = void *; |
55 | |
56 | virtual bool reinit_with_metrics(const std::vector<std::string> metrics) { |
57 | return false; |
58 | }; // public API for all backend, do not use TI_NOT_IMPLEMENTED; |
59 | |
60 | virtual void clear() = 0; |
61 | |
62 | virtual void sync() = 0; |
63 | |
64 | virtual void update() = 0; |
65 | |
66 | virtual bool set_profiler_toolkit(std::string toolkit_name) { |
67 | return false; |
68 | } |
69 | |
70 | // TODO: remove start and always use start_with_handle |
71 | virtual void start(const std::string &kernel_name){TI_NOT_IMPLEMENTED}; |
72 | |
73 | virtual TaskHandle start_with_handle(const std::string &kernel_name){ |
74 | TI_NOT_IMPLEMENTED}; |
75 | |
76 | static void profiler_start(KernelProfilerBase *profiler, |
77 | const char *kernel_name); |
78 | |
79 | virtual void stop(){TI_NOT_IMPLEMENTED}; |
80 | |
81 | virtual void stop(TaskHandle){TI_NOT_IMPLEMENTED}; |
82 | |
83 | static void profiler_stop(KernelProfilerBase *profiler); |
84 | |
85 | void query(const std::string &kernel_name, |
86 | int &counter, |
87 | double &min, |
88 | double &max, |
89 | double &avg); |
90 | |
91 | std::vector<KernelProfileTracedRecord> get_traced_records() { |
92 | return traced_records_; |
93 | } |
94 | |
95 | double get_total_time() const; |
96 | |
97 | void insert_record(const std::string &kernel_name, double duration_ms); |
98 | |
99 | virtual std::string get_device_name() { |
100 | std::string str(" " ); |
101 | return str; |
102 | } |
103 | |
104 | virtual ~KernelProfilerBase() { |
105 | } |
106 | }; |
107 | |
108 | std::unique_ptr<KernelProfilerBase> make_profiler(Arch arch, bool enable); |
109 | |
110 | } // namespace taichi::lang |
111 | |