1#pragma once
2
3#include "taichi/rhi/arch.h"
4#include "taichi/util/lang_util.h"
5
6#include <algorithm>
7#include <map>
8#include <string>
9#include <vector>
10#include <memory>
11#include <regex>
12
13namespace taichi::lang {
14
15struct KernelProfileTracedRecord {
16 // kernel attributes
17 int register_per_thread{0};
18 int shared_mem_per_block{0};
19 int grid_size{0};
20 int block_size{0};
21 int active_blocks_per_multiprocessor{0};
22 // kernel time
23 float kernel_elapsed_time_in_ms{0.0};
24 float time_since_base{0.0}; // for Timeline
25 std::string name; // kernel name
26 std::vector<float> metric_values; // user selected metrics
27};
28
29struct KernelProfileStatisticalResult {
30 std::string name;
31 int counter;
32 double min;
33 double max;
34 double total;
35
36 explicit KernelProfileStatisticalResult(const std::string &name)
37 : name(name), counter(0), min(0), max(0), total(0) {
38 }
39
40 void insert_record(double t); // TODO replace `double time` with
41 // `KernelProfileTracedRecord record`
42
43 bool operator<(const KernelProfileStatisticalResult &o) const;
44};
45
46class KernelProfilerBase {
47 protected:
48 std::vector<KernelProfileTracedRecord> traced_records_;
49 std::vector<KernelProfileStatisticalResult> statistical_results_;
50 double total_time_ms_{0};
51
52 public:
53 // Needed for the CUDA backend since we need to know which task to "stop"
54 using TaskHandle = void *;
55
56 virtual bool reinit_with_metrics(const std::vector<std::string> metrics) {
57 return false;
58 }; // public API for all backend, do not use TI_NOT_IMPLEMENTED;
59
60 virtual void clear() = 0;
61
62 virtual void sync() = 0;
63
64 virtual void update() = 0;
65
66 virtual bool set_profiler_toolkit(std::string toolkit_name) {
67 return false;
68 }
69
70 // TODO: remove start and always use start_with_handle
71 virtual void start(const std::string &kernel_name){TI_NOT_IMPLEMENTED};
72
73 virtual TaskHandle start_with_handle(const std::string &kernel_name){
74 TI_NOT_IMPLEMENTED};
75
76 static void profiler_start(KernelProfilerBase *profiler,
77 const char *kernel_name);
78
79 virtual void stop(){TI_NOT_IMPLEMENTED};
80
81 virtual void stop(TaskHandle){TI_NOT_IMPLEMENTED};
82
83 static void profiler_stop(KernelProfilerBase *profiler);
84
85 void query(const std::string &kernel_name,
86 int &counter,
87 double &min,
88 double &max,
89 double &avg);
90
91 std::vector<KernelProfileTracedRecord> get_traced_records() {
92 return traced_records_;
93 }
94
95 double get_total_time() const;
96
97 void insert_record(const std::string &kernel_name, double duration_ms);
98
99 virtual std::string get_device_name() {
100 std::string str(" ");
101 return str;
102 }
103
104 virtual ~KernelProfilerBase() {
105 }
106};
107
108std::unique_ptr<KernelProfilerBase> make_profiler(Arch arch, bool enable);
109
110} // namespace taichi::lang
111