1 | /** |
2 | * This file is adapted from PyTorch/XLA |
3 | * https://github.com/pytorch/xla/blob/master/third_party/xla_client/metrics.h |
4 | */ |
5 | |
6 | #pragma once |
7 | |
8 | #include <atomic> |
9 | #include <functional> |
10 | #include <map> |
11 | #include <memory> |
12 | #include <mutex> |
13 | #include <string> |
14 | #include <vector> |
15 | |
16 | #include <c10/macros/Export.h> |
17 | |
18 | namespace torch { |
19 | namespace lazy { |
20 | |
21 | struct TORCH_API Sample { |
22 | Sample() = default; |
23 | Sample(int64_t timestamp_ns, double value) |
24 | : timestamp_ns(timestamp_ns), value(value) {} |
25 | |
26 | int64_t timestamp_ns = 0; |
27 | double value = 0; |
28 | }; |
29 | |
30 | using MetricReprFn = std::function<std::string(double)>; |
31 | |
32 | // Class used to collect time-stamped numeric samples. The samples are stored in |
33 | // a circular buffer whose size can be configured at constructor time. |
34 | class TORCH_API MetricData { |
35 | public: |
36 | // Creates a new MetricData object with the internal circular buffer storing |
37 | // max_samples samples. The repr_fn argument allow to specify a function which |
38 | // pretty-prints a sample value. |
39 | MetricData(MetricReprFn repr_fn, size_t max_samples); |
40 | |
41 | // Returns the total values of all the samples being posted to this metric. |
42 | double Accumulator() const; |
43 | |
44 | size_t TotalSamples() const; |
45 | |
46 | void AddSample(int64_t timestamp_ns, double value); |
47 | |
48 | // Returns a vector with all the current samples, from the oldest to the |
49 | // newer. If accumulator is not nullptr, it will receive the current value of |
50 | // the metrics' accumulator (the sum of all posted values). If total_samples |
51 | // is not nullptr, it will receive the count of the posted values. |
52 | std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const; |
53 | |
54 | std::string Repr(double value) const { |
55 | return repr_fn_(value); |
56 | } |
57 | |
58 | void Reset(); |
59 | |
60 | bool IsValid() const { |
61 | return TotalSamples() > 0; |
62 | } |
63 | |
64 | private: |
65 | mutable std::mutex lock_; |
66 | MetricReprFn repr_fn_; |
67 | size_t count_ = 0; |
68 | std::vector<Sample> samples_; |
69 | double accumulator_ = 0.0; |
70 | }; |
71 | |
72 | // Counters are a very lightweight form of metrics which do not need to track |
73 | // sample time. |
74 | class TORCH_API CounterData { |
75 | public: |
76 | CounterData() : value_(0) {} |
77 | |
78 | void AddValue(int64_t value) { |
79 | value_ += value; |
80 | } |
81 | |
82 | int64_t Value() const { |
83 | return value_; |
84 | } |
85 | |
86 | void Reset() { |
87 | value_ = 0; |
88 | } |
89 | |
90 | bool IsValid() const { |
91 | return value_ > 0; |
92 | } |
93 | |
94 | private: |
95 | std::atomic<int64_t> value_; |
96 | }; |
97 | |
98 | class TORCH_API MetricsArena { |
99 | public: |
100 | static MetricsArena* Get(); |
101 | |
102 | void ResetCounters(); |
103 | void ResetMetrics(); |
104 | |
105 | // Registers a new metric in the global arena. |
106 | void RegisterMetric( |
107 | const std::string& name, |
108 | MetricReprFn repr_fn, |
109 | size_t max_samples, |
110 | std::shared_ptr<MetricData>* data); |
111 | |
112 | void RegisterCounter( |
113 | const std::string& name, |
114 | std::shared_ptr<CounterData>* data); |
115 | |
116 | void ForEachMetric( |
117 | const std::function<void(const std::string&, MetricData*)>& metric_func); |
118 | |
119 | void ForEachCounter( |
120 | const std::function<void(const std::string&, CounterData*)>& |
121 | counter_func); |
122 | |
123 | std::vector<std::string> GetMetricNames(); |
124 | |
125 | MetricData* GetMetric(const std::string& name); |
126 | |
127 | std::vector<std::string> GetCounterNames(); |
128 | |
129 | CounterData* GetCounter(const std::string& name); |
130 | |
131 | private: |
132 | std::mutex lock_; |
133 | std::map<std::string, std::shared_ptr<MetricData>> metrics_; |
134 | std::map<std::string, std::shared_ptr<CounterData>> counters_; |
135 | }; |
136 | |
137 | // Emits the value in a to_string() conversion. |
138 | TORCH_API std::string MetricFnValue(double value); |
139 | // Emits the value in a humanized bytes representation. |
140 | TORCH_API std::string MetricFnBytes(double value); |
141 | // Emits the value in a humanized time representation. The value is expressed in |
142 | // nanoseconds EPOCH time. |
143 | TORCH_API std::string MetricFnTime(double value); |
144 | |
145 | // The typical use of a Metric is one in which it gets created either in a |
146 | // global scope context: |
147 | // static Metric* metric = new Metric("RpcCount"); |
148 | // Or within a function scope: |
149 | // void MyFunction(...) { |
150 | // static Metric* metric = new Metric("RpcCount"); |
151 | // ... |
152 | // metric->AddSample(ts_nanos, some_value); |
153 | // } |
154 | class TORCH_API Metric { |
155 | public: |
156 | explicit Metric( |
157 | std::string name, |
158 | MetricReprFn repr_fn = MetricFnValue, |
159 | size_t max_samples = 0); |
160 | |
161 | const std::string& Name() const { |
162 | return name_; |
163 | } |
164 | |
165 | double Accumulator() const; |
166 | |
167 | void AddSample(int64_t timestamp_ns, double value); |
168 | |
169 | void AddSample(double value); |
170 | |
171 | std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const; |
172 | |
173 | std::string Repr(double value) const; |
174 | |
175 | private: |
176 | MetricData* GetData() const; |
177 | |
178 | std::string name_; |
179 | MetricReprFn repr_fn_; |
180 | size_t max_samples_; |
181 | mutable std::shared_ptr<MetricData> data_ptr_; |
182 | mutable std::atomic<MetricData*> data_; |
183 | }; |
184 | |
185 | // A Counter is a lightweight form of metric which tracks an integer value which |
186 | // can increase or decrease. |
187 | // A typical use is as: |
188 | // static Counter* counter = new Counter("MyCounter"); |
189 | // ... |
190 | // counter->AddValue(+1); |
191 | class TORCH_API Counter { |
192 | public: |
193 | explicit Counter(std::string name); |
194 | |
195 | void AddValue(int64_t value) { |
196 | GetData()->AddValue(value); |
197 | } |
198 | |
199 | int64_t Value() const { |
200 | return GetData()->Value(); |
201 | } |
202 | |
203 | private: |
204 | CounterData* GetData() const; |
205 | |
206 | std::string name_; |
207 | mutable std::shared_ptr<CounterData> data_ptr_; |
208 | mutable std::atomic<CounterData*> data_; |
209 | }; |
210 | |
211 | #define TORCH_LAZY_COUNTER(name, value) \ |
212 | do { \ |
213 | static ::torch::lazy::Counter* __counter = \ |
214 | new ::torch::lazy::Counter(name); \ |
215 | __counter->AddValue(value); \ |
216 | } while (0) |
217 | |
218 | #define TORCH_LAZY_FN_COUNTER(ns) TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1) |
219 | |
220 | #define TORCH_LAZY_VALUE_METRIC(name, value) \ |
221 | do { \ |
222 | static ::torch::lazy::Metric* __metric = \ |
223 | new ::torch::lazy::Metric(name, torch::lazy::MetricFnValue); \ |
224 | __metric->AddSample(value); \ |
225 | } while (0) |
226 | |
227 | // Creates a report with the current metrics statistics. |
228 | TORCH_API std::string CreateMetricReport(); |
229 | |
230 | // Creates a report with the selected metrics statistics. |
231 | TORCH_API std::string CreateMetricReport( |
232 | const std::vector<std::string>& counter_names, |
233 | const std::vector<std::string>& metric_names); |
234 | |
235 | // Returns the currently registered metric names. Note that the list can grow |
236 | // since metrics are usually function intialized (they are static function |
237 | // variables). |
238 | TORCH_API std::vector<std::string> GetMetricNames(); |
239 | |
240 | // Retrieves the metric data of a given metric, or nullptr if such metric does |
241 | // not exist. |
242 | TORCH_API MetricData* GetMetric(const std::string& name); |
243 | |
244 | // Returns the currently registered counter names. Note that the list can grow |
245 | // since counters are usually function intialized (they are static function |
246 | // variables). |
247 | TORCH_API std::vector<std::string> GetCounterNames(); |
248 | |
249 | // Retrieves the counter data of a given counter, or nullptr if such counter |
250 | // does not exist. |
251 | TORCH_API CounterData* GetCounter(const std::string& name); |
252 | |
253 | // Retrieves the current EPOCH time in nanoseconds. |
254 | TORCH_API int64_t NowNs(); |
255 | |
256 | // Scope based utility class TORCH_API to measure the time the code takes within |
257 | // a given C++ scope. |
258 | class TORCH_API TimedSection { |
259 | public: |
260 | explicit TimedSection(Metric* metric) : metric_(metric), start_(NowNs()) {} |
261 | |
262 | ~TimedSection() { |
263 | int64_t now = NowNs(); |
264 | metric_->AddSample(now, now - start_); |
265 | } |
266 | |
267 | double Elapsed() const { |
268 | return 1e-9 * static_cast<double>(NowNs() - start_); |
269 | } |
270 | |
271 | private: |
272 | Metric* metric_; |
273 | int64_t start_; |
274 | }; |
275 | |
276 | #define TORCH_LAZY_TIMED(name) \ |
277 | static torch::lazy::Metric* timed_metric = \ |
278 | new torch::lazy::Metric(name, torch::lazy::MetricFnTime); \ |
279 | torch::lazy::TimedSection timed_section(timed_metric) |
280 | |
281 | } // namespace lazy |
282 | } // namespace torch |
283 | |